Skip to content

Commit dbdb813

Browse files
[Fix] Fix phycrnet bug (#894)
* fix phycrnet bug * Update examples/phycrnet/functions.py Co-authored-by: zzm <[email protected]> * Update examples/phycrnet/functions.py Co-authored-by: zzm <[email protected]> --------- Co-authored-by: zzm <[email protected]>
1 parent 48a91f9 commit dbdb813

File tree

4 files changed

+17
-9
lines changed

4 files changed

+17
-9
lines changed

docs/zh/examples/phycrnet.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ $$u_t+u\cdot \nabla u -\nu u =0$$
5454
## 3. 问题求解
5555

5656
### 3.1 模型构建
57+
5758
在这一部分中,我们介绍 PhyCRNet 的架构,包括编码器-解码器模块、残差连接、自回归(AR)过程和基于过滤的微分。网络架构如图所示。编码器(黄色Encoder,包含3个卷积层),用于从输入状态变量 $u(t=i),i = 0,1,2,..,T-1$ 学习低维潜在特征,其中 $T$ 表示总时间步。我们应用 ReLU 作为卷积层的激活函数。然后,我们将ConvLSTM层的输出(Encoder得到的低分辨率),潜在特征的时间传播器(绿色部分),其中,输出的LSTM的记忆单元 $C_i$ 和LSTM的隐藏变量单元 $h_i$ 会作为下一个时间步的输入。这样做的好处是对低维变量的基本动态进行建模,能够准确地捕获时间依赖性,同时有助于减轻记忆负担。 使用 LSTM 的另一个优势来自输出状态的双曲正切函数,它可以保持平滑的梯度曲线,并将值控制在 -1 和 1 之间。在建立低分辨率LSTM卷积循环方案后,我们基于上采样操作Decoder(蓝色部分)直接将低分辨率潜在空间重建为高分辨率量。特别注明,应用了子像素卷积层(即像素shuffle),因为与反卷积相比,它具有更好的效率和重建精度,且伪像更少。 最后,我们添加另一个卷积层,用于将有界潜变量空间输出,缩放回原始的物理空间。该Decoder后面没有激活函数。 此外,值得一提的是,鉴于输入变量数量有限及其对超分辨率的缺陷,我们在 PhyCRNet 中没有考虑 batch normalization。 作为替代,我们使用 batch normalization 来训练网络,以实现训练加速和更好的收敛性。受到动力学中,Forward Eular Scheme 的启发,我们在输入状态变量 $u_i$ 和输出变量 $u_{i+1}$ 之间附加全局残差连接。具体网络结构如下图所示:
5859

5960
![image](https://paddle-org.bj.bcebos.com/paddlescience/docs/phycrnet/PhyCRnet.png)
@@ -94,6 +95,7 @@ examples/phycrnet/conf/burgers_equations.yaml:34:42
9495

9596
### 3.2 数据载入
9697
我们使用RK4或者谱方法生成的数据(初值为使用正态分布生成),需要从.mat文件中将其读入,:
98+
9799
``` py linenums="54"
98100
--8<--
99101
examples/phycrnet/main.py:54:72
@@ -139,17 +141,21 @@ $$
139141
$$
140142

141143
这一步需要通过设置外界函数来进行,因此在训练过程中,我们使用`function.transform_out`来进行训练
144+
142145
``` py linenums="47"
143146
--8<--
144147
examples/phycrnet/main.py:47:51
145148
--8<--
146149
```
150+
147151
而在评估过程中,我们使用`function.tranform_output_val`来进行评估,并生成累计均方根误差。
152+
148153
``` py linenums="142"
149154
--8<--
150155
examples/phycrnet/main.py:142:142
151156
--8<--
152157
```
158+
153159
完成上述设置之后,只需要将上述实例化的对象按顺序传递给 `ppsci.solver.Solver`
154160

155161
``` py linenums="117"

examples/phycrnet/functions.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
import math
16+
import os
1617
from typing import Dict
1718

1819
import matplotlib.pyplot as plt
@@ -337,7 +338,8 @@ def get(self, epochs=1):
337338

338339

339340
def output_graph(model, input_dataset, fig_save_path, case_name):
340-
output_dataset = model(input_dataset)
341+
with paddle.no_grad():
342+
output_dataset = model(input_dataset)
341343
output = output_dataset["outputs"]
342344
input = input_dataset["input"][0]
343345
output = paddle.concat(tuple(output), axis=0)
@@ -378,7 +380,7 @@ def output_graph(model, input_dataset, fig_save_path, case_name):
378380
M[0, :] = 0
379381

380382
M = paddle.to_tensor(M)
381-
aRMSE = paddle.sqrt(M.T @ error)
383+
aRMSE = paddle.sqrt(M.T @ error).numpy()
382384
t = np.linspace(0, 4, N)
383385
plt.plot(t, aRMSE, color="r")
384386
plt.yscale("log")
@@ -393,7 +395,7 @@ def output_graph(model, input_dataset, fig_save_path, case_name):
393395
loc="upper left",
394396
)
395397
plt.title(case_name)
396-
plt.savefig(fig_save_path + "/error.jpg")
398+
plt.savefig(os.path.join(fig_save_path, "error.jpg"))
397399

398400
_, ax = plt.subplots(3, 4, figsize=(18, 12))
399401
ax[0, 0].contourf(ten_true[25, 0])
@@ -416,5 +418,5 @@ def output_graph(model, input_dataset, fig_save_path, case_name):
416418
ax[1, 3].contourf(ten_pred[99, 0])
417419
ax[2, 3].contourf(ten_true[99, 0] - ten_pred[99, 0])
418420
plt.title(case_name)
419-
plt.savefig(fig_save_path + "/Burgers.jpg")
421+
plt.savefig(os.path.join(fig_save_path, "Burgers.jpg"))
420422
plt.close()

ppsci/arch/phycrnet.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ def __init__(
147147
)
148148

149149
# ConvLSTM
150-
self.ConvLSTM = paddle.nn.LayerList(
150+
self.convlstm = paddle.nn.LayerList(
151151
[
152152
ConvLSTMCell(
153153
input_channels=self.input_channels[i],
@@ -194,16 +194,16 @@ def forward(self, x):
194194
x = encoder(x)
195195

196196
# convlstm
197-
for i, LSTM in enumerate(self.ConvLSTM):
197+
for i, lstm in enumerate(self.convlstm, self.num_encoder):
198198
if step == 0:
199-
(h, c) = LSTM.init_hidden_tensor(
199+
(h, c) = lstm.init_hidden_tensor(
200200
prev_state=self.initial_state[i - self.num_encoder]
201201
)
202202
internal_state.append((h, c))
203203

204204
# one-step forward
205205
(h, c) = internal_state[i - self.num_encoder]
206-
x, new_c = LSTM(x, h, c)
206+
x, new_c = lstm(x, h, c)
207207
internal_state[i - self.num_encoder] = (x, new_c)
208208

209209
# output

ppsci/utils/download.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ def _download(url, path, md5sum=None):
157157
if chunk:
158158
f.write(chunk)
159159
shutil.move(tmp_fullname, fullname)
160-
logger.message(f"Finished downloading pretrained model and saved to {fullname}")
160+
logger.message(f"Finish downloading pretrained model and saved to {fullname}")
161161

162162
return fullname
163163

0 commit comments

Comments
 (0)