Skip to content

Commit d39b52c

Browse files
add Extformer-MoE example by HKUST(GZ) (#933)
* add extformer-moe * fix 0624 * fix 0624 * Delete outputs_extformer_moe_pretrain directory * Delete docs/zh/examples/extformer_moe_figs directory * Update extformer_moe.md * Update extformer_moe_utils.py * Update extformer_moe.md * Update ppsci/data/dataset/__init__.py Co-authored-by: HydrogenSulfate <[email protected]> * Update ppsci/data/dataset/ext_moe_enso_dataset.py Co-authored-by: HydrogenSulfate <[email protected]> * fix merge review issues * fix merge 240629 * fix 240701 * fix 240701_ * fix 240701__ * fix 240701___ * Update ext_moe_enso_dataset.py * Update docs/zh/examples/extformer_moe.md Co-authored-by: HydrogenSulfate <[email protected]> * Update docs/zh/examples/extformer_moe.md Co-authored-by: HydrogenSulfate <[email protected]> * Update ppsci/data/dataset/ext_moe_enso_dataset.py --------- Co-authored-by: HydrogenSulfate <[email protected]>
1 parent 22fbd99 commit d39b52c

14 files changed

+6559
-0
lines changed

docs/zh/api/arch.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
- ChipDeepONets
3030
- AutoEncoder
3131
- CuboidTransformer
32+
- ExtFormerMoECuboid
3233
- SFNONet
3334
- UNONet
3435
- TFNO1dNet

docs/zh/api/data/dataset.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
- ContinuousNamedArrayDataset
1313
- ERA5Dataset
1414
- ERA5SampledDataset
15+
- ExtMoEENSODataset
1516
- IterableMatDataset
1617
- MatDataset
1718
- IterableNPZDataset

docs/zh/examples/extformer_moe.md

Lines changed: 224 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,224 @@
1+
# Extformer-MoE
2+
3+
开始训练、评估前,请先下载,并对应修改 yaml 配置文件中的 FILE_PATH
4+
5+
[ICAR-ENSO数据集](https://tianchi.aliyun.com/dataset/98942)
6+
7+
=== "模型训练命令"
8+
9+
``` sh
10+
# ICAR-ENSO 数据预训练模型: Extformer-MoE
11+
python extformer_moe_enso_train.py
12+
```
13+
14+
=== "模型评估命令"
15+
16+
``` sh
17+
# ICAR-ENSO 模型评估: Extformer-MoE
18+
python extformer_moe_enso_train.py mode=eval EVAL.pretrained_model_path=https://paddle-org.bj.bcebos.com/paddlescience/models/extformer-moe/extformer_moe_pretrained.pdparams
19+
```
20+
21+
| 模型 | 变量名称 | C-Nino3.4-M | C-Nino3.4-WM | MSE(1E-4) | MAE(1E-1) | RMSE |
22+
| :-- | :-- | :-- | :-- | :-- | :-- | :-- |
23+
| [Extformer-MoE](https://paddle-org.bj.bcebos.com/paddlescience/models/extformer-moe/extformer_moe_pretrained.pdparams) | sst | 0.7651 | 2.39771 | 3.0000 | 0.1291 | 0.50243 |
24+
25+
## 1. 背景简介
26+
27+
地球是一个复杂的系统。地球系统的变化,从温度波动等常规事件到干旱、冰雹和厄尔尼诺/南方涛动 (ENSO) 等极端事件,影响着我们的日常生活。在所有后果中,地球系统的变化会影响农作物产量、航班延误、引发洪水和森林火灾。对这些变化进行准确及时的预测可以帮助人们采取必要的预防措施以避免危机,或者更好地利用风能和太阳能等自然资源。因此,改进地球变化(例如天气和气候)的预测模型具有巨大的社会经济影响。
28+
29+
近年来,深度学习模型在天气和气候预报任务中显示出了巨大的潜力。相较于传统的数值模拟方法,深度学习方法通过利用视觉神经网络 (ViT) 或图神经网络 (GNN) 等新兴技术直接从海量再分析数据中学习当前和未来天气或气候状态之间的复杂映射关系,在预测效率和精度方面均取得了显著的提升。然而,地球变化中发生的极端事件往往呈现出长距离时空同步关联、时空分布规律多样以及极值观测信号稀疏等特点,给基于深度学习的地球系统极端事件预测模型的构建带来了诸多新的技术挑战。
30+
31+
### 1.1 长距离时空同步关联
32+
33+
面对复杂耦合的地球变化系统,现有基于视觉和图深度学习的技术在建模极端天气呈现出的长距离时空关联性时存在诸多不足。具体而言,基于视觉深度学习的智能预报模型(例如华为的盘古气象大模型)仅限于计算局部区域内的信息交互,无法高效利用来自遥远区域的全局信息。相比之下,基于图神经网络的天气预报方法(例如谷歌的GraphCast)可以通过预定义的图结构进行远程信息传播,然而先验图结构难以有效识别影响极端天气的关键长距离信息且容易受到噪声影响,导致模型产生有偏甚至错误的预测结果。此外,地球系统的气象数据一般具有海量的网格点,在挖掘全局的长距离时空关联信息的同时,可能会导致模型复杂度的激增,如何高效建模时空数据中的长距离关联成为地球系统极端事件预测的重大挑战。
34+
35+
Earthformer,一种用于地球系统预测的时空转换器。为了更好地探索时空注意力的设计,其中设计了 Cuboid Attention ,它是高效时空注意力的通用构建块。这个想法是将输入张量分解为不重叠的长方体,并行应用长方体级自注意力。由于我们将 O(N<sup>2</sup>) 自注意力限制在局部长方体内,因此模型整体复杂度大大降低。不同类型的相关性可以通过不同的长方体分解来捕获。同时 Earthformer 引入了一组关注所有局部长方体的全局向量,从而收集系统的整体状态。通过关注全局向量,局部长方体可以掌握系统的总体动态并相互共享信息,从而捕获到地球系统的长距离关联信息。
36+
37+
### 1.2 时空分布规律多样
38+
39+
精准建模时空分布规律的多样性是提升地球系统极端事件预测的关键。现有方法在时域和空域均使用共享的参数,无法有效捕捉特定于时段和地理位置独特的的极端天气特征模式。
40+
41+
混合专家(MoE, Mixture-of-Experts)网络,它包含一组专家网络和门控网络。每个专家网络都是独立的神经网络,拥有独立的参数,门控网络自适应地为每个输入单元选择一个独特的专家网络子集。在训练和推理过程中,每个输入单元只需要利用一个很小的专家网络子集,因此可以扩大专家网络的总数,在增强模型表达能力的同时维持相对较小的计算复杂度。在地球系统中,MoE 可以通过学习与时间、地理位置、模型输入相关的独有参数集合,从而增强模型捕捉时空分布差异性的能力。
42+
43+
### 1.3 极值观测信号稀疏
44+
45+
气象数据的不均衡分布会导致模型偏向于预测频繁出现的正常气象状况,而低估了观测值稀少的极端状况,因为模型训练中常用的回归损失函数比如均方误差(MSE)损失会导致预测结果的过平滑现象。与具有离散标签空间的不平衡分类问题不同,不平衡回归问题具有连续的标签空间,为极端预测问题带来了更大的挑战。
46+
47+
Rank-N-Contrast(RNC)是一种表征学习方法,旨在学习一种回归感知的样本表征,该表征以连续标签空间中的距离为依据,对嵌入空间中的样本间距离进行排序,然后利用它来预测最终连续的标签。在地球系统极端预测问题中,RNC 可以对气象数据的表征进行规范,使其满足嵌入空间的连续性,和标签空间对齐,最终缓解极端事件的预测结果的过平滑问题。
48+
49+
50+
## 2. 模型原理
51+
52+
### 2.1 Earthformer
53+
54+
本章节仅对 EarthFormer 的模型原理进行简单地介绍,详细的理论推导请阅读 [Earthformer: Exploring Space-Time Transformers for Earth System Forecasting](https://arxiv.org/abs/2207.05833)
55+
56+
Earthformer 的网络模型使用了基于 Cuboid Attention 的分层 Encoder-Decoder 架构Transformer,它将数据分解为长方体并并行应用长方体级自注意力,这些长方体进一步与全局向量的集合交互以捕获全局信息。
57+
58+
Earthformer 的总体结构如图所示:
59+
60+
<center class ='img'>
61+
<img title="Earthformer" src="https://paddle-org.bj.bcebos.com/paddlescience/docs/extformer-moe/Earthformer.png" width="60%">
62+
</center>
63+
64+
### 2.2 Mixture-of-Experts
65+
66+
本章节仅对 Mixture-of-Experts 的原理进行简单地介绍,详细的理论推导请阅读 [Outrageously Large Neural Networks: The Sparsely-Gated Mixture-of-Experts Layer
67+
](https://arxiv.org/abs/1701.06538)
68+
69+
混合专家(MoE, Mixture-of-Experts)网络,它包含一组参数独立的专家网络 $E_1,E_2,...,E_n$ 和门控网络 $G$。给定输入 $x$,MoE 网络的输出为 $y=\sum_{i=1}^n G(x)_iE_i(x)$。
70+
71+
MoE 的总体结构如图所示:
72+
73+
<center class ='img'>
74+
<img title="MoE" src="https://paddle-org.bj.bcebos.com/paddlescience/docs/extformer-moe/MoE.png" width="60%">
75+
</center>
76+
77+
### 2.3 Rank-N-Contrast
78+
79+
Rank-N-Contrast(RNC)是一种根据样本在标签空间中的相互间的排序,通过对比来学习以学习连续性表征的的回归方法。RNC 的一个简单示例如图所示:
80+
81+
<center class ='img'>
82+
<img title="RNC" src="https://paddle-org.bj.bcebos.com/paddlescience/docs/extformer-moe/RNC.png" width="70%">
83+
</center>
84+
85+
### 2.4 Extformer-MoE 模型的训练、推理过程
86+
87+
模型预训练阶段是基于随机初始化的网络权重对模型进行训练,如下图所示,其中 $[x_{i}]_{i=1}^{T}$ 表示长度为 $T$ 时空序列的输入气象数据,$[y_{i}]_{i=1}^{K}$ 表示预测未来 $K$ 步的气象数据,$[y_{i_True}]_{i=1}^{K}$ 表示未来 $K$ 步的真实数据,如海面温度数据和云总降水量数据。最后网络模型预测的输出和真值计算 mse 损失函数。在推理阶段,给定长度序列为 $T$ 的数据,得到长度序列为 $K$ 的预测结果。
88+
89+
## 3. 海面温度模型实现
90+
91+
接下来开始讲解如何基于 PaddleScience 代码,实现 Extformer-MoE 模型的训练与推理。关于该案例中的其余细节请参考 [API文档](../api/arch.md)
92+
93+
### 3.1 数据集介绍
94+
95+
数据集采用了 [EarthFormer](https://github.com/amazon-science/earth-forecasting-transformer/tree/main) 处理好的 ICAR-ENSO 数据集。
96+
97+
本数据集由气候与应用前沿研究院 ICAR 提供。数据包括 CMIP5/6 模式的历史模拟数据和美国 SODA 模式重建的近100多年历史观测同化数据。每个样本包含以下气象及时空变量:海表温度异常 (SST) ,热含量异常 (T300),纬向风异常 (Ua),经向风异常 (Va),数据维度为 (year,month,lat,lon)。训练数据提供对应月份的 Nino3.4 index 标签数据。测试用的初始场数据为国际多个海洋资料同化结果提供的随机抽取的 n 段 12 个时间序列,数据格式采用 NPY 格式保存。
98+
99+
**训练数据:**
100+
101+
每个数据样本第一维度 (year) 表征数据所对应起始年份,对于 CMIP 数据共 291 年,其中 1-2265 为 CMIP6 中 15 个模式提供的 151 年的历史模拟数据 (总共:151年 *15 个模式=2265) ;2266-4645 为 CMIP5 中 17 个模式提供的 140 年的历史模拟数据 (总共:140 年*17 个模式=2380)。对于历史观测同化数据为美国提供的 SODA 数据。
102+
103+
**训练数据标签**
104+
105+
标签数据为 Nino3.4 SST 异常指数,数据维度为 (year,month)。
106+
107+
CMIP(SODA)_train.nc 对应的标签数据当前时刻 Nino3.4 SST 异常指数的三个月滑动平均值,因此数据维度与维度介绍同训练数据一致。
108+
109+
注:三个月滑动平均值为当前月与未来两个月的平均值。
110+
111+
**测试数据**
112+
113+
测试用的初始场 (输入) 数据为国际多个海洋资料同化结果提供的随机抽取的 n 段 12 个时间序列,数据格式采用NPY格式保存,维度为 (12,lat,lon, 4), 12 为 t 时刻及过去 11 个时刻,4 为预测因子,并按照 SST,T300,Ua,Va 的顺序存放。
114+
115+
EarthFFormer 模型对于 ICAR-ENSO 数据集的训练中,只对其中海面温度 (SST) 进行训练和预测。训练海温异常观测的 12 步 (一年) ,预测海温异常最多 14 步。
116+
117+
### 3.2 模型预训练
118+
119+
#### 3.2.1 约束构建
120+
121+
本案例基于数据驱动的方法求解问题,因此需要使用 PaddleScience 内置的 `SupervisedConstraint` 构建监督约束。在定义约束之前,需要首先指定监督约束中用于数据加载的各个参数。
122+
123+
数据加载的代码如下:
124+
125+
``` py linenums="35" title="examples/extformer_moe/extformer_moe_enso_train.py"
126+
--8<--
127+
examples/extformer_moe/extformer_moe_enso_train.py:35:56
128+
--8<--
129+
```
130+
131+
其中,"dataset" 字段定义了使用的 `Dataset` 类名为 `ExtMoEENSODataset`,"sampler" 字段定义了使用的 `Sampler` 类名为 `BatchSampler`,设置的 `batch_size` 为 16,`num_works` 为 8。
132+
133+
定义监督约束的代码如下:
134+
135+
``` py linenums="58" title="examples/extformer_moe/extformer_moe_enso_train.py"
136+
--8<--
137+
examples/extformer_moe/extformer_moe_enso_train.py:58:64
138+
--8<--
139+
```
140+
141+
`SupervisedConstraint` 的第一个参数是数据的加载方式,这里使用上文中定义的 `train_dataloader_cfg`
142+
143+
第二个参数是损失函数的定义,这里使用自定义的损失函数;
144+
145+
第三个参数是约束条件的名字,方便后续对其索引。此处命名为 `Sup`
146+
147+
#### 3.2.2 模型构建
148+
149+
在该案例中,海面温度模型基于 ExtFormerMoECuboid 网络模型实现,用 PaddleScience 代码表示如下:
150+
151+
``` py linenums="97" title="examples/extformer_moe/extformer_moe_enso_train.py"
152+
--8<--
153+
examples/extformer_moe/extformer_moe_enso_train.py:97:101
154+
--8<--
155+
```
156+
157+
网络模型的参数通过配置文件进行设置如下:
158+
159+
``` yaml linenums="47" title="examples/earthformer/conf/earthformer_enso_pretrain.yaml"
160+
--8<--
161+
examples/extformer_moe/conf/extformer_moe_enso_pretrain.yaml:47:129
162+
--8<--
163+
```
164+
165+
其中,`input_keys``output_keys` 分别代表网络模型输入、输出变量的名称。
166+
167+
#### 3.2.3 学习率与优化器构建
168+
169+
本案例中使用的学习率方法为 `Cosine`,学习率大小设置为 `2e-4`。优化器使用 `AdamW`,并将参数进行分组,使用不同的
170+
`weight_decay`,用 PaddleScience 代码表示如下:
171+
172+
``` py linenums="103" title="examples/extformer_moe/extformer_moe_enso_train.py"
173+
--8<--
174+
examples/extformer_moe/extformer_moe_enso_train.py:103:128
175+
--8<--
176+
```
177+
178+
#### 3.2.4 评估器构建
179+
180+
本案例训练过程中会按照一定的训练轮数间隔,使用验证集评估当前模型的训练情况,需要使用 `SupervisedValidator` 构建评估器。代码如下:
181+
182+
``` py linenums="68" title="examples/extformer_moe/extformer_moe_enso_train.py"
183+
--8<--
184+
examples/extformer_moe/extformer_moe_enso_train.py:68:95
185+
--8<--
186+
```
187+
188+
`SupervisedValidator` 评估器与 `SupervisedConstraint` 比较相似,不同的是评估器需要设置评价指标 `metric`,在这里使用了自定义的评价指标分别是 `MAE``MSE``RMSE``corr_nino3.4_epoch``corr_nino3.4_weighted_epoch`
189+
190+
#### 3.2.5 模型训练与评估
191+
192+
完成上述设置之后,只需要将上述实例化的对象按顺序传递给 `ppsci.solver.Solver`,然后启动训练、评估。
193+
194+
``` py linenums="130" title="examples/extformer_moe/extformer_moe_enso_train.py"
195+
--8<--
196+
examples/extformer_moe/extformer_moe_enso_train.py:130:151
197+
--8<--
198+
```
199+
200+
### 3.3 模型评估
201+
202+
构建模型的代码为:
203+
204+
``` py linenums="184" title="examples/extformer_moe/extformer_moe_enso_train.py"
205+
--8<--
206+
examples/extformer_moe/extformer_moe_enso_train.py:184:188
207+
--8<--
208+
```
209+
210+
构建评估器的代码为:
211+
212+
``` py linenums="155" title="examples/extformer_moe/extformer_moe_enso_train.py"
213+
--8<--
214+
examples/extformer_moe/extformer_moe_enso_train.py:155:182
215+
--8<--
216+
```
217+
218+
## 4. 完整代码
219+
220+
``` py linenums="1" title="examples/extformer_moe/extformer_moe_enso_train.py"
221+
--8<--
222+
examples/extformer_moe/extformer_moe_enso_train.py
223+
--8<--
224+
```
Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
defaults:
2+
- ppsci_default
3+
- TRAIN: train_default
4+
- TRAIN/ema: ema_default
5+
- TRAIN/swa: swa_default
6+
- EVAL: eval_default
7+
- INFER: infer_default
8+
- hydra/job/config/override_dirname/exclude_keys: exclude_keys_default
9+
- _self_
10+
11+
hydra:
12+
run:
13+
# dynamic output directory according to running time and override name
14+
dir: outputs_extformer_moe_pretrain
15+
job:
16+
name: ${mode} # name of logfile
17+
chdir: false # keep current working directory unchanged
18+
callbacks:
19+
init_callback:
20+
_target_: ppsci.utils.callbacks.InitCallback
21+
sweep:
22+
# output directory for multirun
23+
dir: ${hydra.run.dir}
24+
subdir: ./
25+
26+
# general settings
27+
mode: train # running mode: train/eval
28+
seed: 0
29+
output_dir: ${hydra:run.dir}
30+
log_freq: 20
31+
32+
# set train and evaluate data path
33+
FILE_PATH: /hpc2hdd/home/hni017/Workplace/data/weather_data/icar_enso_2021/enso_round1_train_20210201
34+
35+
# dataset setting
36+
DATASET:
37+
label_keys: ["sst_target","nino_target"]
38+
in_len: 12
39+
out_len: 14
40+
nino_window_t: 3
41+
in_stride: 1
42+
out_stride: 1
43+
train_samples_gap: 1
44+
eval_samples_gap: 1
45+
normalize_sst: true
46+
47+
# model settings
48+
MODEL:
49+
input_keys: ["sst_data"]
50+
output_keys: ["sst_target","nino_target","aux_loss","rank_loss"]
51+
input_shape: [12, 24, 48, 1]
52+
target_shape: [14, 24, 48, 1]
53+
base_units: 64
54+
scale_alpha: 1.0
55+
56+
enc_depth: [1, 1]
57+
dec_depth: [1, 1]
58+
enc_use_inter_ffn: true
59+
dec_use_inter_ffn: true
60+
dec_hierarchical_pos_embed: false
61+
62+
downsample: 2
63+
downsample_type: "patch_merge"
64+
upsample_type: "upsample"
65+
66+
num_global_vectors: 0
67+
use_dec_self_global: false
68+
dec_self_update_global: true
69+
use_dec_cross_global: false
70+
use_global_vector_ffn: false
71+
use_global_self_attn: false
72+
separate_global_qkv: false
73+
global_dim_ratio: 1
74+
75+
self_pattern: "axial"
76+
cross_self_pattern: "axial"
77+
cross_pattern: "cross_1x1"
78+
dec_cross_last_n_frames: null
79+
80+
attn_drop: 0.1
81+
proj_drop: 0.1
82+
ffn_drop: 0.1
83+
num_heads: 4
84+
85+
ffn_activation: "gelu"
86+
gated_ffn: false
87+
norm_layer: "layer_norm"
88+
padding_type: "zeros"
89+
pos_embed_type: "t+h+w"
90+
use_relative_pos: true
91+
self_attn_use_final_proj: true
92+
dec_use_first_self_attn: false
93+
94+
z_init_method: "zeros"
95+
initial_downsample_type: "conv"
96+
initial_downsample_activation: "leaky_relu"
97+
initial_downsample_scale: [1, 1, 2]
98+
initial_downsample_conv_layers: 2
99+
final_upsample_conv_layers: 1
100+
checkpoint_level: 0
101+
102+
attn_linear_init_mode: "0"
103+
ffn_linear_init_mode: "0"
104+
conv_init_mode: "0"
105+
down_up_linear_init_mode: "0"
106+
norm_init_mode: "0"
107+
108+
# moe settings
109+
MOE:
110+
use_linear_moe: false
111+
use_ffn_moe: true
112+
use_attn_moe: false
113+
num_experts: 10
114+
out_planes: 4
115+
importance_weight: 0.0
116+
load_weight: 0.0
117+
gate_style: "cuboid-latent" # linear, spatial-latent, cuboid-latent, spatial-latent-linear, cuboid-latent-linear
118+
dispatch_style: "dense" # sparse, dense
119+
aux_loss_style: "all" # all, cell
120+
121+
# rnc settings
122+
RNC:
123+
use_rnc: true
124+
rank_imbalance_style: "batch+T+H+W"
125+
feature_similarity_style: "l2"
126+
rank_imbalance_temp: 2
127+
label_difference_style: "l1"
128+
rank_reg_coeff: 0.01
129+
loss_cal_style: "computation-efficient" # computation-efficient, memory-efficient
130+
131+
# training settings
132+
TRAIN:
133+
epochs: 100
134+
save_freq: 20
135+
eval_during_train: true
136+
eval_freq: 10
137+
lr_scheduler:
138+
epochs: ${TRAIN.epochs}
139+
learning_rate: 0.0002
140+
by_epoch: true
141+
min_lr_ratio: 1.0e-3
142+
wd: 1.0e-5
143+
batch_size: 16
144+
pretrained_model_path: null
145+
checkpoint_path: null
146+
update_freq: 1
147+
148+
# evaluation settings
149+
EVAL:
150+
pretrained_model_path: ./checkpoint/enso/extformer_moe_enso.pdparams
151+
compute_metric_by_batch: false
152+
eval_with_no_grad: true
153+
batch_size: 1

0 commit comments

Comments
 (0)