Skip to content

Commit 0717f06

Browse files
【PPSCI Export&Infer No.22】VP_NSFNet4 (#864)
* [SCI Export&Infer No.24] biharmonic2d * P[PSCI Export&Infer No.724] biharmonic2d fix * add export&infer nsfnet4 * add export&infer nsfnet4
1 parent c0fa8ca commit 0717f06

File tree

3 files changed

+205
-3
lines changed

3 files changed

+205
-3
lines changed

docs/zh/examples/nsfnet4.md

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,25 @@
2929

3030
```
3131

32+
=== "模型导出命令"
33+
34+
``` sh
35+
python VP_NSFNet4.py mode=export
36+
```
37+
38+
=== "模型推理命令"
39+
40+
``` sh
41+
# VP_NSFNet4
42+
# linux
43+
wget -nc https://paddle-org.bj.bcebos.com/paddlescience/datasets/NSFNet/NSF4_data.zip -P ./data/
44+
unzip ./data/NSF4_data.zip
45+
# windows
46+
# curl https://paddle-org.bj.bcebos.com/paddlescience/datasets/NSFNet/NSF4_data.zip --output ./data/NSF4_data.zip
47+
# unzip ./data/NSF4_data.zip
48+
python VP_NSFNet4.py mode=infer
49+
```
50+
3251
## 1. 背景简介
3352

3453
最近几年, 深度学习在很多领域取得了非凡的成就, 尤其是计算机视觉和自然语言处理方面, 而受启发于深度学习的快速发展, 基于深度学习强大的函数逼近能力, 神经网络在科学计算领域也取得了成功, 现阶段的研究主要分为两大类, 一类是将物理信息以及物理限制加入损失函数来对神经网络进行训练, 其代表有 PINN 以及 Deep Ritz Net, 另一类是通过数据驱动的深度神经网络算子, 其代表有 FNO 以及 DeepONet。这些方法都在科学实践中获得了广泛应用, 比如天气预测, 量子化学, 生物工程, 以及计算流体等领域。而为充分探索PINN对流体方程的求解能力, 本次复现[论文](https://arxiv.org/abs/2003.06496)作者设计了NSFNets, 并且先后使用具有解析解或数值解的二维、三维纳韦斯托克方程以及使用DNS方法进行高精度求解的数据集作为参考, 进行正问题求解训练。论文实验表明PINN对不可压纳韦斯托克方程具有优秀的数值求解能力, 本项目主要目标是使用PaddleScience复现论文所实现的高精度求解纳韦斯托克方程的代码。

examples/nsfnet/VP_NSFNet4.py

Lines changed: 167 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -389,7 +389,7 @@ def evaluate(cfg: DictConfig):
389389
t_plot = paddle.to_tensor((t[-1]) * np.ones(x_plot.shape), paddle.float32)
390390
sol = model({"x": x_plot, "y": y_plot, "z": z_plot, "t": t_plot})
391391
fig, ax = plt.subplots(1, 4, figsize=(16, 4))
392-
cmap = plt.cm.get_cmap("jet")
392+
cmap = matplotlib.colormaps.get_cmap("jet")
393393

394394
ax[0].contourf(grid_x, grid_y, sol["u"].reshape(grid_x.shape), levels=50, cmap=cmap)
395395
ax[0].set_title("u prediction")
@@ -422,7 +422,167 @@ def evaluate(cfg: DictConfig):
422422
t_plot = paddle.to_tensor((t[-1]) * np.ones(x_plot.shape), paddle.float32)
423423
sol = model({"x": x_plot, "y": y_plot, "z": z_plot, "t": t_plot})
424424
fig, ax = plt.subplots(1, 4, figsize=(16, 4))
425-
cmap = plt.cm.get_cmap("jet")
425+
cmap = matplotlib.colormaps.get_cmap("jet")
426+
427+
ax[0].contourf(grid_y, grid_z, sol["u"].reshape(grid_x.shape), levels=50, cmap=cmap)
428+
ax[0].set_title("u prediction")
429+
ax[1].contourf(grid_y, grid_z, sol["v"].reshape(grid_x.shape), levels=50, cmap=cmap)
430+
ax[1].set_title("v prediction")
431+
ax[2].contourf(grid_y, grid_z, sol["w"].reshape(grid_x.shape), levels=50, cmap=cmap)
432+
ax[2].set_title("w prediction")
433+
ax[3].contourf(grid_y, grid_z, sol["p"].reshape(grid_x.shape), levels=50, cmap=cmap)
434+
ax[3].set_title("p prediction")
435+
norm = matplotlib.colors.Normalize(
436+
vmin=sol["u"].min(), vmax=sol["u"].max()
437+
) # set maximum and minimum
438+
im = plt.cm.ScalarMappable(norm=norm, cmap=cmap)
439+
ax13 = fig.add_axes([0.125, 0.0, 0.175, 0.02])
440+
plt.colorbar(im, cax=ax13, orientation="horizontal")
441+
ax13 = fig.add_axes([0.325, 0.0, 0.175, 0.02])
442+
plt.colorbar(im, cax=ax13, orientation="horizontal")
443+
ax13 = fig.add_axes([0.525, 0.0, 0.175, 0.02])
444+
plt.colorbar(im, cax=ax13, orientation="horizontal")
445+
ax13 = fig.add_axes([0.725, 0.0, 0.175, 0.02])
446+
plt.colorbar(im, cax=ax13, orientation="horizontal")
447+
plt.savefig(osp.join(cfg.output_dir, "x=0 plane"))
448+
449+
450+
def export(cfg: DictConfig):
451+
from paddle.static import InputSpec
452+
453+
# set models
454+
model = ppsci.arch.MLP(**cfg.MODEL)
455+
456+
# load pretrained model
457+
solver = ppsci.solver.Solver(
458+
model=model, pretrained_model_path=cfg.INFER.pretrained_model_path
459+
)
460+
461+
# export models
462+
input_spec = [
463+
{key: InputSpec([None, 1], "float32", name=key) for key in model.input_keys},
464+
]
465+
solver.export(input_spec, cfg.INFER.export_path)
466+
467+
468+
def inference(cfg: DictConfig):
469+
from deploy.python_infer import pinn_predictor
470+
471+
# set model predictor
472+
predictor = pinn_predictor.PINNPredictor(cfg)
473+
474+
# infer Data
475+
test_x = np.load(osp.join(cfg.data_dir, "test43_l.npy")).astype(np.float32)
476+
test_v = np.load(osp.join(cfg.data_dir, "test43_vp.npy")).astype(np.float32)
477+
t = np.array([0.0065, 4 * 0.0065, 7 * 0.0065, 10 * 0.0065, 13 * 0.0065]).astype(
478+
np.float32
479+
)
480+
t_star = np.tile(t.reshape(5, 1), (1, 3000)).reshape(-1, 1)
481+
x_star = np.tile(test_x[:, 0:1], (5, 1)).reshape(-1, 1)
482+
y_star = np.tile(test_x[:, 1:2], (5, 1)).reshape(-1, 1)
483+
z_star = np.tile(test_x[:, 2:3], (5, 1)).reshape(-1, 1)
484+
u_star = test_v[:, 0:1]
485+
v_star = test_v[:, 1:2]
486+
w_star = test_v[:, 2:3]
487+
p_star = test_v[:, 3:4]
488+
489+
pred = predictor.predict(
490+
{
491+
"x": x_star,
492+
"y": y_star,
493+
"z": z_star,
494+
"t": t_star,
495+
},
496+
cfg.INFER.batch_size,
497+
)
498+
499+
pred = {
500+
store_key: pred[infer_key]
501+
for store_key, infer_key in zip(cfg.INFER.output_keys, pred.keys())
502+
}
503+
504+
u_pred = pred["u"].reshape((5, -1))
505+
v_pred = pred["v"].reshape((5, -1))
506+
w_pred = pred["w"].reshape((5, -1))
507+
p_pred = pred["p"].reshape((5, -1))
508+
u_star = u_star.reshape((5, -1))
509+
v_star = v_star.reshape((5, -1))
510+
w_star = w_star.reshape((5, -1))
511+
p_star = p_star.reshape((5, -1))
512+
513+
# NS equation can figure out pressure drop, need background pressure p_star.mean()
514+
p_pred = p_pred - p_pred.mean() + p_star.mean()
515+
516+
u_error = np.linalg.norm(u_pred - u_star, axis=1) / np.linalg.norm(u_star, axis=1)
517+
v_error = np.linalg.norm(v_pred - v_star, axis=1) / np.linalg.norm(v_star, axis=1)
518+
w_error = np.linalg.norm(w_pred - w_star, axis=1) / np.linalg.norm(w_star, axis=1)
519+
p_error = np.linalg.norm(p_pred - p_star, axis=1) / np.linalg.norm(w_star, axis=1)
520+
t = np.array([0.0065, 4 * 0.0065, 7 * 0.0065, 10 * 0.0065, 13 * 0.0065])
521+
plt.plot(t, np.array(u_error))
522+
plt.plot(t, np.array(v_error))
523+
plt.plot(t, np.array(w_error))
524+
plt.plot(t, np.array(p_error))
525+
plt.legend(["u_error", "v_error", "w_error", "p_error"])
526+
plt.xlabel("t")
527+
plt.ylabel("Relative l2 Error")
528+
plt.title("Relative l2 Error, on test dataset")
529+
plt.savefig(osp.join(cfg.output_dir, "error.jpg"))
530+
531+
grid_x, grid_y = np.mgrid[
532+
x_star.min() : x_star.max() : 100j, y_star.min() : y_star.max() : 100j
533+
].astype(np.float32)
534+
x_plot = grid_x.reshape(-1, 1)
535+
y_plot = grid_y.reshape(-1, 1)
536+
z_plot = (z_star.min() * np.ones(y_plot.shape)).astype(np.float32)
537+
t_plot = ((t[-1]) * np.ones(x_plot.shape)).astype(np.float32)
538+
sol = predictor.predict(
539+
{"x": x_plot, "y": y_plot, "z": z_plot, "t": t_plot}, cfg.INFER.batch_size
540+
)
541+
sol = {
542+
store_key: sol[infer_key]
543+
for store_key, infer_key in zip(cfg.INFER.output_keys, sol.keys())
544+
}
545+
fig, ax = plt.subplots(1, 4, figsize=(16, 4))
546+
cmap = matplotlib.colormaps.get_cmap("jet")
547+
548+
ax[0].contourf(grid_x, grid_y, sol["u"].reshape(grid_x.shape), levels=50, cmap=cmap)
549+
ax[0].set_title("u prediction")
550+
ax[1].contourf(grid_x, grid_y, sol["v"].reshape(grid_x.shape), levels=50, cmap=cmap)
551+
ax[1].set_title("v prediction")
552+
ax[2].contourf(grid_x, grid_y, sol["w"].reshape(grid_x.shape), levels=50, cmap=cmap)
553+
ax[2].set_title("w prediction")
554+
ax[3].contourf(grid_x, grid_y, sol["p"].reshape(grid_x.shape), levels=50, cmap=cmap)
555+
ax[3].set_title("p prediction")
556+
norm = matplotlib.colors.Normalize(
557+
vmin=sol["u"].min(), vmax=sol["u"].max()
558+
) # set maximum and minimum
559+
im = plt.cm.ScalarMappable(norm=norm, cmap=cmap)
560+
ax13 = fig.add_axes([0.125, 0.0, 0.175, 0.02])
561+
plt.colorbar(im, cax=ax13, orientation="horizontal")
562+
ax13 = fig.add_axes([0.325, 0.0, 0.175, 0.02])
563+
plt.colorbar(im, cax=ax13, orientation="horizontal")
564+
ax13 = fig.add_axes([0.525, 0.0, 0.175, 0.02])
565+
plt.colorbar(im, cax=ax13, orientation="horizontal")
566+
ax13 = fig.add_axes([0.725, 0.0, 0.175, 0.02])
567+
plt.colorbar(im, cax=ax13, orientation="horizontal")
568+
plt.savefig(osp.join(cfg.output_dir, "z=0 plane"))
569+
570+
grid_y, grid_z = np.mgrid[
571+
y_star.min() : y_star.max() : 100j, z_star.min() : z_star.max() : 100j
572+
].astype(np.float32)
573+
z_plot = grid_z.reshape(-1, 1)
574+
y_plot = grid_y.reshape(-1, 1)
575+
x_plot = (x_star.min() * np.ones(y_plot.shape)).astype(np.float32)
576+
t_plot = ((t[-1]) * np.ones(x_plot.shape)).astype(np.float32)
577+
sol = predictor.predict(
578+
{"x": x_plot, "y": y_plot, "z": z_plot, "t": t_plot}, cfg.INFER.batch_size
579+
)
580+
sol = {
581+
store_key: sol[infer_key]
582+
for store_key, infer_key in zip(cfg.INFER.output_keys, sol.keys())
583+
}
584+
fig, ax = plt.subplots(1, 4, figsize=(16, 4))
585+
cmap = matplotlib.colormaps.get_cmap("jet")
426586

427587
ax[0].contourf(grid_y, grid_z, sol["u"].reshape(grid_x.shape), levels=50, cmap=cmap)
428588
ax[0].set_title("u prediction")
@@ -453,9 +613,13 @@ def main(cfg: DictConfig):
453613
train(cfg)
454614
elif cfg.mode == "eval":
455615
evaluate(cfg)
616+
elif cfg.mode == "export":
617+
export(cfg)
618+
elif cfg.mode == "infer":
619+
inference(cfg)
456620
else:
457621
raise ValueError(
458-
osp.join("cfg.mode should in ['train', 'eval'], but got", cfg.mode)
622+
f"cfg.mode should in ['train', 'eval', 'export', 'infer'], but got '{cfg.mode}'"
459623
)
460624

461625

examples/nsfnet/conf/VP_NSFNet4.yaml

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ hydra:
2323
seed: 1234
2424
output_dir: ${hydra:run.dir}
2525
data_dir: ./data/
26+
log_freq: 20
2627
MODEL:
2728
input_keys: ["x", "y","z","t"]
2829
output_keys: ["u", "v", "w","p"]
@@ -52,3 +53,21 @@ EVAL:
5253
pretrained_model_path: null
5354
eval_with_no_grad: true
5455

56+
57+
INFER:
58+
pretrained_model_path: https://paddle-org.bj.bcebos.com/paddlescience/models/nsfnet/nsfnet4.pdparams
59+
export_path: ./inference/VP_NSFNet4
60+
pdmodel_path: ${INFER.export_path}.pdmodel
61+
pdpiparams_path: ${INFER.export_path}.pdiparams
62+
output_keys: ['p', 'u', 'v', 'w']
63+
device: gpu
64+
engine: native
65+
precision: fp32
66+
onnx_path: ${INFER.export_path}.onnx
67+
ir_optim: true
68+
min_subgraph_size: 10
69+
gpu_mem: 4000
70+
gpu_id: 0
71+
max_batch_size: 64
72+
num_cpu_threads: 4
73+
batch_size: 16

0 commit comments

Comments
 (0)