Skip to content

Commit 3057e51

Browse files
【PPSCI Export&Infer No.9】Bubble (#887)
* 【PPSCI Export&Infer No.9】 * update examples/bubble/conf/bubble.yaml * fix codestyle bugs * Update examples/bubble/bubble.py * update examples/bubble/bubble.py --------- Co-authored-by: HydrogenSulfate <[email protected]>
1 parent 0b67435 commit 3057e51

File tree

3 files changed

+140
-1
lines changed

3 files changed

+140
-1
lines changed

docs/zh/examples/bubble.md

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,22 @@
2020
python bubble.py mode=eval EVAL.pretrained_model_path=https://paddle-org.bj.bcebos.com/paddlescience/models/bubble/bubble_pretrained.pdparams
2121
```
2222

23+
=== "模型导出命令"
24+
25+
``` sh
26+
python bubble.py mode=export
27+
```
28+
29+
=== "模型推理命令"
30+
31+
``` sh
32+
# linux
33+
wget -nc https://paddle-org.bj.bcebos.com/paddlescience/datasets/BubbleNet/bubble.mat
34+
# windows
35+
# curl https://paddle-org.bj.bcebos.com/paddlescience/datasets/BubbleNet/bubble.mat --output bubble.mat
36+
python bubble.py mode=infer
37+
```
38+
2339
| 预训练模型 | 指标 |
2440
|:--| :--|
2541
| [bubble_pretrained.pdparams](https://paddle-org.bj.bcebos.com/paddlescience/models/bubble/bubble_pretrained.pdparams) | loss(bubble_mse): 0.00558<br>MSE.u(bubble_mse): 0.00090<br>MSE.v(bubble_mse): 0.00322<br>MSE.p(bubble_mse): 0.00066<br>MSE.phil(bubble_mse): 0.00079 |

examples/bubble/bubble.py

Lines changed: 105 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -403,14 +403,118 @@ def transform_out(in_, out):
403403
)
404404

405405

406+
def export(cfg: DictConfig):
407+
# set model
408+
model_psi = ppsci.arch.MLP(**cfg.MODEL.psi_net)
409+
model_p = ppsci.arch.MLP(**cfg.MODEL.p_net)
410+
model_phil = ppsci.arch.MLP(**cfg.MODEL.phil_net)
411+
412+
# transform
413+
def transform_out(in_, out):
414+
psi_y = out["psi"]
415+
y = in_["y"]
416+
x = in_["x"]
417+
u = jacobian(psi_y, y, create_graph=False)
418+
v = -jacobian(psi_y, x, create_graph=False)
419+
return {"u": u, "v": v}
420+
421+
# register transform
422+
model_psi.register_output_transform(transform_out)
423+
model_list = ppsci.arch.ModelList((model_psi, model_p, model_phil))
424+
425+
# initialize solver
426+
solver = ppsci.solver.Solver(
427+
model_list,
428+
pretrained_model_path=cfg.INFER.pretrained_model_path,
429+
)
430+
# export model
431+
from paddle.static import InputSpec
432+
433+
input_spec = [
434+
{
435+
key: InputSpec([None, 1], "float32", name=key)
436+
for key in model_list.input_keys
437+
},
438+
]
439+
solver.export(input_spec, cfg.INFER.export_path)
440+
441+
442+
def inference(cfg: DictConfig):
443+
# load Data
444+
data = scipy.io.loadmat(cfg.DATA_PATH)
445+
# normalize data
446+
p_max = data["p"].max(axis=0)
447+
p_min = data["p"].min(axis=0)
448+
u_max = data["u"].max(axis=0)
449+
u_min = data["u"].min(axis=0)
450+
v_max = data["v"].max(axis=0)
451+
v_min = data["v"].min(axis=0)
452+
453+
from deploy.python_infer import pinn_predictor
454+
455+
predictor = pinn_predictor.PINNPredictor(cfg)
456+
# set time-geometry
457+
timestamps = np.linspace(0, 126, 127, endpoint=True)
458+
geom = {
459+
"time_rect_visu": ppsci.geometry.TimeXGeometry(
460+
ppsci.geometry.TimeDomain(1, 126, timestamps=timestamps),
461+
ppsci.geometry.Rectangle((0, 0), (15, 5)),
462+
),
463+
}
464+
NTIME_ALL = len(timestamps)
465+
NPOINT_PDE, NTIME_PDE = 300 * 100, NTIME_ALL - 1
466+
input_dict = geom["time_rect_visu"].sample_interior(
467+
NPOINT_PDE * NTIME_PDE, evenly=True
468+
)
469+
output_dict = predictor.predict(input_dict, cfg.INFER.batch_size)
470+
471+
# mapping data to cfg.INFER.output_keys
472+
output_dict = {
473+
store_key: output_dict[infer_key]
474+
for store_key, infer_key in zip(cfg.MODEL.output_keys, output_dict.keys())
475+
}
476+
477+
# inverse normalization
478+
p_pred = output_dict["p"].reshape([NTIME_PDE, NPOINT_PDE]).T
479+
u_pred = output_dict["u"].reshape([NTIME_PDE, NPOINT_PDE]).T
480+
v_pred = output_dict["v"].reshape([NTIME_PDE, NPOINT_PDE]).T
481+
pred = {
482+
"p": (p_pred * (p_max - p_min) + p_min).T.reshape([-1, 1]),
483+
"u": (u_pred * (u_max - u_min) + u_min).T.reshape([-1, 1]),
484+
"v": (v_pred * (v_max - v_min) + v_min).T.reshape([-1, 1]),
485+
"phil": output_dict["phil"],
486+
}
487+
ppsci.visualize.save_vtu_from_dict(
488+
"./visual/bubble_pred.vtu",
489+
{
490+
"t": input_dict["t"],
491+
"x": input_dict["x"],
492+
"y": input_dict["y"],
493+
"u": pred["u"],
494+
"v": pred["v"],
495+
"p": pred["p"],
496+
"phil": pred["phil"],
497+
},
498+
("t", "x", "y"),
499+
("u", "v", "p", "phil"),
500+
NTIME_PDE,
501+
)
502+
503+
406504
@hydra.main(version_base=None, config_path="./conf", config_name="bubble.yaml")
407505
def main(cfg: DictConfig):
408506
if cfg.mode == "train":
409507
train(cfg)
410508
elif cfg.mode == "eval":
411509
evaluate(cfg)
510+
elif cfg.mode == "export":
511+
export(cfg)
512+
elif cfg.mode == "infer":
513+
inference(cfg)
412514
else:
413-
raise ValueError(f"cfg.mode should in ['train', 'eval'], but got '{cfg.mode}'")
515+
raise ValueError(
516+
f"cfg.mode should in ['train', 'eval', 'export', 'infer'], but got '{cfg.mode}'"
517+
)
414518

415519

416520
if __name__ == "__main__":

examples/bubble/conf/bubble.yaml

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ MODEL:
4646
num_layers: 9
4747
hidden_size: 30
4848
activation: "tanh"
49+
output_keys: ["u", "v", "p", "phil"]
4950

5051
# training settings
5152
TRAIN:
@@ -65,3 +66,21 @@ TRAIN:
6566
EVAL:
6667
pretrained_model_path: null
6768
eval_with_no_grad: true
69+
70+
# inference settings
71+
INFER:
72+
pretrained_model_path: https://paddle-org.bj.bcebos.com/paddlescience/models/bubble/bubble_pretrained.pdparams
73+
export_path: ./inference/bubble
74+
pdmodel_path: ${INFER.export_path}.pdmodel
75+
pdiparams_path: ${INFER.export_path}.pdiparams
76+
onnx_path: ${INFER.export_path}.onnx
77+
device: gpu
78+
engine: native
79+
precision: fp32
80+
ir_optim: true
81+
min_subgraph_size: 5
82+
gpu_mem: 2000
83+
gpu_id: 0
84+
max_batch_size: 8192
85+
num_cpu_threads: 10
86+
batch_size: 8192

0 commit comments

Comments
 (0)