Skip to content

Commit 58ad0b3

Browse files
【PPSCI Export&Infer No.23】viv (#832)
* eadd export and inference for viv * add doc * fix viv export&infer * Rewriting function * fix viv export&infer
1 parent fb16798 commit 58ad0b3

File tree

4 files changed

+110
-2
lines changed

4 files changed

+110
-2
lines changed

docs/zh/examples/viv.md

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,18 @@
1414
python viv.py mode=eval EVAL.pretrained_model_path=https://paddle-org.bj.bcebos.com/paddlescience/models/viv/viv_pretrained.pdparams
1515
```
1616

17+
=== "模型导出命令"
18+
19+
``` sh
20+
python viv.py mode=export
21+
```
22+
23+
=== "模型推理命令"
24+
25+
``` sh
26+
python viv.py mode=infer
27+
```
28+
1729
| 预训练模型 | 指标 |
1830
|:--| :--|
1931
| [viv_pretrained.pdparams](https://paddle-org.bj.bcebos.com/paddlescience/models/aneurysm/viv_pretrained.pdparams)<br>[viv_pretrained.pdeqn](https://paddle-org.bj.bcebos.com/paddlescience/models/aneurysm/viv_pretrained.pdeqn) | 'eta': 1.1416150300647132e-06<br>'f': 4.635014192899689e-06 |

examples/fsi/conf/viv.yaml

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@ hydra:
1111
- TRAIN.checkpoint_path
1212
- TRAIN.pretrained_model_path
1313
- EVAL.pretrained_model_path
14+
- INFER.pretrained_model_path
15+
- INFER.export_path
1416
- mode
1517
- output_dir
1618
- log_freq
@@ -60,3 +62,23 @@ TRAIN:
6062
EVAL:
6163
pretrained_model_path: null
6264
batch_size: 32
65+
66+
# inference settings
67+
INFER:
68+
pretrained_model_path: "https://paddle-org.bj.bcebos.com/paddlescience/models/viv/viv_pretrained.pdparams"
69+
export_path: ./inference/viv
70+
pdmodel_path: ${INFER.export_path}.pdmodel
71+
pdpiparams_path: ${INFER.export_path}.pdiparams
72+
input_keys: ${MODEL.input_keys}
73+
output_keys: ["eta", "f"]
74+
device: gpu
75+
engine: native
76+
precision: fp32
77+
onnx_path: ${INFER.export_path}.onnx
78+
ir_optim: true
79+
min_subgraph_size: 10
80+
gpu_mem: 4000
81+
gpu_id: 0
82+
max_batch_size: 64
83+
num_cpu_threads: 4
84+
batch_size: 16

examples/fsi/viv.py

Lines changed: 75 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -200,14 +200,88 @@ def evaluate(cfg: DictConfig):
200200
solver.visualize()
201201

202202

203+
def export(cfg: DictConfig):
204+
from paddle import nn
205+
from paddle.static import InputSpec
206+
207+
# set model
208+
model = ppsci.arch.MLP(**cfg.MODEL)
209+
# initialize equation
210+
equation = {"VIV": ppsci.equation.Vibration(2, -4, 0)}
211+
# initialize solver
212+
solver = ppsci.solver.Solver(
213+
model,
214+
equation=equation,
215+
pretrained_model_path=cfg.INFER.pretrained_model_path,
216+
)
217+
# Convert equation to func
218+
f_func = ppsci.lambdify(
219+
solver.equation["VIV"].equations["f"],
220+
solver.model,
221+
list(solver.equation["VIV"].learnable_parameters),
222+
)
223+
224+
class Wrapped_Model(nn.Layer):
225+
def __init__(self, model, func):
226+
super().__init__()
227+
self.model = model
228+
self.func = func
229+
230+
def forward(self, x):
231+
model_out = self.model(x)
232+
func_out = self.func(x)
233+
return {**model_out, "f": func_out}
234+
235+
solver.model = Wrapped_Model(model, f_func)
236+
# export models
237+
input_spec = [
238+
{key: InputSpec([None, 1], "float32", name=key) for key in model.input_keys},
239+
]
240+
solver.export(input_spec, cfg.INFER.export_path, skip_prune_program=True)
241+
242+
243+
def inference(cfg: DictConfig):
244+
from deploy.python_infer import pinn_predictor
245+
246+
# set model predictor
247+
predictor = pinn_predictor.PINNPredictor(cfg)
248+
249+
infer_mat = ppsci.utils.reader.load_mat_file(
250+
cfg.VIV_DATA_PATH,
251+
("t_f", "eta_gt", "f_gt"),
252+
alias_dict={"eta_gt": "eta", "f_gt": "f"},
253+
)
254+
255+
input_dict = {key: infer_mat[key] for key in cfg.INFER.input_keys}
256+
257+
output_dict = predictor.predict(input_dict, cfg.INFER.batch_size)
258+
259+
# mapping data to cfg.INFER.output_keys
260+
output_dict = {
261+
store_key: output_dict[infer_key]
262+
for store_key, infer_key in zip(cfg.INFER.output_keys, output_dict.keys())
263+
}
264+
infer_mat.update(output_dict)
265+
266+
ppsci.visualize.plot.save_plot_from_1d_dict(
267+
"./viv_pred", infer_mat, ("t_f",), ("eta", "eta_gt", "f", "f_gt")
268+
)
269+
270+
203271
@hydra.main(version_base=None, config_path="./conf", config_name="viv.yaml")
204272
def main(cfg: DictConfig):
205273
if cfg.mode == "train":
206274
train(cfg)
207275
elif cfg.mode == "eval":
208276
evaluate(cfg)
277+
elif cfg.mode == "export":
278+
export(cfg)
279+
elif cfg.mode == "infer":
280+
inference(cfg)
209281
else:
210-
raise ValueError(f"cfg.mode should in ['train', 'eval'], but got '{cfg.mode}'")
282+
raise ValueError(
283+
f"cfg.mode should in ['train', 'eval', 'export', 'infer'], but got '{cfg.mode}'"
284+
)
211285

212286

213287
if __name__ == "__main__":

ppsci/utils/symbolic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -490,7 +490,7 @@ class ComposedNode(nn.Layer):
490490
def __init__(self, callable_nodes: List[Node]):
491491
super().__init__()
492492
assert len(callable_nodes)
493-
self.callable_nodes = callable_nodes
493+
self.callable_nodes = nn.LayerList(callable_nodes)
494494

495495
def forward(self, data_dict: DATA_DICT) -> paddle.Tensor:
496496
# call all callable_nodes in order

0 commit comments

Comments
 (0)