Skip to content

Commit be0802f

Browse files
[Fix] Adapt code to dy2st mode (#957)
* shallow copy input data in expression * shallow copy in ComposedNode.forward * update tempoGAN code linenumber * update code
1 parent d8e2053 commit be0802f

File tree

4 files changed

+23
-21
lines changed

4 files changed

+23
-21
lines changed

docs/zh/examples/tempoGAN.md

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -110,25 +110,25 @@ examples/tempoGAN/tempoGAN.py:57:76
110110

111111
Generator 的输入为低密度流体数据的插值,而数据集中保存的为原始的低密度流体数据,因此需要进行一个插值的 transform。
112112

113-
``` py linenums="269"
113+
``` py linenums="270"
114114
--8<--
115-
examples/tempoGAN/functions.py:269:274
115+
examples/tempoGAN/functions.py:270:275
116116
--8<--
117117
```
118118

119119
Discriminator 和 Discriminator_tempo 对输入的 transform 更为复杂,分别为:
120120

121-
``` py linenums="359"
121+
``` py linenums="360"
122122
--8<--
123-
examples/tempoGAN/functions.py:359:393
123+
examples/tempoGAN/functions.py:360:394
124124
--8<--
125125
```
126126

127127
其中:
128128

129-
``` py linenums="368"
129+
``` py linenums="369"
130130
--8<--
131-
examples/tempoGAN/functions.py:368:368
131+
examples/tempoGAN/functions.py:369:369
132132
--8<--
133133
```
134134

@@ -239,9 +239,9 @@ examples/tempoGAN/tempoGAN.py:205:244
239239

240240
因为 GAN 网络训练的特性,本问题不使用 PaddleScience 中内置的可视化器,而是自定义了一个用于实现推理的函数,该函数读取验证集数据,得到推理结果并将结果以图片形式保存下来,在训练过程中按照一定间隔调用该函数即可在训练过程中监控训练效果。
241241

242-
``` py linenums="153"
242+
``` py linenums="154"
243243
--8<--
244-
examples/tempoGAN/functions.py:153:229
244+
examples/tempoGAN/functions.py:154:230
245245
--8<--
246246
```
247247

@@ -253,39 +253,39 @@ examples/tempoGAN/functions.py:153:229
253253

254254
Generator 的 loss 提供了 l1 loss、l2 loss、输出经过 Discriminator 判断的 loss 和 输出经过 Discriminator_tempo 判断的 loss。这些 loss 是否存在根据权重参数控制,若某一项 loss 的权重参数为 0,则表示训练中不添加该 loss 项。
255255

256-
``` py linenums="276"
256+
``` py linenums="277"
257257
--8<--
258-
examples/tempoGAN/functions.py:276:345
258+
examples/tempoGAN/functions.py:277:346
259259
--8<--
260260
```
261261

262262
#### 3.8.2 Discriminator 的 loss
263263

264264
Discriminator 为判别器,它的作用是判断数据为真数据还是假数据,因此它的 loss 为 Generator 产生的数据应当判断为假而产生的 loss 和 目标值数据应当判断为真而产生的 loss。
265265

266-
``` py linenums="395"
266+
``` py linenums="396"
267267
--8<--
268-
examples/tempoGAN/functions.py:395:409
268+
examples/tempoGAN/functions.py:396:410
269269
--8<--
270270
```
271271

272272
#### 3.8.3 Discriminator_tempo 的 loss
273273

274274
Discriminator_tempo 的 loss 构成 与 Discriminator 相同,只是所需数据不同。
275275

276-
``` py linenums="411"
276+
``` py linenums="412"
277277
--8<--
278-
examples/tempoGAN/functions.py:411:427
278+
examples/tempoGAN/functions.py:412:428
279279
--8<--
280280
```
281281

282282
#### 3.8.4 自定义 data transform
283283

284284
本问题提供了一种输入数据处理方法,将输入的流体密度数据随机裁剪一块,然后进行密度值判断,若裁剪下来的块密度值低于阈值则重新裁剪,直到密度满足条件或裁剪次数达到阈值。这样做主要是为了减少训练所需的显存,同时对裁剪下来的块密度值的判断保证了块中信息的丰富程度。[参数和超参数设定](#34)`tile_ratio` 表示原始尺寸是块的尺寸的几倍,即若`tile_ratio` 为 2,裁剪下来的块的大小为整张原始图片的四分之一。
285285

286-
``` py linenums="430"
286+
``` py linenums="431"
287287
--8<--
288-
examples/tempoGAN/functions.py:430:488
288+
examples/tempoGAN/functions.py:431:489
289289
--8<--
290290
```
291291

examples/fsi/viv.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,7 @@ def __init__(self, model, func):
201201
self.func = func
202202

203203
def forward(self, x):
204+
x = {**x}
204205
model_out = self.model(x)
205206
func_out = self.func(x)
206207
return {**model_out, "f": func_out}

examples/tempoGAN/functions.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,11 +64,12 @@ def reshape_input(input_dict: Dict[str, paddle.Tensor]) -> Dict[str, paddle.Tens
6464
Returns:
6565
Dict[str, paddle.Tensor]: reshaped data dict.
6666
"""
67+
out_dict = {}
6768
for key in input_dict:
6869
input = input_dict[key]
6970
N, C, H, W = input.shape
70-
input_dict[key] = paddle.reshape(input, [N * C, 1, H, W])
71-
return input_dict
71+
out_dict[key] = paddle.reshape(input, [N * C, 1, H, W])
72+
return out_dict
7273

7374

7475
def dereshape_input(

ppsci/solver/solver.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -858,7 +858,7 @@ def predict(
858858
@misc.run_on_eval_mode
859859
def export(
860860
self,
861-
input_spec: List["InputSpec"],
861+
input_spec: List[Dict[str, InputSpec]],
862862
export_path: str,
863863
with_onnx: bool = False,
864864
skip_prune_program: bool = False,
@@ -870,8 +870,8 @@ def export(
870870
Convert model to static graph model and export to files.
871871
872872
Args:
873-
input_spec (List[InputSpec]): InputSpec describes the signature information
874-
of the model input.
873+
input_spec (List[Dict[str, InputSpec]]): InputSpec describes the signature
874+
information of the model input.
875875
export_path (str): The path prefix to save model.
876876
with_onnx (bool, optional): Whether to export model into onnx after
877877
paddle inference models are exported. Defaults to False.

0 commit comments

Comments
 (0)