Skip to content

Commit 7168b9a

Browse files
[Fea] Add PirateNet and update allen_cahn document (#907)
* Add PiraNet and update allen_cahn document * fix example code for mlp.py * rename pira to pirate * update AIStudio link for allen cahn
1 parent 970adf9 commit 7168b9a

File tree

10 files changed

+711
-60
lines changed

10 files changed

+711
-60
lines changed

docs/zh/api/arch.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
- AMGNet
99
- MLP
1010
- ModifiedMLP
11+
- PirateNet
1112
- DeepONet
1213
- DeepPhyLSTM
1314
- LorenzEmbedding

docs/zh/examples/allen_cahn.md

Lines changed: 42 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,34 +1,46 @@
11
# Allen-Cahn
22

3-
<!-- <a href="TODO" class="md-button md-button--primary" style>AI Studio快速体验</a> -->
3+
<a href="https://aistudio.baidu.com/projectdetail/7927786" class="md-button md-button--primary" style>AI Studio快速体验</a>
44

55
=== "模型训练命令"
66

77
``` sh
8-
python allen_cahn_default.py
8+
# linux
9+
wget -nc https://paddle-org.bj.bcebos.com/paddlescience/datasets/AllenCahn/allen_cahn.mat -P ./dataset/
10+
# windows
11+
# curl https://paddle-org.bj.bcebos.com/paddlescience/datasets/AllenCahn/allen_cahn.mat --output ./dataset/antiderivative_unaligned_train.npz
12+
python allen_cahn_piratenet.py
913
```
1014

1115
=== "模型评估命令"
1216

1317
``` sh
14-
python allen_cahn_default.py mode=eval EVAL.pretrained_model_path=https://paddle-org.bj.bcebos.com/paddlescience/models/allen_cahn/allen_cahn_default_pretrained.pdparams
18+
# linux
19+
wget -nc https://paddle-org.bj.bcebos.com/paddlescience/datasets/AllenCahn/allen_cahn.mat -P ./dataset/
20+
# windows
21+
# curl https://paddle-org.bj.bcebos.com/paddlescience/datasets/AllenCahn/allen_cahn.mat --output ./dataset/antiderivative_unaligned_train.npz
22+
python allen_cahn_piratenet.py mode=eval EVAL.pretrained_model_path=https://paddle-org.bj.bcebos.com/paddlescience/models/AllenCahn/allen_cahn_piratenet_pretrained.pdparams
1523
```
1624

1725
=== "模型导出命令"
1826

1927
``` sh
20-
python allen_cahn_default.py mode=export
28+
python allen_cahn_piratenet.py mode=export
2129
```
2230

2331
=== "模型推理命令"
2432

2533
``` sh
26-
python allen_cahn_default.py mode=infer
34+
# linux
35+
wget -nc https://paddle-org.bj.bcebos.com/paddlescience/datasets/AllenCahn/allen_cahn.mat -P ./dataset/
36+
# windows
37+
# curl https://paddle-org.bj.bcebos.com/paddlescience/datasets/AllenCahn/allen_cahn.mat --output ./dataset/antiderivative_unaligned_train.npz
38+
python allen_cahn_piratenet.py mode=infer
2739
```
2840

2941
| 预训练模型 | 指标 |
3042
|:--| :--|
31-
| [allen_cahn_default_pretrained.pdparams](TODO) | TODO |
43+
| [allen_cahn_piratenet_pretrained.pdparams](https://paddle-org.bj.bcebos.com/paddlescience/models/AllenCahn/allen_cahn_piratenet_pretrained.pdparams) | L2Rel.u: 8.32403e-06 |
3244

3345
## 1. 背景简介
3446

@@ -72,27 +84,27 @@ $$
7284
### 3.1 模型构建
7385

7486
在 Allen-Cahn 问题中,每一个已知的坐标点 $(t, x)$ 都有对应的待求解的未知量 $(u)$,
75-
在这里使用比较简单的 MLP(Multilayer Perceptron, 多层感知机) 来表示 $(t, x)$ 到 $(u)$ 的映射函数 $f: \mathbb{R}^2 \to \mathbb{R}^1$ ,即:
87+
在这里使用 PirateNet 来表示 $(t, x)$ 到 $(u)$ 的映射函数 $f: \mathbb{R}^2 \to \mathbb{R}^1$ ,即:
7688

7789
$$
7890
u = f(t, x)
7991
$$
8092

81-
上式中 $f$ 即为 MLP 模型本身,用 PaddleScience 代码表示如下
93+
上式中 $f$ 即为 PirateNet 模型本身,用 PaddleScience 代码表示如下
8294

8395
``` py linenums="63"
8496
--8<--
85-
examples/allen_cahn/allen_cahn_default.py:63:64
97+
examples/allen_cahn/allen_cahn_piratenet.py:63:64
8698
--8<--
8799
```
88100

89101
为了在计算时,准确快速地访问具体变量的值,在这里指定网络模型的输入变量名是 `("t", "x")`,输出变量名是 `("u")`,这些命名与后续代码保持一致。
90102

91-
接着通过指定 MLP 的层数、神经元个数,就实例化出了一个拥有 4 层隐藏神经元,每层神经元数为 256 的神经网络模型 `model`使用 `tanh` 作为激活函数。
103+
接着通过指定 PirateNet 的层数、神经元个数,就实例化出了一个拥有 3 个 PiraBlock,每个 PiraBlock 的隐层神经元个数为 256 的神经网络模型 `model` 并且使用 `tanh` 作为激活函数。
92104

93-
``` yaml linenums="35"
105+
``` yaml linenums="34"
94106
--8<--
95-
examples/allen_cahn/conf/allen_cahn_default.yaml:35:41
107+
examples/allen_cahn/conf/allen_cahn_piratenet.yaml:34:40
96108
--8<--
97109
```
98110

@@ -102,7 +114,7 @@ Allen-Cahn 微分方程可以用如下代码表示:
102114

103115
``` py linenums="66"
104116
--8<--
105-
examples/allen_cahn/allen_cahn_default.py:66:67
117+
examples/allen_cahn/allen_cahn_piratenet.py:66:67
106118
--8<--
107119
```
108120

@@ -112,7 +124,7 @@ examples/allen_cahn/allen_cahn_default.py:66:67
112124

113125
``` py linenums="69"
114126
--8<--
115-
examples/allen_cahn/allen_cahn_default.py:69:81
127+
examples/allen_cahn/allen_cahn_piratenet.py:69:81
116128
--8<--
117129
```
118130

@@ -124,7 +136,7 @@ examples/allen_cahn/allen_cahn_default.py:69:81
124136

125137
``` py linenums="94"
126138
--8<--
127-
examples/allen_cahn/allen_cahn_default.py:94:110
139+
examples/allen_cahn/allen_cahn_piratenet.py:94:110
128140
--8<--
129141
```
130142

@@ -139,11 +151,11 @@ examples/allen_cahn/allen_cahn_default.py:94:110
139151
#### 3.4.2 周期边界约束
140152

141153
此处我们采用 hard-constraint 的方式,在神经网络模型中,对输入数据使用cos、sin等周期函数进行周期化,从而让$u_{\theta}$在数学上直接满足方程的周期性质。
142-
根据方程可得函数$u(t, x)$在$x$轴上的周期为2,因此将该周期设置到模型配置里即可。
154+
根据方程可得函数$u(t, x)$在$x$轴上的周期为 2,因此将该周期设置到模型配置里即可。
143155

144-
``` yaml linenums="35"
156+
``` yaml linenums="41"
145157
--8<--
146-
examples/allen_cahn/conf/allen_cahn_default.yaml:35:43
158+
examples/allen_cahn/conf/allen_cahn_piratenet.yaml:41:42
147159
--8<--
148160
```
149161

@@ -153,25 +165,25 @@ examples/allen_cahn/conf/allen_cahn_default.yaml:35:43
153165

154166
``` py linenums="112"
155167
--8<--
156-
examples/allen_cahn/allen_cahn_default.py:112:125
168+
examples/allen_cahn/allen_cahn_piratenet.py:112:125
157169
--8<--
158170
```
159171

160172
在微分方程约束、初值约束构建完毕之后,以刚才的命名为关键字,封装到一个字典中,方便后续访问。
161173

162174
``` py linenums="126"
163175
--8<--
164-
examples/allen_cahn/allen_cahn_default.py:126:130
176+
examples/allen_cahn/allen_cahn_piratenet.py:126:130
165177
--8<--
166178
```
167179

168180
### 3.5 超参数设定
169181

170-
接下来需要指定训练轮数和学习率,此处按实验经验,使用 200 轮训练轮数,0.001 的初始学习率。
182+
接下来需要指定训练轮数和学习率,此处按实验经验,使用 300 轮训练轮数,0.001 的初始学习率。
171183

172-
``` yaml linenums="51"
184+
``` yaml linenums="50"
173185
--8<--
174-
examples/allen_cahn/conf/allen_cahn_default.yaml:51:73
186+
examples/allen_cahn/conf/allen_cahn_piratenet.yaml:50:63
175187
--8<--
176188
```
177189

@@ -181,7 +193,7 @@ examples/allen_cahn/conf/allen_cahn_default.yaml:51:73
181193

182194
``` py linenums="132"
183195
--8<--
184-
examples/allen_cahn/allen_cahn_default.py:132:136
196+
examples/allen_cahn/allen_cahn_piratenet.py:132:136
185197
--8<--
186198
```
187199

@@ -191,7 +203,7 @@ examples/allen_cahn/allen_cahn_default.py:132:136
191203

192204
``` py linenums="138"
193205
--8<--
194-
examples/allen_cahn/allen_cahn_default.py:138:156
206+
examples/allen_cahn/allen_cahn_piratenet.py:138:156
195207
--8<--
196208
```
197209

@@ -201,15 +213,15 @@ examples/allen_cahn/allen_cahn_default.py:138:156
201213

202214
``` py linenums="158"
203215
--8<--
204-
examples/allen_cahn/allen_cahn_default.py:158:194
216+
examples/allen_cahn/allen_cahn_piratenet.py:158:184
205217
--8<--
206218
```
207219

208220
## 4. 完整代码
209221

210-
``` py linenums="1" title="allen_cahn_default.py"
222+
``` py linenums="1" title="allen_cahn_piratenet.py"
211223
--8<--
212-
examples/allen_cahn/allen_cahn_default.py
224+
examples/allen_cahn/allen_cahn_piratenet.py
213225
--8<--
214226
```
215227

@@ -218,12 +230,13 @@ examples/allen_cahn/allen_cahn_default.py
218230
在计算域上均匀采样出 $201\times501$ 个点,其预测结果和解析解如下图所示。
219231

220232
<figure markdown>
221-
![allen_cahn_default.jpg](https://paddle-org.bj.bcebos.com/paddlescience/docs/AllenCahn/allen_cahn_default.png){ loading=lazy }
233+
![allen_cahn_piratenet.jpg](https://paddle-org.bj.bcebos.com/paddlescience/docs/AllenCahn/allen_cahn_piratenet_ac.png){ loading=lazy }
222234
<figcaption> 左侧为 PaddleScience 预测结果,中间为解析解结果,右侧为两者的差值</figcaption>
223235
</figure>
224236

225237
可以看到对于函数$u(t, x)$,模型的预测结果和解析解的结果基本一致。
226238

227239
## 6. 参考资料
228240

241+
- [PIRATENETS: PHYSICS-INFORMED DEEP LEARNING WITHRESIDUAL ADAPTIVE NETWORKS](https://arxiv.org/pdf/2402.00326.pdf)
229242
- [Allen-Cahn equation](https://github.com/PredictiveIntelligenceLab/jaxpi/blob/main/examples/allen_cahn/README.md)

examples/allen_cahn/allen_cahn_causal.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -271,12 +271,12 @@ def inference(cfg: DictConfig):
271271

272272
input_dict = {"t": tx_star[:, 0:1], "x": tx_star[:, 1:2]}
273273
output_dict = predictor.predict(input_dict, cfg.INFER.batch_size)
274+
# mapping data to cfg.INFER.output_keys
274275
output_dict = {
275276
store_key: output_dict[infer_key]
276277
for store_key, infer_key in zip(cfg.MODEL.output_keys, output_dict.keys())
277278
}
278279
u_pred = output_dict["u"].reshape([len(t_star), len(x_star)])
279-
# mapping data to cfg.INFER.output_keys
280280

281281
plot(t_star, x_star, u_ref, u_pred, cfg.output_dir)
282282

examples/allen_cahn/allen_cahn_default.py

Lines changed: 4 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -159,19 +159,9 @@ def gen_label_batch(input_batch):
159159
solver = ppsci.solver.Solver(
160160
model,
161161
constraint,
162-
cfg.output_dir,
163-
optimizer,
164-
epochs=cfg.TRAIN.epochs,
165-
iters_per_epoch=cfg.TRAIN.iters_per_epoch,
166-
save_freq=cfg.TRAIN.save_freq,
167-
log_freq=cfg.log_freq,
168-
eval_during_train=True,
169-
eval_freq=cfg.TRAIN.eval_freq,
162+
optimizer=optimizer,
170163
equation=equation,
171164
validator=validator,
172-
pretrained_model_path=cfg.TRAIN.pretrained_model_path,
173-
checkpoint_path=cfg.TRAIN.checkpoint_path,
174-
eval_with_no_grad=cfg.EVAL.eval_with_no_grad,
175165
loss_aggregator=mtl.GradNorm(
176166
model,
177167
len(constraint),
@@ -226,11 +216,9 @@ def evaluate(cfg: DictConfig):
226216
# initialize solver
227217
solver = ppsci.solver.Solver(
228218
model,
229-
output_dir=cfg.output_dir,
230219
log_freq=cfg.log_freq,
231220
validator=validator,
232-
pretrained_model_path=cfg.EVAL.pretrained_model_path,
233-
eval_with_no_grad=cfg.EVAL.eval_with_no_grad,
221+
cfg=cfg,
234222
)
235223

236224
# evaluate after finished training
@@ -250,10 +238,7 @@ def export(cfg: DictConfig):
250238
model = ppsci.arch.MLP(**cfg.MODEL)
251239

252240
# initialize solver
253-
solver = ppsci.solver.Solver(
254-
model,
255-
pretrained_model_path=cfg.INFER.pretrained_model_path,
256-
)
241+
solver = ppsci.solver.Solver(model, cfg=cfg)
257242
# export model
258243
from paddle.static import InputSpec
259244

@@ -275,12 +260,12 @@ def inference(cfg: DictConfig):
275260

276261
input_dict = {"t": tx_star[:, 0:1], "x": tx_star[:, 1:2]}
277262
output_dict = predictor.predict(input_dict, cfg.INFER.batch_size)
263+
# mapping data to cfg.INFER.output_keys
278264
output_dict = {
279265
store_key: output_dict[infer_key]
280266
for store_key, infer_key in zip(cfg.MODEL.output_keys, output_dict.keys())
281267
}
282268
u_pred = output_dict["u"].reshape([len(t_star), len(x_star)])
283-
# mapping data to cfg.INFER.output_keys
284269

285270
plot(t_star, x_star, u_ref, u_pred, cfg.output_dir)
286271

0 commit comments

Comments
 (0)