Skip to content

Commit 4b994d8

Browse files
authored
Add training code for gencast (#1154)
* Modify the document of heat_pinn * add gencast code * add gencast code * add gencast code * add gencast code * add gencast code * modify gencast code * Corrected model type error * Add training code * Add training code for gencast
1 parent f8be42e commit 4b994d8

File tree

10 files changed

+487
-39
lines changed

10 files changed

+487
-39
lines changed

docs/zh/examples/gencast.md

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,17 +5,27 @@
55
- 下载目录`dm_graphcast/gencast/stats`下的所有文件放入`./data/stats/`目录下。
66
- 下载目录`dm_graphcast/gencast/dataset`下的任意或所有文件(例如:source-era5_date-2019-03-29_res-1.0_levels-13_steps-12.nc)放入`./data/dataset/`目录下。
77

8+
=== "模型训练命令"
9+
10+
``` sh
11+
# 设置路径到 PaddleScience/jointContribution 文件夹
12+
cd PaddleScience/jointContribution
13+
export PYTHONPATH=$PWD:$PYTHONPATH
14+
# 运行训练脚本
15+
python run_gencast.py mode=train
16+
```
17+
818
=== "模型评估命令"
919

1020
``` sh
1121
# 设置路径到 PaddleScience/jointContribution 文件夹
1222
cd PaddleScience/jointContribution
13-
export PYTHONPATH=$PWD:$PYTHONPAT
23+
export PYTHONPATH=$PWD:$PYTHONPATH
1424
# 下载模型参数
1525
cd gencast/
1626
wget -nc https://paddle-org.bj.bcebos.com/paddlescience/models/gencast/gencast_params_GenCast-1p0deg-Mini-_2019.pdparams -P ./data/params/
1727
# 运行评估脚本
18-
python run_gencast.py
28+
python run_gencast.py mode=eval
1929
```
2030

2131
## 1. 背景简介

jointContribution/gencast/conf/gencast.yaml

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,13 @@ mean_path: data/stats/gencast_stats_mean_by_level.nc
2828
min_path: data/stats/gencast_stats_min_by_level.nc
2929
param_path: data/params/gencast_params_GenCast-1p0deg-Mini-_2019.pdparams
3030

31+
train:
32+
learning_rate: 0.001
33+
weight_decay: 0.1
34+
num_epochs: 2000000
35+
batch_size: 1
36+
snapshot_freq: 10
37+
3138
sampler_config:
3239
max_noise_level: 80.0
3340
min_noise_level: 0.03
@@ -63,8 +70,9 @@ denoiser_architecture_config:
6370
block_q_dkv: 512
6471
block_kv_dkv: 1024
6572
block_kv_dkv_compute: 1024
66-
ffw_winit_final_mult: 0.0
67-
attn_winit_final_mult: 0.0
73+
ffw_winit_final_mult: 1.0
74+
attn_winit_final_mult: 1.0
75+
attn_winit_mult: 2.0
6876
ffw_hidden: 2048
6977
mesh_node_dim: 186
7078
mesh_node_emb_dim: 512

jointContribution/gencast/denoiser.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
"""Support for wrapping a general Predictor to act as a Denoiser."""
1515

1616
import copy
17+
import math
1718
import os
1819
import pickle
1920
from typing import Optional
@@ -66,12 +67,16 @@ def __init__(
6667
self._num_frequencies = num_frequencies
6768
self._apply_log_first = apply_log_first
6869

69-
# 创建 MLP
70+
# Creating MLP
7071
layers = []
7172
input_size = 2 * num_frequencies
7273
num_layers = len(output_sizes)
7374
for i, output_size in enumerate(output_sizes):
74-
linear_layer = nn.Linear(input_size, output_size)
75+
limit = math.sqrt(6 / input_size)
76+
weight_attr = paddle.framework.ParamAttr(
77+
initializer=paddle.nn.initializer.Uniform(low=-limit, high=limit)
78+
)
79+
linear_layer = nn.Linear(input_size, output_size, weight_attr=weight_attr)
7580
layers.append(linear_layer)
7681
if i < num_layers - 1:
7782
layers.append(activation)
@@ -168,4 +173,12 @@ def forward(
168173
grid_node_outputs, noisy_targets
169174
)
170175

171-
return raw_predictions
176+
resolution = self.cfg.denoiser_architecture_config.resolution
177+
grid_lat = np.arange(-90.0, 90.0 + resolution, resolution).astype(np.float32)
178+
grid_lon = np.arange(0.0, 360.0, resolution).astype(np.float32)
179+
grid_shape = [grid_lat.shape[0], grid_lon.shape[0]]
180+
grid_outputs_lat_lon_leading = grid_node_outputs.reshape(
181+
grid_shape + grid_node_outputs.shape[1:]
182+
)
183+
184+
return raw_predictions, grid_outputs_lat_lon_leading

jointContribution/gencast/dpm_solver_plus_plus_2s.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,7 @@ def init_noise(template):
173173
mid_over_current = mid_noise_level / noise_level
174174
# x = xr.open_dataset('/workspace/workspace/graphcast/x.nc')
175175

176-
x_denoised = denoiser(noise_level, x)
176+
x_denoised, _ = denoiser(noise_level, x)
177177
# This turns out to be a convex combination of current and denoised x,
178178
# which isn't entirely apparent from the paper formulae:
179179
x_mid = (
@@ -182,7 +182,7 @@ def init_noise(template):
182182
)
183183

184184
next_over_current = next_noise_level / noise_level
185-
x_mid_denoised = denoiser(mid_noise_level, x_mid)
185+
x_mid_denoised, _ = denoiser(mid_noise_level, x_mid)
186186
x_next = (
187187
next_over_current.numpy() * x
188188
+ (1 - next_over_current.numpy()) * x_mid_denoised

jointContribution/gencast/gencast.py

Lines changed: 87 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,13 @@
2424

2525
import denoiser
2626
import dpm_solver_plus_plus_2s
27+
import losses
28+
import numpy as np
29+
import paddle
2730
import paddle.nn as nn
31+
import samplers_utils
2832
import xarray as xr
33+
from graphcast import datasets
2934

3035

3136
class GenCast(nn.Layer):
@@ -54,6 +59,7 @@ def __init__(
5459
self._sampler_config = cfg.sampler_config
5560
self._sampler = None
5661
self._noise_config = cfg.noise_config
62+
self.cfg = cfg
5763

5864
def _c_in(self, noise_scale: xr.DataArray) -> xr.DataArray:
5965
"""Scaling applied to the noisy targets input to the underlying network."""
@@ -81,22 +87,95 @@ def _preconditioned_denoiser(
8187
) -> xr.Dataset:
8288
"""The preconditioned denoising function D from the paper (Eqn 7)."""
8389
# Convert xarray DataArray to Paddle tensor for operations
84-
raw_predictions = self._denoiser(
90+
raw_predictions, grid_node_outputs = self._denoiser(
8591
inputs=inputs,
8692
noisy_targets=noisy_targets * self._c_in(noise_levels),
8793
noise_levels=noise_levels,
8894
forcings=forcings,
8995
**kwargs
9096
)
9197

92-
return raw_predictions * self._c_out(
93-
noise_levels
94-
) + noisy_targets * self._c_skip(noise_levels)
98+
stacked_noisy_targets = datasets.dataset_to_stacked(noisy_targets)
99+
stacked_noisy_targets = stacked_noisy_targets.transpose("lat", "lon", ...)
100+
101+
out = grid_node_outputs * paddle.to_tensor(self._c_out(noise_levels).data)
102+
skip = paddle.to_tensor(
103+
stacked_noisy_targets.data * self._c_skip(noise_levels).data
104+
)
105+
grid_node_outputs = out + skip
106+
107+
return (
108+
raw_predictions * self._c_out(noise_levels)
109+
+ noisy_targets * self._c_skip(noise_levels),
110+
grid_node_outputs,
111+
)
112+
113+
def loss(
114+
self,
115+
inputs: xr.Dataset,
116+
targets: xr.Dataset,
117+
forcings: Optional[xr.Dataset] = None,
118+
):
119+
120+
if self._noise_config is None:
121+
raise ValueError("Noise config must be specified to train GenCast.")
122+
123+
grid_node_outputs, denoised_predictions, noise_levels = self.forward(
124+
inputs, targets, forcings
125+
)
126+
127+
loss, diagnostics = losses.weighted_mse_loss_from_xarray(
128+
grid_node_outputs,
129+
targets,
130+
# Weights are same as we used for GraphCast.
131+
per_variable_weights={
132+
# Any variables not specified here are weighted as 1.0.
133+
# A single-level variable, but an important headline variable
134+
# and also one which we have struggled to get good performance
135+
# on at short lead times, so leaving it weighted at 1.0, equal
136+
# to the multi-level variables:
137+
"2m_temperature": 1.0,
138+
# New single-level variables, which we don't weight too highly
139+
# to avoid hurting performance on other variables.
140+
"10m_u_component_of_wind": 0.1,
141+
"10m_v_component_of_wind": 0.1,
142+
"mean_sea_level_pressure": 0.1,
143+
"sea_surface_temperature": 0.1,
144+
"total_precipitation_12hr": 0.1,
145+
},
146+
)
147+
loss *= paddle.to_tensor(self._loss_weighting(noise_levels).data)
148+
return loss, diagnostics
95149

96150
def forward(self, inputs, targets_template, forcings=None, **kwargs):
151+
if self.cfg.mode == "eval":
152+
if self._sampler is None:
153+
self._sampler = dpm_solver_plus_plus_2s.Sampler(
154+
self._preconditioned_denoiser, **self._sampler_config
155+
)
156+
return self._sampler(inputs, targets_template, forcings, **kwargs)
157+
if self.cfg.mode == "train":
158+
# Sample noise levels:
159+
batch_size = inputs.sizes["batch"]
160+
noise_levels = xr.DataArray(
161+
data=samplers_utils.rho_inverse_cdf(
162+
min_value=self._noise_config.training_min_noise_level,
163+
max_value=self._noise_config.training_max_noise_level,
164+
rho=self._noise_config.training_noise_level_rho,
165+
cdf=np.random.uniform(size=(batch_size,)).astype("float32"),
166+
),
167+
dims=("batch",),
168+
)
169+
170+
# Sample noise and apply it to targets:
171+
noise = (
172+
samplers_utils.spherical_white_noise_like(targets_template)
173+
* noise_levels
174+
)
175+
176+
noisy_targets = targets_template + noise
97177

98-
if self._sampler is None:
99-
self._sampler = dpm_solver_plus_plus_2s.Sampler(
100-
self._preconditioned_denoiser, **self._sampler_config
178+
denoised_predictions, grid_node_outputs = self._preconditioned_denoiser(
179+
inputs, noisy_targets, noise_levels, forcings
101180
)
102-
return self._sampler(inputs, targets_template, forcings, **kwargs)
181+
return grid_node_outputs, denoised_predictions, noise_levels

0 commit comments

Comments
 (0)