Skip to content

Commit ed0441b

Browse files
committed
Update readme
1 parent f0d7a9f commit ed0441b

File tree

1 file changed

+136
-39
lines changed

1 file changed

+136
-39
lines changed

README.md

Lines changed: 136 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,22 @@ Training can then be run immediately on the preprocessed data with readily avail
4242
```
4343
python -m neural_lam.train_model \
4444
--config_path data_small/vlasiator_config.yaml \
45-
--model graph_efm \
46-
...
45+
--model graphcast \
46+
--graph simple \
47+
--epochs 10 \
48+
--lr 0.001 \
49+
--batch_size 4 \
50+
--hidden_dim 32 \
51+
--processor_layers 6 \
52+
--decode_dim 16 \
53+
--div_weight 10 \
54+
--ar_steps_train 2 \
55+
--ar_steps_eval 2
56+
```
57+
58+
For more commands see:
59+
```
60+
python -m neural_lam.train_model --help
4761
```
4862

4963
## Data
@@ -88,9 +102,17 @@ This produces training-ready zarr stores in the data directory.
88102

89103
Simple, multiscale, and hierarchical graphs are included already, but can be created using the following commands:
90104
```
91-
python -m neural_lam.create_graph --config_path data/vlasiator_config.yaml --name simple --levels 1 --coarsen-factor 5 --plot
92-
python -m neural_lam.create_graph --config_path data/vlasiator_config.yaml --name multiscale --levels 3 --coarsen-factor 5 --plot
93-
python -m neural_lam.create_graph --config_path data/vlasiator_config.yaml --name hierarchical --levels 3 --coarsen-factor 5 --hierarchical --plot
105+
python -m neural_lam.create_graph \
106+
--config_path data/vlasiator_config.yaml \
107+
--name simple --levels 1 --coarsen-factor 5 --plot
108+
109+
python -m neural_lam.create_graph \
110+
--config_path data/vlasiator_config.yaml \
111+
--name multiscale --levels 3 --coarsen-factor 5 --plot
112+
113+
python -m neural_lam.create_graph \
114+
--config_path data/vlasiator_config.yaml \
115+
--name hierarchical --levels 3 --coarsen-factor 5 --hierarchical --plot
94116
```
95117

96118
To plot the graphs and store as `.html` files run:
@@ -99,6 +121,30 @@ python -m neural_lam.plot_graph --datastore_config_path data/vlasiator_config.ya
99121
```
100122
with `--graph` as `simple`, `multiscale` or `hierarchcial` and `--save` specifies the name of the output file.
101123

124+
## Models
125+
126+
Pretrained models can be downloaded from [Hugging Face](https://huggingface.co/deinal/spacecast-models) using:
127+
```
128+
from huggingface_hub import snapshot_download
129+
130+
snapshot_download(
131+
repo_id="deinal/spacecast-models",
132+
repo_type="model",
133+
local_dir="model_weights"
134+
)
135+
```
136+
This also includes metrics for the models, and example forecasts for each run.
137+
138+
To reproduce results, run:
139+
```
140+
python -m neural_lam.plot_metrics \
141+
--metrics_dir model_weights/metrics \
142+
--forecasts_dir model_weights/forecasts \
143+
--output_dir model_weights/plots
144+
```
145+
146+
Examples forecast animations are available online for [Run 1](https://vimeo.com/1138703695), [Run 2](https://vimeo.com/1138703709), [Run 3](https://vimeo.com/1138703719) and [Run 4](https://vimeo.com/1138703728).
147+
102148
## Logging
103149

104150
If you'd like to login and use [W&B](https://wandb.ai/), run:
@@ -113,53 +159,104 @@ See [docs](https://docs.wandb.ai/) for more details.
113159

114160
## Training
115161

116-
The first stage of a probabilistic model can be trained something like this (where in later stages you add `kl_beta` and `crps_weight`):
162+
For a full list of training options see `python neural_lam.train_model --help`.
163+
164+
The Graph-FM models were trained with commands like this:
165+
```
166+
python -m neural_lam.train_model \
167+
--config_path data/vlasiator_config.yaml \
168+
--model graphcast \
169+
--graph simple \
170+
--precision bf16-mixed \
171+
--epochs 250 \
172+
--scheduler_epochs 175 225 \
173+
--lr 0.001 \
174+
--batch_size 1 \
175+
--hidden_dim 256 \
176+
--processor_layers 12 \
177+
--decode_dim 128 \
178+
--ar_steps_train 4 \
179+
--div_weight 10 \
180+
--ar_steps_eval 4 \
181+
--num_sanity_val_steps 0 \
182+
--grad_checkpointing \
183+
--num_workers 4 \
184+
--num_nodes 4
185+
```
186+
The graph can be changed from `simple` to `multiscale` or `hierarchical`. In the case of the `hierarchical` graph change the `--model` from `graphcast` to `graph_fm`. Distributed data parallel training is supported, and the above script runs on 4 compute nodes. Gradient checkpointing is also turned on as training with many autoregressive steps increases memory consumption.
117187

188+
The probabilistic Graph-EFM model can be trained like this:
118189
```
119190
python -m neural_lam.train_model \
120-
--config_path data/vlasiator_config.yaml \
121-
--num_workers 2 \
122-
--precision bf16-mixed \
123-
--model graph_efm \
124-
--graph multiscale \
125-
--hidden_dim 64 \
126-
--processor_layers 4 \
127-
--ensemble_size 5 \
128-
--batch_size 1 \
129-
--lr 0.001 \
130-
--kl_beta 0 \
131-
--crps_weight 0 \
132-
--ar_steps_train 1 \
133-
--epochs 500 \
134-
--val_interval 50 \
135-
--ar_steps_eval 4 \
136-
--val_steps_to_log 1 2 3
137-
```
138-
139-
Distributed data parallel training is supported. Specify number of nodes with the `--node` argument. For a full list of training options see `python neural_lam.train_model --help`.
191+
--config_path data/vlasiator_config.yaml \
192+
--precision bf16-mixed \
193+
--model graph_efm \
194+
--graph multiscale \
195+
--hidden_dim 256 \
196+
--processor_layers 12 \
197+
--decode_dim 128 \
198+
--batch_size 1 \
199+
--lr 0.001 \
200+
--kl_beta 1 \
201+
--ar_steps_train 4 \
202+
--div_weight 1e8 \
203+
--crps_weight 1e6 \
204+
--epochs 250 \
205+
--scheduler_epochs 100 150 200 225
206+
--val_interval 5 \
207+
--ar_steps_eval 1 \
208+
--val_steps_to_log 1 \
209+
--num_sanity_val_steps 0 \
210+
--var_leads_val_plot '{"0":[1], "3":[1], "6":[1], "9":[1]}' \
211+
--grad_checkpointing \
212+
--num_workers 4 \
213+
--num_nodes 4
214+
```
215+
It is also possible to train without the `--scheduler_epochs`. Sometimes it is more convenient to train each phase separately to tune loss weights for example. In this case manually train the model in phases with 1. `--kl_beta 0` off for autoencoder training, 2. then turn it on with `--kl_beta 1` for 1-step ELBO training, 3. increase `--ar_steps_train 4` to a suitable value, 4. turn on `--crps_weight 1e6` where you see decrease in the CRPS loss and increase in SSR, and 5. optionally apply `--div_weight 1e7` (which can be turned on earlier too).
140216

