Skip to content

Commit 3f26f48

Browse files
authored
Add export & inference for hPINNs (#902)
* feat: add export and infer functions for hpinns * fix:register transform brfore export
1 parent 1701459 commit 3f26f48

File tree

3 files changed

+106
-1
lines changed

3 files changed

+106
-1
lines changed

docs/zh/examples/hpinns.md

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,26 @@
2626
python holography.py mode=eval EVAL.pretrained_model_path=https://paddle-org.bj.bcebos.com/paddlescience/models/hPINNs/hpinns_pretrained.pdparams
2727
```
2828

29+
=== "模型导出命令"
30+
31+
``` sh
32+
python holography.py mode=export
33+
```
34+
35+
=== "模型推理命令"
36+
37+
``` sh
38+
# linux
39+
wget -nc https://paddle-org.bj.bcebos.com/paddlescience/datasets/hPINNs/hpinns_holo_train.mat -P ./datasets/
40+
wget -nc https://paddle-org.bj.bcebos.com/paddlescience/datasets/hPINNs/hpinns_holo_valid.mat -P ./datasets/
41+
# windows
42+
# curl https://paddle-org.bj.bcebos.com/paddlescience/datasets/hPINNs/hpinns_holo_train.mat --output ./datasets/hpinns_holo_train.mat
43+
# curl https://paddle-org.bj.bcebos.com/paddlescience/datasets/hPINNs/hpinns_holo_valid.mat --output ./datasets/hpinns_holo_valid.mat
44+
python holography.py mode=infer
45+
```
46+
47+
48+
2949
| 预训练模型 | 指标 |
3050
|:--| :--|
3151
| [hpinns_pretrained.pdparams](https://paddle-org.bj.bcebos.com/paddlescience/models/hPINNs/hpinns_pretrained.pdparams) | loss(opt_sup): 0.05352<br>MSE.eval_metric(opt_sup): 0.00002<br>loss(val_sup): 0.02205<br>MSE.eval_metric(val_sup): 0.00001 |

examples/hpinns/conf/hpinns.yaml

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ seed: 42
2525
output_dir: ${hydra:run.dir}
2626
DATASET_PATH: ./datasets/hpinns_holo_train.mat
2727
DATASET_PATH_VALID: ./datasets/hpinns_holo_valid.mat
28+
log_freq: 20
2829

2930
# set working condition
3031
TRAIN_MODE: aug_lag # "soft", "penalty", "aug_lag"
@@ -65,3 +66,22 @@ TRAIN:
6566
# evaluation settings
6667
EVAL:
6768
pretrained_model_path: null
69+
70+
# inference settings
71+
INFER:
72+
pretrained_model_path: "https://paddle-org.bj.bcebos.com/paddlescience/models/hPINNs/hpinns_pretrained.pdparams"
73+
export_path: ./inference/hpinns
74+
pdmodel_path: ${INFER.export_path}.pdmodel
75+
pdiparams_path: ${INFER.export_path}.pdiparams
76+
output_keys: ["e_re", "e_im", "eps"]
77+
device: gpu
78+
engine: native
79+
precision: fp32
80+
onnx_path: ${INFER.export_path}.onnx
81+
ir_optim: true
82+
min_subgraph_size: 10
83+
gpu_mem: 8000
84+
gpu_id: 0
85+
batch_size: 128
86+
max_batch_size: 128
87+
num_cpu_threads: 4

examples/hpinns/holography.py

Lines changed: 66 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -409,14 +409,79 @@ def evaluate(cfg: DictConfig):
409409
solver.eval()
410410

411411

412+
def export(cfg: DictConfig):
413+
# set model
414+
model_re = ppsci.arch.MLP(**cfg.MODEL.re_net)
415+
model_im = ppsci.arch.MLP(**cfg.MODEL.im_net)
416+
model_eps = ppsci.arch.MLP(**cfg.MODEL.eps_net)
417+
418+
# register transform
419+
model_re.register_input_transform(func_module.transform_in)
420+
model_im.register_input_transform(func_module.transform_in)
421+
model_eps.register_input_transform(func_module.transform_in)
422+
423+
model_re.register_output_transform(func_module.transform_out_real_part)
424+
model_im.register_output_transform(func_module.transform_out_imaginary_part)
425+
model_eps.register_output_transform(func_module.transform_out_epsilon)
426+
427+
# wrap to a model_list
428+
model_list = ppsci.arch.ModelList((model_re, model_im, model_eps))
429+
430+
# initialize solver
431+
solver = ppsci.solver.Solver(
432+
model_list,
433+
pretrained_model_path=cfg.INFER.pretrained_model_path,
434+
)
435+
436+
# export model
437+
from paddle.static import InputSpec
438+
439+
input_spec = [
440+
{key: InputSpec([None, 1], "float32", name=key) for key in ["x", "y"]},
441+
]
442+
solver.export(input_spec, cfg.INFER.export_path)
443+
444+
445+
def inference(cfg: DictConfig):
446+
from deploy.python_infer import pinn_predictor
447+
448+
predictor = pinn_predictor.PINNPredictor(cfg)
449+
450+
valid_dict = ppsci.utils.reader.load_mat_file(
451+
cfg.DATASET_PATH_VALID, ("x_val", "y_val", "bound")
452+
)
453+
input_dict = {"x": valid_dict["x_val"], "y": valid_dict["y_val"]}
454+
455+
output_dict = predictor.predict(input_dict, cfg.INFER.batch_size)
456+
457+
# mapping data to cfg.INFER.output_keys
458+
output_dict = {
459+
store_key: output_dict[infer_key]
460+
for store_key, infer_key in zip(cfg.INFER.output_keys, output_dict.keys())
461+
}
462+
463+
ppsci.visualize.save_vtu_from_dict(
464+
"./hpinns_pred.vtu",
465+
{**input_dict, **output_dict},
466+
input_dict.keys(),
467+
cfg.INFER.output_keys,
468+
)
469+
470+
412471
@hydra.main(version_base=None, config_path="./conf", config_name="hpinns.yaml")
413472
def main(cfg: DictConfig):
414473
if cfg.mode == "train":
415474
train(cfg)
416475
elif cfg.mode == "eval":
417476
evaluate(cfg)
477+
elif cfg.mode == "export":
478+
export(cfg)
479+
elif cfg.mode == "infer":
480+
inference(cfg)
418481
else:
419-
raise ValueError(f"cfg.mode should in ['train', 'eval'], but got '{cfg.mode}'")
482+
raise ValueError(
483+
f"cfg.mode should in ['train', 'eval', 'export', 'infer'], but got '{cfg.mode}'"
484+
)
420485

421486

422487
if __name__ == "__main__":

0 commit comments

Comments
 (0)