Skip to content

Commit 1701459

Browse files
【PPSCI Export&Infer No.13】 darcy2d (#900)
* 【PPSCI Export&Infer No.13】 darcy2d * Update examples/darcy/darcy2d.py --------- Co-authored-by: HydrogenSulfate <[email protected]>
1 parent f17d1b3 commit 1701459

File tree

3 files changed

+83
-1
lines changed

3 files changed

+83
-1
lines changed

docs/zh/examples/darcy2d.md

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,18 @@
1414
python darcy2d.py mode=eval EVAL.pretrained_model_path=https://paddle-org.bj.bcebos.com/paddlescience/models/darcy2d/darcy2d_pretrained.pdparams
1515
```
1616

17+
=== "模型导出命令"
18+
19+
``` sh
20+
python darcy2d.py mode=export
21+
```
22+
23+
=== "模型推理命令"
24+
25+
``` sh
26+
python darcy2d.py mode=infer
27+
```
28+
1729
| 预训练模型 | 指标 |
1830
|:--| :--|
1931
| [darcy2d_pretrained.pdparams](https://paddle-org.bj.bcebos.com/paddlescience/models/darcy2d/darcy2d_pretrained.pdparams) | loss(Residual): 0.36500<br>MSE.poisson(Residual): 0.00006 |

examples/darcy/conf/darcy2d.yaml

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ hydra:
2323
mode: train # running mode: train/eval
2424
seed: 42
2525
output_dir: ${hydra:run.dir}
26+
log_freq: 20
2627

2728
# set working condition
2829
NPOINT_PDE: 9801 # 99 ** 2
@@ -62,3 +63,20 @@ EVAL:
6263
batch_size:
6364
residual_validator: 8192
6465
pretrained_model_path: null
66+
67+
INFER:
68+
pretrained_model_path: https://paddle-org.bj.bcebos.com/paddlescience/models/darcy2d/darcy2d_pretrained.pdparams
69+
export_path: ./inference/darcy2d
70+
pdmodel_path: ${INFER.export_path}.pdmodel
71+
pdiparams_path: ${INFER.export_path}.pdiparams
72+
onnx_path: ${INFER.export_path}.onnx
73+
device: gpu
74+
engine: native
75+
precision: fp32
76+
ir_optim: true
77+
min_subgraph_size: 5
78+
gpu_mem: 2000
79+
gpu_id: 0
80+
max_batch_size: 8192
81+
num_cpu_threads: 10
82+
batch_size: 8192

examples/darcy/darcy2d.py

Lines changed: 53 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -296,14 +296,66 @@ def poisson_ref_compute_func(_in):
296296
solver.visualize()
297297

298298

299+
def export(cfg: DictConfig):
300+
# set model
301+
model = ppsci.arch.MLP(**cfg.MODEL)
302+
303+
# initialize solver
304+
solver = ppsci.solver.Solver(
305+
model,
306+
pretrained_model_path=cfg.INFER.pretrained_model_path,
307+
)
308+
# export model
309+
from paddle.static import InputSpec
310+
311+
input_spec = [
312+
{key: InputSpec([None, 1], "float32", name=key) for key in model.input_keys},
313+
]
314+
315+
solver.export(input_spec, cfg.INFER.export_path)
316+
317+
318+
def inference(cfg: DictConfig):
319+
from deploy.python_infer import pinn_predictor
320+
321+
predictor = pinn_predictor.PINNPredictor(cfg)
322+
323+
# set geometry
324+
geom = {"rect": ppsci.geometry.Rectangle((0.0, 0.0), (1.0, 1.0))}
325+
# manually collate input data for visualization,
326+
input_dict = geom["rect"].sample_interior(
327+
cfg.NPOINT_PDE + cfg.NPOINT_BC, evenly=True
328+
)
329+
output_dict = predictor.predict(
330+
{key: input_dict[key] for key in cfg.MODEL.input_keys}, cfg.INFER.batch_size
331+
)
332+
# mapping data to cfg.INFER.output_keys
333+
output_dict = {
334+
store_key: output_dict[infer_key]
335+
for store_key, infer_key in zip(cfg.MODEL.output_keys, output_dict.keys())
336+
}
337+
ppsci.visualize.save_vtu_from_dict(
338+
"./visual/darcy2d.vtu",
339+
{**input_dict, **output_dict},
340+
input_dict.keys(),
341+
cfg.MODEL.output_keys,
342+
)
343+
344+
299345
@hydra.main(version_base=None, config_path="./conf", config_name="darcy2d.yaml")
300346
def main(cfg: DictConfig):
301347
if cfg.mode == "train":
302348
train(cfg)
303349
elif cfg.mode == "eval":
304350
evaluate(cfg)
351+
elif cfg.mode == "export":
352+
export(cfg)
353+
elif cfg.mode == "infer":
354+
inference(cfg)
305355
else:
306-
raise ValueError(f"cfg.mode should in ['train', 'eval'], but got '{cfg.mode}'")
356+
raise ValueError(
357+
f"cfg.mode should in ['train', 'eval', 'export', 'infer'], but got '{cfg.mode}'"
358+
)
307359

308360

309361
if __name__ == "__main__":

0 commit comments

Comments
 (0)