Skip to content

Commit c3c239d

Browse files
[Fix] Fix extformer-moe (#940)
* fix&refine extformer code and docs * replace paddle.nn. with nn. * update Extformer-MoE in docs
1 parent d39b52c commit c3c239d

20 files changed

+361
-391
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ PaddleScience 是一个基于深度学习框架 PaddlePaddle 开发的科学计
9292

9393
| 问题类型 | 案例名称 | 优化算法 | 模型类型 | 训练方式 | 数据集 | 参考资料 |
9494
|-----|---------|-----|---------|----|---------|---------|
95+
| 天气预报 | [Extformer-MoE 气象预报](https://paddlescience-docs.readthedocs.io/zh/latest/zh/examples/extformer_moe.md) | 数据驱动 | FourCastNet | 监督学习 | [enso](https://tianchi.aliyun.com/dataset/98942) | - |
9596
| 天气预报 | [FourCastNet 气象预报](https://paddlescience-docs.readthedocs.io/zh/latest/zh/examples/fourcastnet) | 数据驱动 | FourCastNet | 监督学习 | [ERA5](https://app.globus.org/file-manager?origin_id=945b3c9e-0f8c-11ed-8daf-9f359c660fbd&origin_path=%2F~%2Fdata%2F) | [Paper](https://arxiv.org/pdf/2202.11214.pdf) |
9697
| 天气预报 | [NowCastNet 气象预报](https://paddlescience-docs.readthedocs.io/zh/latest/zh/examples/nowcastnet) | 数据驱动 | NowCastNet | 监督学习 | [MRMS](https://app.globus.org/file-manager?origin_id=945b3c9e-0f8c-11ed-8daf-9f359c660fbd&origin_path=%2F~%2Fdata%2F) | [Paper](https://www.nature.com/articles/s41586-023-06184-4) |
9798
| 天气预报 | [GraphCast 气象预报](https://paddlescience-docs.readthedocs.io/zh/latest/zh/examples/graphcast) | 数据驱动 | GraphCastNet | 监督学习 | - | [Paper](https://arxiv.org/abs/2212.12794) |

docs/index.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,7 @@
137137

138138
| 问题类型 | 案例名称 | 优化算法 | 模型类型 | 训练方式 | 数据集 | 参考资料 |
139139
|-----|---------|-----|---------|----|---------|---------|
140+
| 天气预报 | [Extformer-MoE 气象预报](./zh/examples/extformer_moe.md) | 数据驱动 | FourCastNet | 监督学习 | [enso](https://tianchi.aliyun.com/dataset/98942) | - |
140141
| 天气预报 | [FourCastNet 气象预报](./zh/examples/fourcastnet.md) | 数据驱动 | FourCastNet | 监督学习 | [ERA5](https://app.globus.org/file-manager?origin_id=945b3c9e-0f8c-11ed-8daf-9f359c660fbd&origin_path=%2F~%2Fdata%2F) | [Paper](https://arxiv.org/pdf/2202.11214.pdf) |
141142
| 天气预报 | [NowCastNet 气象预报](./zh/examples/nowcastnet.md) | 数据驱动 | NowCastNet | 监督学习 | [MRMS](https://app.globus.org/file-manager?origin_id=945b3c9e-0f8c-11ed-8daf-9f359c660fbd&origin_path=%2F~%2Fdata%2F) | [Paper](https://www.nature.com/articles/s41586-023-06184-4) |
142143
| 天气预报 | [GraphCast 气象预报](./zh/examples/graphcast.md) | 数据驱动 | GraphCastNet | 监督学习 | - | [Paper](https://arxiv.org/abs/2212.12794) |

docs/zh/examples/extformer_moe.md

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,18 @@
11
# Extformer-MoE
22

3-
开始训练、评估前,请先下载,并对应修改 yaml 配置文件中的 FILE_PATH
3+
!!! note
44

5-
[ICAR-ENSO数据集](https://tianchi.aliyun.com/dataset/98942)
5+
1. 开始训练、评估前,请先下载 [ICAR-ENSO数据集](https://tianchi.aliyun.com/dataset/98942),并对应修改 yaml 配置文件中的 `FILE_PATH` 为解压后的数据集路径。
6+
2. 开始训练、评估前,请安装 `xarray` 和 `h5netcdf`:`pip install requirements.txt`
7+
3. 若训练时显存不足,可指定 `MODEL.checkpoint_level` 为 `1` 或 `2`,此时使用 recompute 模式运行,以训练时间换取显存。
68

79
=== "模型训练命令"
810

911
``` sh
1012
# ICAR-ENSO 数据预训练模型: Extformer-MoE
1113
python extformer_moe_enso_train.py
14+
# python extformer_moe_enso_train.py MODEL.checkpoint_level=1 # using recompute to run in device with small GPU memory
15+
# python extformer_moe_enso_train.py MODEL.checkpoint_level=2 # using recompute to run in device with small GPU memory
1216
```
1317

1418
=== "模型评估命令"
@@ -46,7 +50,6 @@ Earthformer,一种用于地球系统预测的时空转换器。为了更好地
4650

4751
Rank-N-Contrast(RNC)是一种表征学习方法,旨在学习一种回归感知的样本表征,该表征以连续标签空间中的距离为依据,对嵌入空间中的样本间距离进行排序,然后利用它来预测最终连续的标签。在地球系统极端预测问题中,RNC 可以对气象数据的表征进行规范,使其满足嵌入空间的连续性,和标签空间对齐,最终缓解极端事件的预测结果的过平滑问题。
4852

49-
5053
## 2. 模型原理
5154

5255
### 2.1 Earthformer

examples/extformer_moe/extformer_moe_enso_train.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
1+
import enso_metric
12
import hydra
23
import paddle
34
from omegaconf import DictConfig
45
from omegaconf import OmegaConf
56
from paddle import nn
67

7-
import examples.extformer_moe.enso_metric as enso_metric
88
import ppsci
99

1010

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
h5netcdf
2+
xarray==2024.2.0

mkdocs.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ nav:
8484
- 材料科学(AI for Material):
8585
- hPINNs: zh/examples/hpinns.md
8686
- 地球科学(AI for Earth Science):
87+
- Extformer-MoE: zh/examples/extformer_moe.md
8788
- FourCastNet: zh/examples/fourcastnet.md
8889
- NowcastNet: zh/examples/nowcastnet.md
8990
- DGMR: zh/examples/dgmr.md

ppsci/arch/activation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def __init__(self, beta: float = 1.0):
5151
super().__init__()
5252
self.beta = self.create_parameter(
5353
shape=[],
54-
default_initializer=paddle.nn.initializer.Constant(beta),
54+
default_initializer=nn.initializer.Constant(beta),
5555
)
5656

5757
def forward(self, x):

ppsci/arch/amgnet.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -238,21 +238,21 @@ def faster_graph_connectivity(perm, edge_index, edge_weight, score, pos, N, norm
238238
value_A = edge_weight.clone()
239239

240240
value_A = paddle.squeeze(value_A)
241-
model_1 = paddle.nn.Sequential(
242-
("l1", paddle.nn.Linear(128, 256)),
243-
("act1", paddle.nn.ReLU()),
244-
("l2", paddle.nn.Linear(256, 256)),
245-
("act2", paddle.nn.ReLU()),
246-
("l4", paddle.nn.Linear(256, 128)),
247-
("act4", paddle.nn.ReLU()),
248-
("l5", paddle.nn.Linear(128, 1)),
241+
model_1 = nn.Sequential(
242+
("l1", nn.Linear(128, 256)),
243+
("act1", nn.ReLU()),
244+
("l2", nn.Linear(256, 256)),
245+
("act2", nn.ReLU()),
246+
("l4", nn.Linear(256, 128)),
247+
("act4", nn.ReLU()),
248+
("l5", nn.Linear(128, 1)),
249249
)
250-
model_2 = paddle.nn.Sequential(
251-
("l1", paddle.nn.Linear(1, 64)),
252-
("act1", paddle.nn.ReLU()),
253-
("l2", paddle.nn.Linear(64, 128)),
254-
("act2", paddle.nn.ReLU()),
255-
("l4", paddle.nn.Linear(128, 128)),
250+
model_2 = nn.Sequential(
251+
("l1", nn.Linear(1, 64)),
252+
("act1", nn.ReLU()),
253+
("l2", nn.Linear(64, 128)),
254+
("act2", nn.ReLU()),
255+
("l4", nn.Linear(128, 128)),
256256
)
257257

258258
val_A = model_1(value_A)

ppsci/arch/cuboid_transformer.py

Lines changed: 26 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
"""A space-time Transformer with Cuboid Attention"""
1717

1818

19-
class InitialEncoder(paddle.nn.Layer):
19+
class InitialEncoder(nn.Layer):
2020
def __init__(
2121
self,
2222
dim,
@@ -38,39 +38,35 @@ def __init__(
3838
for i in range(num_conv_layers):
3939
if i == 0:
4040
conv_block.append(
41-
paddle.nn.Conv2D(
41+
nn.Conv2D(
4242
kernel_size=(3, 3),
4343
padding=(1, 1),
4444
in_channels=dim,
4545
out_channels=out_dim,
4646
)
4747
)
48-
conv_block.append(
49-
paddle.nn.GroupNorm(num_groups=16, num_channels=out_dim)
50-
)
48+
conv_block.append(nn.GroupNorm(num_groups=16, num_channels=out_dim))
5149
conv_block.append(
5250
act_mod.get_activation(activation)
5351
if activation != "leaky_relu"
5452
else nn.LeakyReLU(NEGATIVE_SLOPE)
5553
)
5654
else:
5755
conv_block.append(
58-
paddle.nn.Conv2D(
56+
nn.Conv2D(
5957
kernel_size=(3, 3),
6058
padding=(1, 1),
6159
in_channels=out_dim,
6260
out_channels=out_dim,
6361
)
6462
)
65-
conv_block.append(
66-
paddle.nn.GroupNorm(num_groups=16, num_channels=out_dim)
67-
)
63+
conv_block.append(nn.GroupNorm(num_groups=16, num_channels=out_dim))
6864
conv_block.append(
6965
act_mod.get_activation(activation)
7066
if activation != "leaky_relu"
7167
else nn.LeakyReLU(NEGATIVE_SLOPE)
7268
)
73-
self.conv_block = paddle.nn.Sequential(*conv_block)
69+
self.conv_block = nn.Sequential(*conv_block)
7470
if isinstance(downsample_scale, int):
7571
patch_merge_downsample = (1, downsample_scale, downsample_scale)
7672
elif len(downsample_scale) == 2:
@@ -121,7 +117,7 @@ def forward(self, x):
121117
return x
122118

123119

124-
class FinalDecoder(paddle.nn.Layer):
120+
class FinalDecoder(nn.Layer):
125121
def __init__(
126122
self,
127123
target_thw: Tuple[int, ...],
@@ -142,20 +138,20 @@ def __init__(
142138
conv_block = []
143139
for i in range(num_conv_layers):
144140
conv_block.append(
145-
paddle.nn.Conv2D(
141+
nn.Conv2D(
146142
kernel_size=(3, 3),
147143
padding=(1, 1),
148144
in_channels=dim,
149145
out_channels=dim,
150146
)
151147
)
152-
conv_block.append(paddle.nn.GroupNorm(num_groups=16, num_channels=dim))
148+
conv_block.append(nn.GroupNorm(num_groups=16, num_channels=dim))
153149
conv_block.append(
154150
act_mod.get_activation(activation)
155151
if activation != "leaky_relu"
156152
else nn.LeakyReLU(NEGATIVE_SLOPE)
157153
)
158-
self.conv_block = paddle.nn.Sequential(*conv_block)
154+
self.conv_block = nn.Sequential(*conv_block)
159155
self.upsample = cuboid_decoder.Upsample3DLayer(
160156
dim=dim,
161157
out_dim=dim,
@@ -196,7 +192,7 @@ def forward(self, x):
196192
return x
197193

198194

199-
class InitialStackPatchMergingEncoder(paddle.nn.Layer):
195+
class InitialStackPatchMergingEncoder(nn.Layer):
200196
def __init__(
201197
self,
202198
num_merge: int,
@@ -220,8 +216,8 @@ def __init__(
220216
self.downsample_scale_list = downsample_scale_list[:num_merge]
221217
self.num_conv_per_merge_list = num_conv_per_merge_list
222218
self.num_group_list = [max(1, out_dim // 4) for out_dim in self.out_dim_list]
223-
self.conv_block_list = paddle.nn.LayerList()
224-
self.patch_merge_list = paddle.nn.LayerList()
219+
self.conv_block_list = nn.LayerList()
220+
self.patch_merge_list = nn.LayerList()
225221
for i in range(num_merge):
226222
if i == 0:
227223
in_dim = in_dim
@@ -236,15 +232,15 @@ def __init__(
236232
else:
237233
conv_in_dim = out_dim
238234
conv_block.append(
239-
paddle.nn.Conv2D(
235+
nn.Conv2D(
240236
kernel_size=(3, 3),
241237
padding=(1, 1),
242238
in_channels=conv_in_dim,
243239
out_channels=out_dim,
244240
)
245241
)
246242
conv_block.append(
247-
paddle.nn.GroupNorm(
243+
nn.GroupNorm(
248244
num_groups=self.num_group_list[i], num_channels=out_dim
249245
)
250246
)
@@ -253,7 +249,7 @@ def __init__(
253249
if activation != "leaky_relu"
254250
else nn.LeakyReLU(NEGATIVE_SLOPE)
255251
)
256-
conv_block = paddle.nn.Sequential(*conv_block)
252+
conv_block = nn.Sequential(*conv_block)
257253
self.conv_block_list.append(conv_block)
258254
patch_merge = cuboid_encoder.PatchMerging3D(
259255
dim=out_dim,
@@ -303,7 +299,7 @@ def forward(self, x):
303299
return x
304300

305301

306-
class FinalStackUpsamplingDecoder(paddle.nn.Layer):
302+
class FinalStackUpsamplingDecoder(nn.Layer):
307303
def __init__(
308304
self,
309305
target_shape_list: Tuple[Tuple[int, ...]],
@@ -326,8 +322,8 @@ def __init__(
326322
self.in_dim = in_dim
327323
self.num_conv_per_up_list = num_conv_per_up_list
328324
self.num_group_list = [max(1, out_dim // 4) for out_dim in self.out_dim_list]
329-
self.conv_block_list = paddle.nn.LayerList()
330-
self.upsample_list = paddle.nn.LayerList()
325+
self.conv_block_list = nn.LayerList()
326+
self.upsample_list = nn.LayerList()
331327
for i in range(self.num_upsample):
332328
if i == 0:
333329
in_dim = in_dim
@@ -349,15 +345,15 @@ def __init__(
349345
else:
350346
conv_in_dim = out_dim
351347
conv_block.append(
352-
paddle.nn.Conv2D(
348+
nn.Conv2D(
353349
kernel_size=(3, 3),
354350
padding=(1, 1),
355351
in_channels=conv_in_dim,
356352
out_channels=out_dim,
357353
)
358354
)
359355
conv_block.append(
360-
paddle.nn.GroupNorm(
356+
nn.GroupNorm(
361357
num_groups=self.num_group_list[i], num_channels=out_dim
362358
)
363359
)
@@ -366,7 +362,7 @@ def __init__(
366362
if activation != "leaky_relu"
367363
else nn.LeakyReLU(NEGATIVE_SLOPE)
368364
)
369-
conv_block = paddle.nn.Sequential(*conv_block)
365+
conv_block = nn.Sequential(*conv_block)
370366
self.conv_block_list.append(conv_block)
371367
self.reset_parameters()
372368

@@ -686,7 +682,7 @@ def __init__(
686682
embed_dim=base_units, typ=pos_embed_type, maxH=H_in, maxW=W_in, maxT=T_in
687683
)
688684
mem_shapes = self.encoder.get_mem_shapes()
689-
self.z_proj = paddle.nn.Linear(
685+
self.z_proj = nn.Linear(
690686
in_features=mem_shapes[-1][-1], out_features=mem_shapes[-1][-1]
691687
)
692688
self.dec_pos_embed = cuboid_decoder.PosEmbed(
@@ -799,7 +795,7 @@ def get_initial_encoder_final_decoder(
799795
new_input_shape = self.initial_encoder.patch_merge.get_out_shape(
800796
self.input_shape
801797
)
802-
self.dec_final_proj = paddle.nn.Linear(
798+
self.dec_final_proj = nn.Linear(
803799
in_features=self.base_units, out_features=C_out
804800
)
805801
elif self.initial_downsample_type == "stack_conv":
@@ -839,7 +835,7 @@ def get_initial_encoder_final_decoder(
839835
linear_init_mode=self.down_up_linear_init_mode,
840836
norm_init_mode=self.norm_init_mode,
841837
)
842-
self.dec_final_proj = paddle.nn.Linear(
838+
self.dec_final_proj = nn.Linear(
843839
in_features=dec_target_shape_list[-1][-1], out_features=C_out
844840
)
845841
new_input_shape = self.initial_encoder.get_out_shape_list(self.input_shape)[
@@ -892,7 +888,7 @@ def get_initial_z(self, final_mem, T_out):
892888
shape=[B, -1, -1, -1, -1]
893889
)
894890
elif self.z_init_method == "nearest_interp":
895-
initial_z = paddle.nn.functional.interpolate(
891+
initial_z = nn.functional.interpolate(
896892
x=final_mem.transpose(perm=[0, 4, 1, 2, 3]),
897893
size=(T_out, final_mem.shape[2], final_mem.shape[3]),
898894
).transpose(perm=[0, 2, 3, 4, 1])

0 commit comments

Comments
 (0)