1
1
# Allen-Cahn
2
2
3
- <!-- < a href="TODO " class="md-button md-button--primary" style>AI Studio快速体验</a> -- >
3
+ <a href =" https://aistudio.baidu.com/projectdetail/7927786 " class =" md-button md-button--primary " style >AI Studio快速体验</a >
4
4
5
5
=== "模型训练命令"
6
6
7
7
``` sh
8
- python allen_cahn_default.py
8
+ # linux
9
+ wget -nc https://paddle-org.bj.bcebos.com/paddlescience/datasets/AllenCahn/allen_cahn.mat -P ./dataset/
10
+ # windows
11
+ # curl https://paddle-org.bj.bcebos.com/paddlescience/datasets/AllenCahn/allen_cahn.mat --output ./dataset/antiderivative_unaligned_train.npz
12
+ python allen_cahn_piratenet.py
9
13
```
10
14
11
15
=== "模型评估命令"
12
16
13
17
``` sh
14
- python allen_cahn_default.py mode=eval EVAL.pretrained_model_path=https://paddle-org.bj.bcebos.com/paddlescience/models/allen_cahn/allen_cahn_default_pretrained.pdparams
18
+ # linux
19
+ wget -nc https://paddle-org.bj.bcebos.com/paddlescience/datasets/AllenCahn/allen_cahn.mat -P ./dataset/
20
+ # windows
21
+ # curl https://paddle-org.bj.bcebos.com/paddlescience/datasets/AllenCahn/allen_cahn.mat --output ./dataset/antiderivative_unaligned_train.npz
22
+ python allen_cahn_piratenet.py mode=eval EVAL.pretrained_model_path=https://paddle-org.bj.bcebos.com/paddlescience/models/AllenCahn/allen_cahn_piratenet_pretrained.pdparams
15
23
```
16
24
17
25
=== "模型导出命令"
18
26
19
27
``` sh
20
- python allen_cahn_default .py mode=export
28
+ python allen_cahn_piratenet .py mode=export
21
29
```
22
30
23
31
=== "模型推理命令"
24
32
25
33
``` sh
26
- python allen_cahn_default.py mode=infer
34
+ # linux
35
+ wget -nc https://paddle-org.bj.bcebos.com/paddlescience/datasets/AllenCahn/allen_cahn.mat -P ./dataset/
36
+ # windows
37
+ # curl https://paddle-org.bj.bcebos.com/paddlescience/datasets/AllenCahn/allen_cahn.mat --output ./dataset/antiderivative_unaligned_train.npz
38
+ python allen_cahn_piratenet.py mode=infer
27
39
```
28
40
29
41
| 预训练模型 | 指标 |
30
42
| :--| :--|
31
- | [ allen_cahn_default_pretrained .pdparams] ( TODO ) | TODO |
43
+ | [ allen_cahn_piratenet_pretrained .pdparams] ( https://paddle-org.bj.bcebos.com/paddlescience/models/AllenCahn/allen_cahn_piratenet_pretrained.pdparams ) | L2Rel.u: 8.32403e-06 |
32
44
33
45
## 1. 背景简介
34
46
72
84
### 3.1 模型构建
73
85
74
86
在 Allen-Cahn 问题中,每一个已知的坐标点 $(t, x)$ 都有对应的待求解的未知量 $(u)$,
75
- ,在这里使用比较简单的 MLP(Multilayer Perceptron, 多层感知机) 来表示 $(t, x)$ 到 $(u)$ 的映射函数 $f: \mathbb{R}^2 \to \mathbb{R}^1$ ,即:
87
+ ,在这里使用 PirateNet 来表示 $(t, x)$ 到 $(u)$ 的映射函数 $f: \mathbb{R}^2 \to \mathbb{R}^1$ ,即:
76
88
77
89
$$
78
90
u = f(t, x)
79
91
$$
80
92
81
- 上式中 $f$ 即为 MLP 模型本身,用 PaddleScience 代码表示如下
93
+ 上式中 $f$ 即为 PirateNet 模型本身,用 PaddleScience 代码表示如下
82
94
83
95
``` py linenums="63"
84
96
-- 8 < --
85
- examples/ allen_cahn/ allen_cahn_default .py:63 :64
97
+ examples/ allen_cahn/ allen_cahn_piratenet .py:63 :64
86
98
-- 8 < --
87
99
```
88
100
89
101
为了在计算时,准确快速地访问具体变量的值,在这里指定网络模型的输入变量名是 ` ("t", "x") ` ,输出变量名是 ` ("u") ` ,这些命名与后续代码保持一致。
90
102
91
- 接着通过指定 MLP 的层数、神经元个数,就实例化出了一个拥有 4 层隐藏神经元,每层神经元数为 256 的神经网络模型 ` model ` ,使用 ` tanh ` 作为激活函数。
103
+ 接着通过指定 PirateNet 的层数、神经元个数,就实例化出了一个拥有 3 个 PiraBlock,每个 PiraBlock 的隐层神经元个数为 256 的神经网络模型 ` model ` , 并且使用 ` tanh ` 作为激活函数。
92
104
93
- ``` yaml linenums="35 "
105
+ ``` yaml linenums="34 "
94
106
--8<--
95
- examples/allen_cahn/conf/allen_cahn_default .yaml:35:41
107
+ examples/allen_cahn/conf/allen_cahn_piratenet .yaml:34:40
96
108
--8<--
97
109
```
98
110
@@ -102,7 +114,7 @@ Allen-Cahn 微分方程可以用如下代码表示:
102
114
103
115
``` py linenums="66"
104
116
-- 8 < --
105
- examples/ allen_cahn/ allen_cahn_default .py:66 :67
117
+ examples/ allen_cahn/ allen_cahn_piratenet .py:66 :67
106
118
-- 8 < --
107
119
```
108
120
@@ -112,7 +124,7 @@ examples/allen_cahn/allen_cahn_default.py:66:67
112
124
113
125
``` py linenums="69"
114
126
-- 8 < --
115
- examples/ allen_cahn/ allen_cahn_default .py:69 :81
127
+ examples/ allen_cahn/ allen_cahn_piratenet .py:69 :81
116
128
-- 8 < --
117
129
```
118
130
@@ -124,7 +136,7 @@ examples/allen_cahn/allen_cahn_default.py:69:81
124
136
125
137
``` py linenums="94"
126
138
-- 8 < --
127
- examples/ allen_cahn/ allen_cahn_default .py:94 :110
139
+ examples/ allen_cahn/ allen_cahn_piratenet .py:94 :110
128
140
-- 8 < --
129
141
```
130
142
@@ -139,11 +151,11 @@ examples/allen_cahn/allen_cahn_default.py:94:110
139
151
#### 3.4.2 周期边界约束
140
152
141
153
此处我们采用 hard-constraint 的方式,在神经网络模型中,对输入数据使用cos、sin等周期函数进行周期化,从而让$u_ {\theta}$在数学上直接满足方程的周期性质。
142
- 根据方程可得函数$u(t, x)$在$x$轴上的周期为2 ,因此将该周期设置到模型配置里即可。
154
+ 根据方程可得函数$u(t, x)$在$x$轴上的周期为 2 ,因此将该周期设置到模型配置里即可。
143
155
144
- ``` yaml linenums="35 "
156
+ ``` yaml linenums="41 "
145
157
--8<--
146
- examples/allen_cahn/conf/allen_cahn_default .yaml:35:43
158
+ examples/allen_cahn/conf/allen_cahn_piratenet .yaml:41:42
147
159
--8<--
148
160
```
149
161
@@ -153,25 +165,25 @@ examples/allen_cahn/conf/allen_cahn_default.yaml:35:43
153
165
154
166
``` py linenums="112"
155
167
-- 8 < --
156
- examples/ allen_cahn/ allen_cahn_default .py:112 :125
168
+ examples/ allen_cahn/ allen_cahn_piratenet .py:112 :125
157
169
-- 8 < --
158
170
```
159
171
160
172
在微分方程约束、初值约束构建完毕之后,以刚才的命名为关键字,封装到一个字典中,方便后续访问。
161
173
162
174
``` py linenums="126"
163
175
-- 8 < --
164
- examples/ allen_cahn/ allen_cahn_default .py:126 :130
176
+ examples/ allen_cahn/ allen_cahn_piratenet .py:126 :130
165
177
-- 8 < --
166
178
```
167
179
168
180
### 3.5 超参数设定
169
181
170
- 接下来需要指定训练轮数和学习率,此处按实验经验,使用 200 轮训练轮数,0.001 的初始学习率。
182
+ 接下来需要指定训练轮数和学习率,此处按实验经验,使用 300 轮训练轮数,0.001 的初始学习率。
171
183
172
- ``` yaml linenums="51 "
184
+ ``` yaml linenums="50 "
173
185
--8<--
174
- examples/allen_cahn/conf/allen_cahn_default .yaml:51:73
186
+ examples/allen_cahn/conf/allen_cahn_piratenet .yaml:50:63
175
187
--8<--
176
188
```
177
189
@@ -181,7 +193,7 @@ examples/allen_cahn/conf/allen_cahn_default.yaml:51:73
181
193
182
194
``` py linenums="132"
183
195
-- 8 < --
184
- examples/ allen_cahn/ allen_cahn_default .py:132 :136
196
+ examples/ allen_cahn/ allen_cahn_piratenet .py:132 :136
185
197
-- 8 < --
186
198
```
187
199
@@ -191,7 +203,7 @@ examples/allen_cahn/allen_cahn_default.py:132:136
191
203
192
204
``` py linenums="138"
193
205
-- 8 < --
194
- examples/ allen_cahn/ allen_cahn_default .py:138 :156
206
+ examples/ allen_cahn/ allen_cahn_piratenet .py:138 :156
195
207
-- 8 < --
196
208
```
197
209
@@ -201,15 +213,15 @@ examples/allen_cahn/allen_cahn_default.py:138:156
201
213
202
214
``` py linenums="158"
203
215
-- 8 < --
204
- examples/ allen_cahn/ allen_cahn_default .py:158 :194
216
+ examples/ allen_cahn/ allen_cahn_piratenet .py:158 :184
205
217
-- 8 < --
206
218
```
207
219
208
220
## 4. 完整代码
209
221
210
- ``` py linenums="1" title="allen_cahn_default .py"
222
+ ``` py linenums="1" title="allen_cahn_piratenet .py"
211
223
-- 8 < --
212
- examples/ allen_cahn/ allen_cahn_default .py
224
+ examples/ allen_cahn/ allen_cahn_piratenet .py
213
225
-- 8 < --
214
226
```
215
227
@@ -218,12 +230,13 @@ examples/allen_cahn/allen_cahn_default.py
218
230
在计算域上均匀采样出 $201\times501$ 个点,其预测结果和解析解如下图所示。
219
231
220
232
<figure markdown >
221
- ![ allen_cahn_default .jpg] ( https://paddle-org.bj.bcebos.com/paddlescience/docs/AllenCahn/allen_cahn_default .png ) { loading=lazy }
233
+ ![ allen_cahn_piratenet .jpg] ( https://paddle-org.bj.bcebos.com/paddlescience/docs/AllenCahn/allen_cahn_piratenet_ac .png ) { loading=lazy }
222
234
<figcaption > 左侧为 PaddleScience 预测结果,中间为解析解结果,右侧为两者的差值</figcaption >
223
235
</figure >
224
236
225
237
可以看到对于函数$u(t, x)$,模型的预测结果和解析解的结果基本一致。
226
238
227
239
## 6. 参考资料
228
240
241
+ - [ PIRATENETS: PHYSICS-INFORMED DEEP LEARNING WITHRESIDUAL ADAPTIVE NETWORKS] ( https://arxiv.org/pdf/2402.00326.pdf )
229
242
- [ Allen-Cahn equation] ( https://github.com/PredictiveIntelligenceLab/jaxpi/blob/main/examples/allen_cahn/README.md )
0 commit comments