141217
## Evaluation
142218

143-
Inference uses the same script as training, with the same choice of parameters, and some to have an extra look at like `--eval test`, `--ar_steps_eval 30` and `--n_example_pred 1` to evaluate 30 second forecasts on the test set with 1 example forecast plotted.
219+
Inference uses similar scripts as training, and some evaluation specific flags like `--eval test`, `--ar_steps_eval 30` and `--n_example_pred 1` to evaluate 30 second forecasts on the test set with 1 example forecast plotted. Graph-FM can be evaluated using:
220+
```
221+
python -m neural_lam.train_model \
222+
--config_path data/vlasiator_config_4.yaml \
223+
--model graphcast \
224+
--graph simple \
225+
--precision bf16-mixed \
226+
--batch_size 1 \
227+
--hidden_dim 256 \
228+
--processor_layers 12 \
229+
--decode_dim 128 \
230+
--num_sanity_val_steps 0 \
231+
--num_workers 4 \
232+
--num_nodes 1 \
233+
--eval test \
234+
--ar_steps_eval 30 \
235+
--n_example_pred 0 \
236+
--load model_weights/graph_fm_simple.ckpt
237+
```
144238

239+
Graph-EFM can be evaluated producing `--ensemble_size 5` as follows:
145240
```
146241
python -m neural_lam.train_model \
147242
--config_path data/vlasiator_config.yaml \
243+
--precision bf16-mixed \
148244
--model graph_efm \
149-
--graph hierarchical \
150-
--num_nodes 1 \
151-
--num_workers 2 \
152-
--batch_size 1 \
153-
--hidden_dim 64 \
154-
--processor_layers 2 \
245+
--graph simple \
246+
--hidden_dim 256 \
247+
--processor_layers 12 \
248+
--decode_dim 128 \
155249
--ensemble_size 5 \
156-
--ar_steps_eval 30 \
157-
--precision bf16-mixed \
158-
--n_example_pred 1 \
250+
--batch_size 1 \
251+
--num_sanity_val_steps 0 \
252+
--num_workers 4 \
253+
--num_nodes 1 \
159254
--eval test \
160-
--load ckpt_path
255+
--ar_steps_eval 30 \
256+
--n_example_pred 0 \
257+
--load model_weights/graph_efm_simple.ckpt
161258
```
162-
where a model checkpoint from a given path given to the `--load` in `.ckpt` format. Already trained model weights are available on [Zenodo](https://zenodo.org/records/16930055).
259+
where a model checkpoint from a given path given to the `--load` in `.ckpt` format. These examples use the pretrained models from [Hugging Face](https://huggingface.co/deinal/spacecast-models) stored in a `model_weights` directory.
163260

164261
## Cite
165262

@@ -184,4 +281,4 @@ ML4PS paper
184281
year={2025}
185282
}
186283
```
187-
This work is based on code using a single run dataloader at commit: https://github.com/fmihpc/spacecast/commit/937094079c1364ec484d3d1647e758f4a388ad97.
284+
The workshop paper is based on code using a single run dataloader at commit: [fmihpc/spacecast@ce3cd1](https://github.com/fmihpc/spacecast/tree/ce3cd1).

0 commit comments

Comments
 (0)