|
24 | 24 |
|
25 | 25 | import denoiser
|
26 | 26 | import dpm_solver_plus_plus_2s
|
| 27 | +import losses |
| 28 | +import numpy as np |
| 29 | +import paddle |
27 | 30 | import paddle.nn as nn
|
| 31 | +import samplers_utils |
28 | 32 | import xarray as xr
|
| 33 | +from graphcast import datasets |
29 | 34 |
|
30 | 35 |
|
31 | 36 | class GenCast(nn.Layer):
|
@@ -54,6 +59,7 @@ def __init__(
|
54 | 59 | self._sampler_config = cfg.sampler_config
|
55 | 60 | self._sampler = None
|
56 | 61 | self._noise_config = cfg.noise_config
|
| 62 | + self.cfg = cfg |
57 | 63 |
|
58 | 64 | def _c_in(self, noise_scale: xr.DataArray) -> xr.DataArray:
|
59 | 65 | """Scaling applied to the noisy targets input to the underlying network."""
|
@@ -81,22 +87,95 @@ def _preconditioned_denoiser(
|
81 | 87 | ) -> xr.Dataset:
|
82 | 88 | """The preconditioned denoising function D from the paper (Eqn 7)."""
|
83 | 89 | # Convert xarray DataArray to Paddle tensor for operations
|
84 |
| - raw_predictions = self._denoiser( |
| 90 | + raw_predictions, grid_node_outputs = self._denoiser( |
85 | 91 | inputs=inputs,
|
86 | 92 | noisy_targets=noisy_targets * self._c_in(noise_levels),
|
87 | 93 | noise_levels=noise_levels,
|
88 | 94 | forcings=forcings,
|
89 | 95 | **kwargs
|
90 | 96 | )
|
91 | 97 |
|
92 |
| - return raw_predictions * self._c_out( |
93 |
| - noise_levels |
94 |
| - ) + noisy_targets * self._c_skip(noise_levels) |
| 98 | + stacked_noisy_targets = datasets.dataset_to_stacked(noisy_targets) |
| 99 | + stacked_noisy_targets = stacked_noisy_targets.transpose("lat", "lon", ...) |
| 100 | + |
| 101 | + out = grid_node_outputs * paddle.to_tensor(self._c_out(noise_levels).data) |
| 102 | + skip = paddle.to_tensor( |
| 103 | + stacked_noisy_targets.data * self._c_skip(noise_levels).data |
| 104 | + ) |
| 105 | + grid_node_outputs = out + skip |
| 106 | + |
| 107 | + return ( |
| 108 | + raw_predictions * self._c_out(noise_levels) |
| 109 | + + noisy_targets * self._c_skip(noise_levels), |
| 110 | + grid_node_outputs, |
| 111 | + ) |
| 112 | + |
| 113 | + def loss( |
| 114 | + self, |
| 115 | + inputs: xr.Dataset, |
| 116 | + targets: xr.Dataset, |
| 117 | + forcings: Optional[xr.Dataset] = None, |
| 118 | + ): |
| 119 | + |
| 120 | + if self._noise_config is None: |
| 121 | + raise ValueError("Noise config must be specified to train GenCast.") |
| 122 | + |
| 123 | + grid_node_outputs, denoised_predictions, noise_levels = self.forward( |
| 124 | + inputs, targets, forcings |
| 125 | + ) |
| 126 | + |
| 127 | + loss, diagnostics = losses.weighted_mse_loss_from_xarray( |
| 128 | + grid_node_outputs, |
| 129 | + targets, |
| 130 | + # Weights are same as we used for GraphCast. |
| 131 | + per_variable_weights={ |
| 132 | + # Any variables not specified here are weighted as 1.0. |
| 133 | + # A single-level variable, but an important headline variable |
| 134 | + # and also one which we have struggled to get good performance |
| 135 | + # on at short lead times, so leaving it weighted at 1.0, equal |
| 136 | + # to the multi-level variables: |
| 137 | + "2m_temperature": 1.0, |
| 138 | + # New single-level variables, which we don't weight too highly |
| 139 | + # to avoid hurting performance on other variables. |
| 140 | + "10m_u_component_of_wind": 0.1, |
| 141 | + "10m_v_component_of_wind": 0.1, |
| 142 | + "mean_sea_level_pressure": 0.1, |
| 143 | + "sea_surface_temperature": 0.1, |
| 144 | + "total_precipitation_12hr": 0.1, |
| 145 | + }, |
| 146 | + ) |
| 147 | + loss *= paddle.to_tensor(self._loss_weighting(noise_levels).data) |
| 148 | + return loss, diagnostics |
95 | 149 |
|
96 | 150 | def forward(self, inputs, targets_template, forcings=None, **kwargs):
|
| 151 | + if self.cfg.mode == "eval": |
| 152 | + if self._sampler is None: |
| 153 | + self._sampler = dpm_solver_plus_plus_2s.Sampler( |
| 154 | + self._preconditioned_denoiser, **self._sampler_config |
| 155 | + ) |
| 156 | + return self._sampler(inputs, targets_template, forcings, **kwargs) |
| 157 | + if self.cfg.mode == "train": |
| 158 | + # Sample noise levels: |
| 159 | + batch_size = inputs.sizes["batch"] |
| 160 | + noise_levels = xr.DataArray( |
| 161 | + data=samplers_utils.rho_inverse_cdf( |
| 162 | + min_value=self._noise_config.training_min_noise_level, |
| 163 | + max_value=self._noise_config.training_max_noise_level, |
| 164 | + rho=self._noise_config.training_noise_level_rho, |
| 165 | + cdf=np.random.uniform(size=(batch_size,)).astype("float32"), |
| 166 | + ), |
| 167 | + dims=("batch",), |
| 168 | + ) |
| 169 | + |
| 170 | + # Sample noise and apply it to targets: |
| 171 | + noise = ( |
| 172 | + samplers_utils.spherical_white_noise_like(targets_template) |
| 173 | + * noise_levels |
| 174 | + ) |
| 175 | + |
| 176 | + noisy_targets = targets_template + noise |
97 | 177 |
|
98 |
| - if self._sampler is None: |
99 |
| - self._sampler = dpm_solver_plus_plus_2s.Sampler( |
100 |
| - self._preconditioned_denoiser, **self._sampler_config |
| 178 | + denoised_predictions, grid_node_outputs = self._preconditioned_denoiser( |
| 179 | + inputs, noisy_targets, noise_levels, forcings |
101 | 180 | )
|
102 |
| - return self._sampler(inputs, targets_template, forcings, **kwargs) |
| 181 | + return grid_node_outputs, denoised_predictions, noise_levels |
0 commit comments