Skip to content

Commit 2a1e85a

Browse files
1want2sleepUnityLikerHydrogenSulfate
authored
【PPSCI Export&Infer No.25】bracket (#878)
* ppsci.equation.PDE.parameters/state_dict/set_state_dict api fix * ppsci.equation.PDE.parameters/state_dict/set_state_dict api fix * fix api docs in the timedomain * fix api docs of timedomain * fix api docs of timedomain * ppsci api docs fixed * ppsci api docs fixed * ppsci api docs fixed * add export and infer for bracket * updata bracket doc * solve conflict according to the branch named develop * Update examples/bracket/conf/bracket.yaml * Update examples/bracket/conf/bracket.yaml * Update examples/bracket/conf/bracket.yaml * add export&inference for bracket --------- Co-authored-by: krp <[email protected]> Co-authored-by: HydrogenSulfate <[email protected]>
1 parent a5f69e4 commit 2a1e85a

File tree

3 files changed

+99
-0
lines changed

3 files changed

+99
-0
lines changed

docs/zh/examples/bracket.md

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,24 @@
2626
python bracket.py mode=eval EVAL.pretrained_model_path=https://paddle-org.bj.bcebos.com/paddlescience/models/bracket/bracket_pretrained.pdparams
2727
```
2828

29+
=== "模型导出命令"
30+
31+
``` sh
32+
python bracket.py mode=export
33+
```
34+
35+
=== "模型推理命令"
36+
37+
``` sh
38+
# linux
39+
wget -nc https://paddle-org.bj.bcebos.com/paddlescience/datasets/bracket/bracket_dataset.tar
40+
# windows
41+
# curl https://paddle-org.bj.bcebos.com/paddlescience/datasets/bracket/bracket_dataset.tar --output bracket_dataset.tar
42+
# unzip it
43+
tar -xvf bracket_dataset.tar
44+
python bracket.py mode=infer
45+
```
46+
2947
| 预训练模型 | 指标 |
3048
|:--| :--|
3149
| [bracket_pretrained.pdparams](https://paddle-org.bj.bcebos.com/paddlescience/models/bracket/bracket_pretrained.pdparams) | loss(commercial_ref_u_v_w_sigmas): 32.28704<br>MSE.u(commercial_ref_u_v_w_sigmas): 0.00005<br>MSE.v(commercial_ref_u_v_w_sigmas): 0.00000<br>MSE.w(commercial_ref_u_v_w_sigmas): 0.00734<br>MSE.sigma_xx(commercial_ref_u_v_w_sigmas): 27.64751<br>MSE.sigma_yy(commercial_ref_u_v_w_sigmas): 1.23101<br>MSE.sigma_zz(commercial_ref_u_v_w_sigmas): 0.89106<br>MSE.sigma_xy(commercial_ref_u_v_w_sigmas): 0.84370<br>MSE.sigma_xz(commercial_ref_u_v_w_sigmas): 1.42126<br>MSE.sigma_yz(commercial_ref_u_v_w_sigmas): 0.24510 |

examples/bracket/bracket.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -514,12 +514,75 @@ def evaluate(cfg: DictConfig):
514514
solver.visualize()
515515

516516

517+
def export(cfg: DictConfig):
518+
# set model
519+
disp_net = ppsci.arch.MLP(**cfg.MODEL.disp_net)
520+
stress_net = ppsci.arch.MLP(**cfg.MODEL.stress_net)
521+
# wrap to a model_list
522+
model = ppsci.arch.ModelList((disp_net, stress_net))
523+
524+
# initialize solver
525+
solver = ppsci.solver.Solver(
526+
model,
527+
pretrained_model_path=cfg.INFER.pretrained_model_path,
528+
)
529+
530+
# export model
531+
from paddle.static import InputSpec
532+
533+
input_spec = [
534+
{key: InputSpec([None, 1], "float32", name=key) for key in model.input_keys},
535+
]
536+
solver.export(input_spec, cfg.INFER.export_path)
537+
538+
539+
def inference(cfg: DictConfig):
540+
from deploy.python_infer import pinn_predictor
541+
542+
predictor = pinn_predictor.PINNPredictor(cfg)
543+
ref_xyzu = ppsci.utils.reader.load_csv_file(
544+
cfg.DEFORMATION_X_PATH,
545+
("x", "y", "z", "u"),
546+
{
547+
"x": "X Location (m)",
548+
"y": "Y Location (m)",
549+
"z": "Z Location (m)",
550+
"u": "Directional Deformation (m)",
551+
},
552+
"\t",
553+
)
554+
input_dict = {
555+
"x": ref_xyzu["x"],
556+
"y": ref_xyzu["y"],
557+
"z": ref_xyzu["z"],
558+
}
559+
output_dict = predictor.predict(input_dict, cfg.INFER.batch_size)
560+
561+
# mapping data to cfg.INFER.output_keys
562+
output_keys = cfg.MODEL.disp_net.output_keys + cfg.MODEL.stress_net.output_keys
563+
output_dict = {
564+
store_key: output_dict[infer_key]
565+
for store_key, infer_key in zip(output_keys, output_dict.keys())
566+
}
567+
568+
ppsci.visualize.save_vtu_from_dict(
569+
"./bracket_pred",
570+
{**input_dict, **output_dict},
571+
input_dict.keys(),
572+
output_keys,
573+
)
574+
575+
517576
@hydra.main(version_base=None, config_path="./conf", config_name="bracket.yaml")
518577
def main(cfg: DictConfig):
519578
if cfg.mode == "train":
520579
train(cfg)
521580
elif cfg.mode == "eval":
522581
evaluate(cfg)
582+
elif cfg.mode == "export":
583+
export(cfg)
584+
elif cfg.mode == "infer":
585+
inference(cfg)
523586
else:
524587
raise ValueError(f"cfg.mode should in ['train', 'eval'], but got '{cfg.mode}'")
525588

examples/bracket/conf/bracket.yaml

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,3 +102,21 @@ EVAL:
102102
eval_with_no_grad: true
103103
batch_size:
104104
sup_validator: 128
105+
106+
# inference settings
107+
INFER:
108+
pretrained_model_path: "https://paddle-org.bj.bcebos.com/paddlescience/models/bracket/bracket_pretrained.pdparams"
109+
export_path: ./inference/bracket
110+
pdmodel_path: ${INFER.export_path}.pdmodel
111+
pdiparams_path: ${INFER.export_path}.pdiparams
112+
device: gpu
113+
engine: native
114+
precision: fp32
115+
onnx_path: ${INFER.export_path}.onnx
116+
ir_optim: true
117+
min_subgraph_size: 10
118+
gpu_mem: 4000
119+
gpu_id: 0
120+
max_batch_size: 128
121+
num_cpu_threads: 4
122+
batch_size: 128

0 commit comments

Comments
 (0)