Skip to content

Commit 9e74b15

Browse files
jleinonenCharlelieLrtmegnvidia
authored andcommitted
Interpolation model example (#1149)
* Temporal interpolation training recipe * Add README * Docs changes based on comments * Update docstrings and README * Add temporal interpolation animation * Add animation link * Add shape check in loss * Updates of configs + trainer * Update config comments * Update README.md style guide edits * Added wandb logging Signed-off-by: Charlelie Laurent <[email protected]> * Reformated sections in docstring for GeometricL2Loss Signed-off-by: Charlelie Laurent <[email protected]> * Update README and configs * README changes + type hint fixes * Update README.md * Draft of validation script * Update validation and README * Fixed command in README.md for temporal_interpolation example Signed-off-by: Charlelie Laurent <[email protected]> * Removed unused import in datapipe/climate_interp.py Signed-off-by: Charlelie Laurent <[email protected]> * Updated license headers in temporal_interpolation example Signed-off-by: Charlelie Laurent <[email protected]> * Renamed methods to avoid implicit shadowing in Trainer class Signed-off-by: Charlelie Laurent <[email protected]> * Cosmetic changes in train.py and removed unused import in validate.py Signed-off-by: Charlelie Laurent <[email protected]> * Added clamp in validate.py to make sure step does not go out of bounds Signed-off-by: Charlelie Laurent <[email protected]> * Added the temporal_interpolation example to the docs + updated CHANGELOG.md Signed-off-by: Charlelie Laurent <[email protected]> * Addressing remaining comments * Merged two data source classes in climate_interp.py Signed-off-by: Charlelie Laurent <[email protected]> --------- Signed-off-by: Charlelie Laurent <[email protected]> Co-authored-by: Charlelie Laurent <[email protected]> Co-authored-by: megnvidia <[email protected]> Co-authored-by: Charlelie Laurent <[email protected]>
1 parent aecf8c2 commit 9e74b15

File tree

15 files changed

+1768
-1
lines changed

15 files changed

+1768
-1
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
3333
to compose and define an active learning workflow is provided in `examples/active_learning`.
3434
The `moons` example provides a minimal (pedagogical) composition that is meant to
3535
illustrate how to define the necessary parts of the workflow.
36+
- Added a new example for temporal interpolation of weather forecasts using ModAFNO.
37+
Accessible in `examples/weather/temporal_interpolation`.
3638

3739
### Changed
3840

docs/examples_weather.rst

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,4 +16,5 @@ Weather and climate modeling examples using PhysicsNeMo.
1616
examples/weather/diagnostic/README.rst
1717
examples/weather/unified_recipe/README.rst
1818
examples/weather/corrdiff/README.rst
19-
examples/weather/stormcast/README.rst
19+
examples/weather/stormcast/README.rst
20+
examples/weather/temporal_interpolation/README.rst
882 KB
Loading

examples/README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ The several examples inside PhysicsNeMo can be classified based on their domains
7070
|[Medium-range global weather forecast using Mixture of Experts](./weather/mixture_of_experts/)|MoE Model|
7171
|[Generative Data Assimilation of Sparse Weather Observations](./weather/regen/)|Denoising Diffusion Model|
7272
|[Flood Forecasting](./weather/flood_modeling/)|GNN + KAN|
73+
|[Temporal Interpolation of Weather Forecasts](./weather/temporal_interpolation/)|ModAFNO|
7374

7475
### Structural Mechanics
7576

Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
# Earth-2 Temporal Interpolation Model
2+
3+
The temporal interpolation model is used to increase the temporal resolution of AI-based
4+
forecast models. These typically have a native temporal resolution of six hours; the
5+
interpolation allows this to be improved to one hour. With appropriate training data, even
6+
higher temporal resolutions might be achievable.
7+
8+
This PhysicsNeMo example shows how to train a ModAFNO-based temporal interpolation model
9+
with a custom dataset. This architecture uses an embedding network to determine
10+
parameters for a shift and scale operation that is used to modify the behavior of the AFNO
11+
network depending on a given conditioning variable. For temporal
12+
interpolation, the atmospheric states at both ends of the interpolation interval are
13+
passed as inputs along with some auxiliary data, such as orography, and the conditioning
14+
indicates which time step between the endpoints will be generated by the model. The
15+
interpolation is deterministic and trained with a latitude-weighted L2 loss. However, it
16+
can still be used to produce probabilistic forecasts, if used to interpolate results of
17+
probabilistic forecast models. More formally, the ModAFNO $f_{\theta}$ is a conditional
18+
expected-value model that approximates:
19+
20+
$$
21+
f_{\theta} (x_{t}, x_{t+T}, \Delta t) \approx
22+
\mathbb{E} \left[ x_{t + \Delta t} | x_{t}, x_{t+T}, \Delta t \right]
23+
$$
24+
25+
$0 \leq \Delta t \leq T$. In the pre-trained model, $T = 6$ hours and
26+
$\Delta t \in \{0, 1, 2, 3, 4, 5, 6\}$ hours.
27+
28+
For access to the pre-trained model, refer to the [wrapper in
29+
Earth2Studio](https://nvidia.github.io/earth2studio/modules/generated/models/px/earth2studio.models.px.InterpModAFNO.html#earth2studio.models.px.InterpModAFNO).
30+
A technical description of the model can be found in the paper ["Modulated Adaptive
31+
Fourier Neural Operators for Temporal Interpolation of Weather
32+
Forecasts"](https://arxiv.org/abs/2410.18904).
33+
34+
![Example of temporal interpolation of wind speed](../../../docs/img/temporal_interpolation.gif)
35+
36+
## Requirements
37+
38+
### Environment
39+
40+
You must have PhysicsNeMo installed on a GPU system. Training useful models, in
41+
practice, requires a multi-GPU system; for the original model, 64 H100 GPUs were used.
42+
Using the [PhysicsNeMo
43+
container](https://catalog.ngc.nvidia.com/orgs/nvidia/teams/physicsnemo/containers/physicsnemo)
44+
is recommended.
45+
46+
Install the additional packages (MLFlow) needed by this example:
47+
48+
```bash
49+
pip install -r requirements.txt
50+
```
51+
52+
### Data
53+
54+
To train a temporal interpolation model, ensure you have the following:
55+
56+
* A dataset of yearly HDF5 files at one-hour resolution. For more details, refer to the
57+
section ["Data Format and Structure" in the diagnostic model
58+
example](https://github.com/NVIDIA/physicsnemo/blob/5a64525c40eada2248cd3eacee0a6ac4735ae380/examples/weather/diagnostic/README.md#data-format-and-structure).
59+
These datasets can be very large. The dataset used to train the original model, with
60+
73 variables from 1980 to 2017, is approximately 100 TB in size. The data used to
61+
train the original model are on the ERA5 0.25 degree grid with shape `(721, 1440)` but
62+
other resolutions can work too. The ERA5 data is freely accessible; a recommended
63+
method to download it is the [ERA5 interface in
64+
Earth2Studio](https://nvidia.github.io/earth2studio/modules/generated/data/earth2studio.data.CDS.html).
65+
The data downloaded from this interface must then be inserted into the HDF5 file.
66+
* Statistics files containing the mean and standard deviation of each channel in the
67+
data files. They must be in the `stats/global_means.npy` and
68+
`stats/global_stds.npy` files in your data directory. They must be `.npy` files
69+
containing a 1D array with length equal to the number of variables in the dataset,
70+
with each value giving the mean (for `global_means.npy`) or standard deviation (for
71+
`global_stds.npy`) of the corresponding variable.
72+
* A JSON file with metadata about the contents of the HDF5 files. Refer to the [data
73+
sample](https://github.com/NVIDIA/physicsnemo/blob/main/examples/weather/temporal_interpolation/data/data.json)
74+
for an example describing the dataset used to train the original model.
75+
* Optional: NetCDF4 files containing the orography and land-sea mask for the grid
76+
contained in the data. These should contain a variable of the same shape as the data.
77+
78+
## Configuration
79+
80+
The model training is controlled by YAML configuration files that are managed by
81+
[Hydra](https://hydra.cc/), which is found in the `config` directory. The full
82+
configuration for training of the original model is
83+
[`train_interp.yaml`](https://github.com/NVIDIA/physicsnemo/blob/main/examples/weather/temporal_interpolation/config/train_interp.yaml).
84+
[`train_interp_lite.yaml`](https://github.com/NVIDIA/physicsnemo/blob/main/examples/weather/temporal_interpolation/config/train_interp_lite.yaml)
85+
runs a short test with a lightweight model, which is not expected to produce useful
86+
checkpoints but can be used to verify that training runs without errors.
87+
88+
See the comments in the configuration files for an explanation of each configuration
89+
parameter. To replicate the model from the paper, you only need to change the file and
90+
directory paths to correspond to those on your system. If you train it with a custom
91+
dataset, you might also need to change the `model.in_channels` and `model.out_channels`
92+
parameters.
93+
94+
## Starting Training
95+
96+
Test the training by running the `train.py` script using the "lite" configuration file
97+
on a system with a GPU:
98+
99+
```bash
100+
python train.py --config-name=train_interp_lite.yaml
101+
```
102+
103+
For a multi-GPU or multi-node training job, launch the training with the
104+
`train_interp.yaml` configuration file using `torchrun` or MPI. For example, to train on
105+
eight nodes with eight GPUs each, for a total of 64 GPUs, start a distributed compute
106+
job (for example, using SLURM or Run:ai) and use:
107+
108+
```bash
109+
torchrun --nnodes=8 --nproc-per-node=8 train.py --config-name=train_interp.yaml
110+
```
111+
112+
Or the equivalent `mpirun` command. The code will automatically use all GPUs
113+
available to the job. Remember to set `training.batch_size` in the configuration file to
114+
the batch size *per process*.
115+
116+
Configuration parameters can be overridden from the command line using the Hydra syntax.
117+
For instance, to set the optimizer learning rate to 0.0001 for the current run, you
118+
can use:
119+
120+
```bash
121+
torchrun --nnodes=8 --nproc-per-node=8 train.py --config-name=train_interp.yaml ++training.optimizer_params.lr=0.0001
122+
```
123+
124+
## Validation
125+
126+
To evaluate checkpoints, you can use the `validate.py` script. The script computes a
127+
histogram of squared errors as a function of the interpolation step (+0 h to +6 h),
128+
which can be used to produce a plot similar to Figure 3 of the paper. The validation
129+
uses the same configuration files as training, with validation-specific options passed
130+
through the `validation` configuration group. Refer to the docstring of `error_by_time`
131+
in `validate.py` for the recognized options.
132+
133+
For example, to run the validation of a model trained with `train_interp.yaml` and save
134+
the resulting error histogram to `validation.nc`:
135+
136+
```bash
137+
python validate.py --config-name="train_interp" ++validation.output_path=validation.nc
138+
```
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES.
2+
# SPDX-FileCopyrightText: All rights reserved.
3+
# SPDX-License-Identifier: Apache-2.0
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
model:
18+
model_type: "modafno" # should always be "modafno"
19+
model_name: "modafno-cplxscale-smallpatch" # name for the model
20+
inp_shape: [720, 1440] # should be [720, 1440], must be divisible by patch_size
21+
in_channels: null # number of input channels to the model, null determines it from datapipe
22+
out_channels: null # number of output channels from the model, null determines it from datapipe
23+
patch_size: [2,2] # size of AFNO patches
24+
embed_dim: 512 # embedding dimension
25+
mlp_ratio: 2.0 # multiplier for MLP hidden layer size (may be a non-integer value, e.g. 2.5)
26+
num_blocks: 12 # number of AFNO blocks
27+
28+
scale_shift_mode: complex # type of numbers used for the ModAFNO modulation, "real" or "complex"
29+
embed_model:
30+
dim: 64 # width of time embedding net
31+
depth: 1 # depth of time embedding net
32+
method: sinusoidal # embedding type used in time embedding net, "sinusoidal" or "learned"
33+
34+
datapipe:
35+
data_dir: "/data/era5-73varQ-hourly" # directory where data files are located
36+
metadata_path: "/data/era5-73varQ-hourly/metadata/data.json" # directory to metadata json file
37+
geopotential_filename: "/data/era5-wind_gust/invariants/orography.nc" # location of orography file
38+
lsm_filename: "/data/era5-wind_gust/invariants/land_sea_mask.nc" # location of lsm file
39+
use_latlon: True # when True, return latitude and longitude from datapipe
40+
num_samples_per_year_train: null # number of training samples per year, null uses all available
41+
num_samples_per_year_valid: 64 # number of validation samples per year
42+
batch_size_train: 1 # batch size per GPU
43+
44+
training:
45+
max_epoch: 120 # number of data "epochs" (each epoch we save a checkpoint, run validation, update LR)
46+
samples_per_epoch: 50000 # number of samples per "epoch"
47+
load_epoch: "latest" # int, null or "latest"; "latest" loads the most recent checkpoint in checkpoint_dir
48+
checkpoint_dir: "/checkpoints/fcinterp/" # location where checkpoints are saved
49+
50+
optimizer_params:
51+
lr: 5e-4 # learning rate
52+
betas: [0.9, 0.95] # beta parameters for Adam
53+
54+
logging:
55+
mlflow:
56+
use_mlflow: True # when True, produce logs with mlflow
57+
experiment_name: "Forecast interpolation model" # experiment name, can be set freely
58+
user_name: "PhysicsNeMo User" # user name, can be set freely
59+
wandb:
60+
use_wandb: False # when True, produce logs with wandb
61+
mode: "offline" # "online", "offline", or "disabled"
62+
project: "Temporal-Interpolation-Training" # project name for wandb
63+
entity: null # entity (username or team) for Weights & Biases
64+
results_dir: "./wandb/" # directory to save wandb logs
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES.
2+
# SPDX-FileCopyrightText: All rights reserved.
3+
# SPDX-License-Identifier: Apache-2.0
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
# Config file for testing training. Does a very short run with a small model.
18+
# Can be used to test that training runs without errors, not expected to
19+
# produce useful checkpoints.
20+
21+
model:
22+
model_type: "modafno" # should always be "modafno"
23+
model_name: "modafno-test" # name for the model
24+
inp_shape: [720, 1440] # should be [720, 1440], must be divisible by patch_size
25+
in_channels: null # number of input channels to the model, null determines it from datapipe
26+
out_channels: null # number of output channels from the model, null determines it from datapipe
27+
patch_size: [8,8] # size of AFNO patches
28+
embed_dim: 64 # embedding dimension
29+
mlp_ratio: 2.0 # multiplier for MLP hidden layer size (may be a non-integer value, e.g. 2.5)
30+
num_blocks: 2 # number of AFNO blocks
31+
32+
scale_shift_mode: complex # type of numbers used for the ModAFNO modulation, "real" or "complex"
33+
embed_model:
34+
dim: 64 # width of time embedding net
35+
depth: 1 # depth of time embedding net
36+
method: sinusoidal # embedding type used in time embedding net, "sinusoidal" or "learned"
37+
38+
datapipe:
39+
data_dir: "/data/era5-73varQ-hourly" # directory where data files are located
40+
metadata_path: "/data/era5-73varQ-hourly/metadata/data.json" # directory to metadata json file
41+
geopotential_filename: "/data/era5-wind_gust/invariants/orography.nc" # location of orography file
42+
lsm_filename: "/data/era5-wind_gust/invariants/land_sea_mask.nc" # location of lsm file
43+
use_latlon: True # when True, return latitude and longitude from datapipe
44+
num_samples_per_year_train: null # number of training samples per year, null uses all available
45+
num_samples_per_year_valid: 64 # number of validation samples per year
46+
batch_size_train: 1 # batch size per GPU
47+
48+
training:
49+
max_epoch: 4 # number of data "epochs" (each epoch we save a checkpoint, run validation, update LR)
50+
samples_per_epoch: 50 # number of samples per "epoch"
51+
load_epoch: "latest" # int, null or "latest"; "latest" loads the most recent checkpoint in checkpoint_dir
52+
checkpoint_dir: "/checkpoints/fcinterp/" # location where checkpoints are saved
53+
54+
optimizer_params:
55+
lr: 5e-4 # learning rate
56+
betas: [0.9, 0.95] # beta parameters for Adam
57+
58+
logging:
59+
mlflow:
60+
use_mlflow: True # when True, produce logs with mlflow
61+
experiment_name: "Forecast interpolation model" # experiment name, can be set freely
62+
user_name: "PhysicsNeMo User" # user name, can be set freely
63+
wandb:
64+
use_wandb: False # when True, produce logs with wandb
65+
mode: "offline" # "online", "offline", or "disabled"
66+
project: "Temporal-Interpolation-Training" # project name for wandb
67+
entity: null # entity (username or team) for Weights & Biases
68+
results_dir: "./wandb/" # directory to save wandb logs
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
{
2+
"dataset_name": "73ch-hourly",
3+
"attrs": {
4+
"description": "ERA5 data at 1 hourly frequency with snapshots at every hour 0000, 0100, 0200, 0300, ..., 2300 UTC. First snapshot in each file is Jan 01 0000 UTC. "
5+
},
6+
"h5_path": "fields",
7+
"dims": [
8+
"time",
9+
"channel",
10+
"lat",
11+
"lon"
12+
],
13+
"coords": {
14+
"channel": [
15+
"u10m",
16+
"v10m",
17+
"u100m",
18+
"v100m",
19+
"t2m",
20+
"sp",
21+
"msl",
22+
"tcwv",
23+
"u50",
24+
"u100",
25+
"u150",
26+
"u200",
27+
"u250",
28+
"u300",
29+
"u400",
30+
"u500",
31+
"u600",
32+
"u700",
33+
"u850",
34+
"u925",
35+
"u1000",
36+
"v50",
37+
"v100",
38+
"v150",
39+
"v200",
40+
"v250",
41+
"v300",
42+
"v400",
43+
"v500",
44+
"v600",
45+
"v700",
46+
"v850",
47+
"v925",
48+
"v1000",
49+
"z50",
50+
"z100",
51+
"z150",
52+
"z200",
53+
"z250",
54+
"z300",
55+
"z400",
56+
"z500",
57+
"z600",
58+
"z700",
59+
"z850",
60+
"z925",
61+
"z1000",
62+
"t50",
63+
"t100",
64+
"t150",
65+
"t200",
66+
"t250",
67+
"t300",
68+
"t400",
69+
"t500",
70+
"t600",
71+
"t700",
72+
"t850",
73+
"t925",
74+
"t1000",
75+
"q50",
76+
"q100",
77+
"q150",
78+
"q200",
79+
"q250",
80+
"q300",
81+
"q400",
82+
"q500",
83+
"q600",
84+
"q700",
85+
"q850",
86+
"q925",
87+
"q1000"
88+
]
89+
}
90+
}

0 commit comments

Comments
 (0)