Skip to content

Commit e476ef4

Browse files
authored
gencast (#1130)
* Modify the document of heat_pinn * add gencast code * add gencast code * add gencast code * add gencast code * add gencast code * modify gencast code
1 parent 9b9246b commit e476ef4

30 files changed

+3557
-51
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,7 @@ PaddleScience 是一个基于深度学习框架 PaddlePaddle 开发的科学计
117117
| 天气预报 | [FourCastNet 气象预报](https://paddlescience-docs.readthedocs.io/zh-cn/latest/zh/examples/fourcastnet) | 数据驱动 | FourCastNet | 监督学习 | [ERA5](https://app.globus.org/file-manager?origin_id=945b3c9e-0f8c-11ed-8daf-9f359c660fbd&origin_path=%2F~%2Fdata%2F) | [Paper](https://arxiv.org/pdf/2202.11214.pdf) |
118118
| 天气预报 | [NowCastNet 气象预报](https://paddlescience-docs.readthedocs.io/zh-cn/latest/zh/examples/nowcastnet) | 数据驱动 | NowCastNet | 监督学习 | [MRMS](https://app.globus.org/file-manager?origin_id=945b3c9e-0f8c-11ed-8daf-9f359c660fbd&origin_path=%2F~%2Fdata%2F) | [Paper](https://www.nature.com/articles/s41586-023-06184-4) |
119119
| 天气预报 | [GraphCast 气象预报](https://paddlescience-docs.readthedocs.io/zh-cn/latest/zh/examples/graphcast) | 数据驱动 | GraphCastNet | 监督学习 | - | [Paper](https://arxiv.org/abs/2212.12794) |
120+
| 天气预报 | [GenCast 气象预报](https://paddlescience-docs.readthedocs.io/zh-cn/latest/zh/examples/gencast) | 数据驱动 | Diffusion | 监督学习 | [Gencast](https://console.cloud.google.com/storage/browser/dm_graphcast) | [Paper](https://arxiv.org/abs/2312.15796) |
120121
| 天气预报 | [FengWu 气象预报](https://paddlescience-docs.readthedocs.io/zh-cn/latest/zh/examples/fengwu) | 数据驱动 | Transformer | 监督学习 | - | [Paper](https://arxiv.org/pdf/2304.02948) |
121122
| 天气预报 | [Pangu-Weather 气象预报](https://paddlescience-docs.readthedocs.io/zh-cn/latest/zh/examples/pangu_weather) | 数据驱动 | Transformer | 监督学习 | - | [Paper](https://arxiv.org/pdf/2211.02556) |
122123
| 大气污染物 | [UNet 污染物扩散](https://aistudio.baidu.com/projectdetail/5663515?channel=0&channelType=0&sUid=438690&shared=1&ts=1698221963752) | 数据驱动 | UNet | 监督学习 | [Data](https://aistudio.baidu.com/datasetdetail/198102) | - |

docs/index.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,7 @@
150150
| 天气预报 | [FourCastNet 气象预报](./zh/examples/fourcastnet.md) | 数据驱动 | FourCastNet | 监督学习 | [ERA5](https://app.globus.org/file-manager?origin_id=945b3c9e-0f8c-11ed-8daf-9f359c660fbd&origin_path=%2F~%2Fdata%2F) | [Paper](https://arxiv.org/pdf/2202.11214.pdf) |
151151
| 天气预报 | [NowCastNet 气象预报](./zh/examples/nowcastnet.md) | 数据驱动 | NowCastNet | 监督学习 | [MRMS](https://app.globus.org/file-manager?origin_id=945b3c9e-0f8c-11ed-8daf-9f359c660fbd&origin_path=%2F~%2Fdata%2F) | [Paper](https://www.nature.com/articles/s41586-023-06184-4) |
152152
| 天气预报 | [GraphCast 气象预报](./zh/examples/graphcast.md) | 数据驱动 | GraphCastNet | 监督学习 | - | [Paper](https://arxiv.org/abs/2212.12794) |
153+
| 天气预报 | [GenCast 气象预报](./zh/examples/gencast.md) | 数据驱动 | Diffusion | 监督学习 | [Gencast](https://console.cloud.google.com/storage/browser/dm_graphcast) | [Paper](https://arxiv.org/abs/2312.15796) |
153154
| 天气预报 | [FengWu 气象预报](./zh/examples/fengwu.md) | 数据驱动 | Transformer | 监督学习 | - | [Paper](https://arxiv.org/pdf/2304.02948) |
154155
| 天气预报 | [Pangu-Weather 气象预报](./zh/examples/pangu_weather.md) | 数据驱动 | Transformer | 监督学习 | - | [Paper](https://arxiv.org/pdf/2211.02556) |
155156
| 大气污染物 | [UNet 污染物扩散](https://aistudio.baidu.com/projectdetail/5663515?channel=0&channelType=0&sUid=438690&shared=1&ts=1698221963752) | 数据驱动 | UNet | 监督学习 | [Data](https://aistudio.baidu.com/datasetdetail/198102) | - |

docs/zh/examples/gencast.md

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
# GenCast
2+
3+
开始评估前,请在 [Google Cloud Bucket](https://console.cloud.google.com/storage/browser/dm_graphcast) 上获取相关数据,并将之放到`gencast.yaml`文件中数据配置的路径下。
4+
5+
- 下载目录`dm_graphcast/gencast/stats`下的所有文件放入`./data/stats/`目录下。
6+
- 下载目录`dm_graphcast/gencast/dataset`下的任意或所有文件(例如:source-era5_date-2019-03-29_res-1.0_levels-13_steps-12.nc)放入`./data/dataset/`目录下。
7+
8+
=== "模型评估命令"
9+
10+
``` sh
11+
# 设置路径到 PaddleScience/jointContribution 文件夹
12+
cd PaddleScience/jointContribution
13+
export PYTHONPATH=$PWD:$PYTHONPAT
14+
# 下载模型参数
15+
cd gencast/
16+
wget -nc https://paddle-org.bj.bcebos.com/paddlescience/models/gencast/gencast_params_GenCast-1p0deg-Mini-_2019.pdparams -P ./data/params/
17+
# 运行评估脚本
18+
python run_gencast.py
19+
```
20+
21+
## 1. 背景简介
22+
23+
天气预报本质上存在不确定性,因此预测可能天气情景的范围对于许多重要决策至关重要,从警告公众危险天气到规划可再生能源的使用。在此,我们介绍了 GenCast,这是一种概率性天气模型,其技能和速度优于世界顶级的中期天气预报——欧洲中期天气预报中心(ECMWF)的集合预报 ENS。与基于数值天气预报(NWP)的传统方法不同,GenCast 是一种机器学习天气预报(MLWP)方法,基于数十年的再分析数据进行训练。GenCast 能够在 8 分钟内生成一个随机的 15 天全球预报集合,以 12 小时为步长,0.25 度的纬度-经度分辨率,覆盖 80 多个地表和大气变量。在我们评估的 1320 个目标中,GenCast 在 97.4% 上表现优于 ENS,并能更好地预测极端天气、热带气旋和风力发电。该工作帮助开启了操作性天气预报的下一个篇章,使依赖天气的重要决策能够以更高的准确性和效率做出。
24+
25+
## 2. 模型原理
26+
27+
在这里,我们介绍了一种概率性天气模型——GenCast,它以0.25°的分辨率生成全球15天的集合预报,首次实现了比顶级操作性集合系统ECMWF的ENS更高的准确性。在云TPUv5设备上生成一个单一的15天GenCast预报大约需要8分钟,可以并行生成多个预报集合。
28+
29+
GenCast 模型化了未来天气状态 $X^{t+1}$ 的条件概率分布 $p(X^{t+1} | X^t, X^{t-1})$,这个分布是基于当前和之前的天气状态的条件来进行的。长度为 $T$ 的预报轨迹 $X^{1:T}$ 是通过对初始和之前状态 $(X^0, X^{-1})$ 进行条件化来建模的,并对连续状态的联合分布进行分解:
30+
31+
$$
32+
p(X^{1:T} | X^0, X^{-1}) = \prod_{t=0}^{T-1} p(X^{t+1} | X^t, X^{t-1})
33+
$$
34+
35+
每个状态都是通过自回归采样得出的。
36+
37+
全球天气状态 $X$ 的表示包括6个地表变量和13个垂直压力层上的6个大气变量,分布在0.25°的纬度-经度网格上(详见表B1)。预报时长为15天,连续步骤 $t$ 和 $t+1$ 之间的间隔为12小时,因此 $T = 30$。
38+
39+
GenCast 实现为一个条件扩散模型,这是一种生成式机器学习模型,用于从给定数据分布生成新样本,这为自然图像、声音和视频建模的许多最新进展提供了支持,被称为“生成式 AI”。扩散模型通过迭代细化的过程运行。未来的大气状态 $X^{t+1}$ 是通过迭代细化候选状态 $Z_0^{t+1}$ 产生的,该状态纯粹从噪声初始化,并以之前的两个大气状态 $(X^t, X^{t-1})$ 为条件。图中的蓝色框显示了第一个预报步骤如何从初始条件生成,以及整个轨迹 $X^{1:T}$ 如何通过自回归生成。由于预报中的每个时间步都是用噪声(即 $Z_0^{t+1}$)初始化的,因此可以用不同的噪声样本重复该过程,以生成轨迹集合。
40+
41+
<figure markdown>
42+
![gencast.png](https://paddle-org.bj.bcebos.com/paddlescience/docs/gencast/gencast.png){ loading=lazy }
43+
</figure>
44+
45+
在迭代细化过程的每个阶段,GenCast 应用一个由编码器、处理器和解码器组成的神经网络架构。编码器组件将输入 $Z_n^{t+1}$ 和条件 $(X^t, X^{t-1})$ 从原始的纬度-经度网格映射到六次细化的二十面体网格上的内部学习表示。处理器组件是一个Graph Transformer,其中每个节点关注其在内部网格上的k跳邻居。解码器组件将内部网格表示映射回 $Z_{n+1}^{t+1}$,其定义在纬度-经度网格上。
46+
47+
GenCast 在40年的ERA5再分析数据上进行训练,时间范围从1979年到2018年,使用标准的扩散模型去噪目标。重要的是,尽管只在单步预测任务上直接训练GenCast,但它可以通过自回归展开来生成15天的集合预报。
48+
49+
## 3. 模型构建
50+
51+
### 3.1 环境依赖
52+
53+
* paddlepaddle
54+
* matpoltlib (用于图像绘制)
55+
* pickle (用于存储和加载图模板)
56+
* xarray (用于加载.nc数据)
57+
* trimesh (用于制作mesh数据)
58+
* scipy (用于球谐变换过程中的稀疏矩阵操作)
59+
* math (用于球谐变换过程中的数学计算)
60+
61+
### 3.2 模型相关文件说明
62+
63+
- **xarray_tree.py**: 一种适用于 xarray 的 tree.map_structure 实现。
64+
65+
- **denoiser.py**: 用于一步预测的 GenCast 去噪器。
66+
67+
- **dpm_solver_plus_plus_2s.py**: 使用 [1] 中的 DPM-Solver++ 2S 的采样器。
68+
69+
- **gencast.py**: 将 GenCast 模型架构与采样器结合,作为去噪器封装以生成预测。
70+
71+
- **samplers_base.py**: 定义采样器的接口。
72+
73+
- **samplers_utils.py**: 采样器的实用方法。
74+
75+
- **sparse_transformer.py**: 通用稀疏变压器,作用于 TypedGraph,其中输入和输出都是每个节点和边的特征平坦向量。`predictor.py` 使用其中一个用于网格图神经网络(GNN)。
76+
77+
- **spherical_harmonic.py**: 球面谐波基础评估和微分算子。
78+
79+
- **main.py**: 评估和可视化脚本。
80+
81+
[1] DPM-Solver++: Fast Solver for Guided Sampling of Diffusion Probabilistic Models, https://arxiv.org/abs/2211.01095
82+
83+
## 4. 结果展示
84+
85+
下图展示了2米温度的真值结果、预测结果和误差。
86+
87+
<figure markdown>
88+
![gencast_2m_t.png](https://paddle-org.bj.bcebos.com/paddlescience/docs/gencast/gencast_2m_t.png){ loading=lazy style="margin:0 auto;"}
89+
<figcaption>真值结果("targets")、预测结果("prediction")和误差("diff")</figcaption>
90+
</figure>
91+
92+
可以看到模型预测结果与真实结果基本一致。
93+
94+
## 4. 参考资料
95+
96+
* [GenCast: Diffusion-based ensemble forecasting for medium-range weather](https://arxiv.org/abs/2312.15796)
97+
* [GraphCast: Learning skillful medium-range global weather forecasting](https://arxiv.org/abs/2212.12794)
98+
* [GenCast Github地址](https://github.com/deepmind/graphcast)
99+
* [dinosaur Github地址](https://github.com/neuralgcm/dinosaur)
Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
2+
hydra:
3+
run:
4+
# dynamic output directory according to running time and override name
5+
dir: gencast/${now:%Y-%m-%d}/${now:%H-%M-%S}/${hydra.job.override_dirname}
6+
job:
7+
name: ${mode} # name of logfile
8+
chdir: false # keep current working direcotry unchaned
9+
sweep:
10+
# output directory for multirun
11+
dir: ${hydra.run.dir}
12+
subdir: ./
13+
14+
# general settings
15+
mode: eval # running mode: train/eval
16+
seed: 2024
17+
output_dir: ${hydra:run.dir}
18+
log_freq: 20
19+
num_ensemble_members: 8
20+
input_duration: "24h"
21+
target_lead_times: "12h"
22+
23+
type: gencast
24+
data_path: data/dataset/source-era5_date-2019-03-29_res-1.0_levels-13_steps-12.nc
25+
stddev_diffs_path: data/stats/gencast_stats_diffs_stddev_by_level.nc
26+
stddev_path: data/stats/gencast_stats_stddev_by_level.nc
27+
mean_path: data/stats/gencast_stats_mean_by_level.nc
28+
min_path: data/stats/gencast_stats_min_by_level.nc
29+
param_path: data/params/gencast_params_GenCast-1p0deg-Mini-_2019.pdparams
30+
31+
sampler_config:
32+
max_noise_level: 80.0
33+
min_noise_level: 0.03
34+
num_noise_levels: 20
35+
rho: 7.0
36+
stochastic_churn_rate: 2.5
37+
churn_min_noise_level: 0.75
38+
churn_max_noise_level: inf
39+
noise_level_inflation_factor: 1.05
40+
41+
noise_config:
42+
training_noise_level_rho: 7.0
43+
training_max_noise_level: 88.0
44+
training_min_noise_level: 0.02
45+
46+
noise_encoder_config:
47+
apply_log_first: true
48+
base_period: 16.0
49+
num_frequencies: 32
50+
output_sizes: [32, 16]
51+
52+
denoiser_architecture_config:
53+
sparse_transformer_config:
54+
attention_k_hop: 16
55+
d_model: 512
56+
num_layers: 16
57+
num_heads: 4
58+
attention_type: triblockdiag_mha
59+
mask_type: lazy
60+
block_q: 1024
61+
block_kv: 512
62+
block_kv_compute: 256
63+
block_q_dkv: 512
64+
block_kv_dkv: 1024
65+
block_kv_dkv_compute: 1024
66+
ffw_winit_final_mult: 0.0
67+
attn_winit_final_mult: 0.0
68+
ffw_hidden: 2048
69+
mesh_node_dim: 186
70+
mesh_node_emb_dim: 512
71+
ffw_winit_mult: 2.0
72+
value_size: 128
73+
key_size: 128
74+
norm_conditioning_feat: 16
75+
activation: gelu
76+
mesh_size: 4
77+
latent_size: 512
78+
hidden_layers: 1
79+
radius_query_fraction_edge_length: 0.6
80+
norm_conditioning_features: ['noise_level_encodings']
81+
grid2mesh_aggregate_normalization: null
82+
node_output_size: 84
83+
grid_node_dim: 267
84+
grid_node_emb_dim: 512
85+
mesh_node_dim: 267
86+
mesh_node_emb_dim: 512
87+
mesh_edge_emb_dim: 512
88+
mesh_edge_dim: 4
89+
grid2mesh_edge_dim: 4
90+
grid2mesh_edge_emb_dim: 512
91+
mesh2grid_edge_dim: 4
92+
mesh2grid_edge_emb_dim: 512
93+
gnn_msg_steps: 16
94+
node_output_dim: 84
95+
norm_conditioning_feat: 16
96+
mesh_node_num: 2562
97+
grid_node_num: 65160
98+
resolution: 1.0
99+
name: gencast

jointContribution/gencast/data/dataset/.gitkeep

Whitespace-only changes.

jointContribution/gencast/data/params/.gitkeep

Whitespace-only changes.

jointContribution/gencast/data/stats/.gitkeep

Whitespace-only changes.

jointContribution/gencast/data/template_graph/.gitkeep

Whitespace-only changes.

jointContribution/gencast/denoiser.py

Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,171 @@
1+
# Copyright 2024 DeepMind Technologies Limited.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS-IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""Support for wrapping a general Predictor to act as a Denoiser."""
15+
16+
import copy
17+
import os
18+
import pickle
19+
from typing import Optional
20+
from typing import Sequence
21+
22+
import numpy as np
23+
import paddle
24+
import paddle.nn as nn
25+
import xarray as xr
26+
from graphcast import datasets
27+
from graphcast import graphcast
28+
from graphcast import graphtype
29+
from graphcast import utils
30+
31+
32+
class FourierFeaturesMLP(nn.Layer):
33+
"""A simple MLP applied to Fourier features of values or their logarithms."""
34+
35+
def __init__(
36+
self,
37+
base_period: float,
38+
num_frequencies: int,
39+
output_sizes: Sequence[int],
40+
apply_log_first: bool = False,
41+
w_init: Optional[nn.initializer.Initializer] = None,
42+
activation: Optional[nn.Layer] = nn.GELU(),
43+
**mlp_kwargs,
44+
):
45+
"""Initializes the module.
46+
47+
Args:
48+
base_period:
49+
See model_utils.fourier_features. Note this would apply to log inputs if
50+
apply_log_first is used.
51+
num_frequencies:
52+
See model_utils.fourier_features.
53+
output_sizes:
54+
Layer sizes for the MLP.
55+
apply_log_first:
56+
Whether to take the log of the inputs before computing Fourier features.
57+
w_init:
58+
Weights initializer for the MLP, default setting aims to produce
59+
approx unit-variance outputs given the input sin/cos features.
60+
activation:
61+
**mlp_kwargs:
62+
Further settings for the MLP.
63+
"""
64+
super(FourierFeaturesMLP, self).__init__()
65+
self._base_period = base_period
66+
self._num_frequencies = num_frequencies
67+
self._apply_log_first = apply_log_first
68+
69+
# 创建 MLP
70+
layers = []
71+
input_size = 2 * num_frequencies
72+
num_layers = len(output_sizes)
73+
for i, output_size in enumerate(output_sizes):
74+
linear_layer = nn.Linear(input_size, output_size)
75+
layers.append(linear_layer)
76+
if i < num_layers - 1:
77+
layers.append(activation)
78+
input_size = output_size
79+
80+
self._mlp = nn.Sequential(*layers)
81+
82+
def forward(self, values: paddle.Tensor) -> paddle.Tensor:
83+
if self._apply_log_first:
84+
values = paddle.log(values)
85+
features = utils.fourier_features(
86+
values, self._base_period, self._num_frequencies
87+
)
88+
89+
return self._mlp(features)
90+
91+
92+
class Denoiser(nn.Layer):
93+
"""Wraps a general deterministic Predictor to act as a Denoiser.
94+
95+
This passes an encoding of the noise level as an additional input to the
96+
Predictor as an additional input 'noise_level_encodings' with shape
97+
('batch', 'noise_level_encoding_channels'). It passes the noisy_targets as
98+
additional forcings (since they are also per-target-timestep data that the
99+
predictor needs to condition on) with the same names as the original target
100+
variables.
101+
"""
102+
103+
def __init__(
104+
self,
105+
cfg,
106+
):
107+
super(Denoiser, self).__init__()
108+
self.cfg = cfg
109+
self._predictor = graphcast.GraphCastNet(
110+
config=cfg.denoiser_architecture_config,
111+
)
112+
113+
self._noise_level_encoder = FourierFeaturesMLP(**cfg.noise_encoder_config)
114+
115+
def forward(
116+
self,
117+
inputs: xr.Dataset,
118+
noisy_targets: xr.Dataset,
119+
noise_levels: xr.DataArray,
120+
forcings: Optional[xr.Dataset] = None,
121+
**kwargs,
122+
) -> xr.Dataset:
123+
124+
if forcings is None:
125+
forcings = xr.Dataset()
126+
forcings = forcings.assign(**noisy_targets)
127+
128+
if noise_levels.dims != ("batch",):
129+
raise ValueError("noise_levels expected to be shape (batch,).")
130+
131+
noise_level_encodings = self._noise_level_encoder(
132+
paddle.to_tensor(noise_levels.values)
133+
)
134+
135+
stacked_inputs = datasets.dataset_to_stacked(inputs)
136+
137+
stacked_forcings = datasets.dataset_to_stacked(forcings)
138+
stacked_inputs = xr.concat([stacked_inputs, stacked_forcings], dim="channels")
139+
140+
stacked_inputs = stacked_inputs.transpose("lat", "lon", ...)
141+
lat_dim, lon_dim, batch_dim, feat_dim = stacked_inputs.shape
142+
stacked_inputs = stacked_inputs.data.reshape(lat_dim * lon_dim, batch_dim, -1)
143+
144+
graph_template_path = os.path.join(
145+
"data", "template_graph", f"{self.cfg.type}.pkl"
146+
)
147+
if os.path.exists(graph_template_path):
148+
graph_template = pickle.load(open(graph_template_path, "rb"))
149+
else:
150+
graph_template = graphtype.GraphGridMesh(
151+
self.cfg.denoiser_architecture_config
152+
)
153+
graph = copy.deepcopy(graph_template)
154+
155+
graph.grid_node_feat = np.concatenate(
156+
[stacked_inputs, graph.grid_node_feat], axis=-1
157+
)
158+
mesh_node_feat = np.zeros([graph.mesh_num_nodes, batch_dim, feat_dim])
159+
graph.mesh_node_feat = np.concatenate(
160+
[mesh_node_feat, graph.mesh_node_feat], axis=-1
161+
)
162+
graph.global_norm_conditioning = noise_level_encodings
163+
164+
predictor = self._predictor(graph=graphtype.convert_np_to_tensor(graph))
165+
166+
grid_node_outputs = predictor.grid_node_feat
167+
raw_predictions = predictor.grid_node_outputs_to_prediction(
168+
grid_node_outputs, noisy_targets
169+
)
170+
171+
return raw_predictions

0 commit comments

Comments
 (0)