You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
with `--graph` as `simple`, `multiscale` or `hierarchcial` and `--save` specifies the name of the output file.
101
123
124
+
## Models
125
+
126
+
Pretrained models can be downloaded from [Hugging Face](https://huggingface.co/deinal/spacecast-models) using:
127
+
```
128
+
from huggingface_hub import snapshot_download
129
+
130
+
snapshot_download(
131
+
repo_id="deinal/spacecast-models",
132
+
repo_type="model",
133
+
local_dir="model_weights"
134
+
)
135
+
```
136
+
This also includes metrics for the models, and example forecasts for each run.
137
+
138
+
To reproduce results, run:
139
+
```
140
+
python -m neural_lam.plot_metrics \
141
+
--metrics_dir model_weights/metrics \
142
+
--forecasts_dir model_weights/forecasts \
143
+
--output_dir model_weights/plots
144
+
```
145
+
146
+
Examples forecast animations are available online for [Run 1](https://vimeo.com/1138703695), [Run 2](https://vimeo.com/1138703709), [Run 3](https://vimeo.com/1138703719) and [Run 4](https://vimeo.com/1138703728).
147
+
102
148
## Logging
103
149
104
150
If you'd like to login and use [W&B](https://wandb.ai/), run:
@@ -113,53 +159,104 @@ See [docs](https://docs.wandb.ai/) for more details.
113
159
114
160
## Training
115
161
116
-
The first stage of a probabilistic model can be trained something like this (where in later stages you add `kl_beta` and `crps_weight`):
162
+
For a full list of training options see `python neural_lam.train_model --help`.
163
+
164
+
The Graph-FM models were trained with commands like this:
165
+
```
166
+
python -m neural_lam.train_model \
167
+
--config_path data/vlasiator_config.yaml \
168
+
--model graphcast \
169
+
--graph simple \
170
+
--precision bf16-mixed \
171
+
--epochs 250 \
172
+
--scheduler_epochs 175 225 \
173
+
--lr 0.001 \
174
+
--batch_size 1 \
175
+
--hidden_dim 256 \
176
+
--processor_layers 12 \
177
+
--decode_dim 128 \
178
+
--ar_steps_train 4 \
179
+
--div_weight 10 \
180
+
--ar_steps_eval 4 \
181
+
--num_sanity_val_steps 0 \
182
+
--grad_checkpointing \
183
+
--num_workers 4 \
184
+
--num_nodes 4
185
+
```
186
+
The graph can be changed from `simple` to `multiscale` or `hierarchical`. In the case of the `hierarchical` graph change the `--model` from `graphcast` to `graph_fm`. Distributed data parallel training is supported, and the above script runs on 4 compute nodes. Gradient checkpointing is also turned on as training with many autoregressive steps increases memory consumption.
117
187
188
+
The probabilistic Graph-EFM model can be trained like this:
118
189
```
119
190
python -m neural_lam.train_model \
120
-
--config_path data/vlasiator_config.yaml \
121
-
--num_workers 2 \
122
-
--precision bf16-mixed \
123
-
--model graph_efm \
124
-
--graph multiscale \
125
-
--hidden_dim 64 \
126
-
--processor_layers 4 \
127
-
--ensemble_size 5 \
128
-
--batch_size 1 \
129
-
--lr 0.001 \
130
-
--kl_beta 0 \
131
-
--crps_weight 0 \
132
-
--ar_steps_train 1 \
133
-
--epochs 500 \
134
-
--val_interval 50 \
135
-
--ar_steps_eval 4 \
136
-
--val_steps_to_log 1 2 3
137
-
```
138
-
139
-
Distributed data parallel training is supported. Specify number of nodes with the `--node` argument. For a full list of training options see `python neural_lam.train_model --help`.
It is also possible to train without the `--scheduler_epochs`. Sometimes it is more convenient to train each phase separately to tune loss weights for example. In this case manually train the model in phases with 1. `--kl_beta 0` off for autoencoder training, 2. then turn it on with `--kl_beta 1` for 1-step ELBO training, 3. increase `--ar_steps_train 4` to a suitable value, 4. turn on `--crps_weight 1e6` where you see decrease in the CRPS loss and increase in SSR, and 5. optionally apply `--div_weight 1e7` (which can be turned on earlier too).
140
216
141
217
## Evaluation
142
218
143
-
Inference uses the same script as training, with the same choice of parameters, and some to have an extra look at like `--eval test`, `--ar_steps_eval 30` and `--n_example_pred 1` to evaluate 30 second forecasts on the test set with 1 example forecast plotted.
219
+
Inference uses similar scripts as training, and some evaluation specific flags like `--eval test`, `--ar_steps_eval 30` and `--n_example_pred 1` to evaluate 30 second forecasts on the test set with 1 example forecast plotted. Graph-FM can be evaluated using:
220
+
```
221
+
python -m neural_lam.train_model \
222
+
--config_path data/vlasiator_config_4.yaml \
223
+
--model graphcast \
224
+
--graph simple \
225
+
--precision bf16-mixed \
226
+
--batch_size 1 \
227
+
--hidden_dim 256 \
228
+
--processor_layers 12 \
229
+
--decode_dim 128 \
230
+
--num_sanity_val_steps 0 \
231
+
--num_workers 4 \
232
+
--num_nodes 1 \
233
+
--eval test \
234
+
--ar_steps_eval 30 \
235
+
--n_example_pred 0 \
236
+
--load model_weights/graph_fm_simple.ckpt
237
+
```
144
238
239
+
Graph-EFM can be evaluated producing `--ensemble_size 5` as follows:
145
240
```
146
241
python -m neural_lam.train_model \
147
242
--config_path data/vlasiator_config.yaml \
243
+
--precision bf16-mixed \
148
244
--model graph_efm \
149
-
--graph hierarchical \
150
-
--num_nodes 1 \
151
-
--num_workers 2 \
152
-
--batch_size 1 \
153
-
--hidden_dim 64 \
154
-
--processor_layers 2 \
245
+
--graph simple \
246
+
--hidden_dim 256 \
247
+
--processor_layers 12 \
248
+
--decode_dim 128 \
155
249
--ensemble_size 5 \
156
-
--ar_steps_eval 30 \
157
-
--precision bf16-mixed \
158
-
--n_example_pred 1 \
250
+
--batch_size 1 \
251
+
--num_sanity_val_steps 0 \
252
+
--num_workers 4 \
253
+
--num_nodes 1 \
159
254
--eval test \
160
-
--load ckpt_path
255
+
--ar_steps_eval 30 \
256
+
--n_example_pred 0 \
257
+
--load model_weights/graph_efm_simple.ckpt
161
258
```
162
-
where a model checkpoint from a given path given to the `--load` in `.ckpt` format. Already trained model weights are available on [Zenodo](https://zenodo.org/records/16930055).
259
+
where a model checkpoint from a given path given to the `--load` in `.ckpt` format. These examples use the pretrained models from [Hugging Face](https://huggingface.co/deinal/spacecast-models) stored in a `model_weights` directory.
163
260
164
261
## Cite
165
262
@@ -184,4 +281,4 @@ ML4PS paper
184
281
year={2025}
185
282
}
186
283
```
187
-
This work is based on code using a single run dataloader at commit: https://github.com/fmihpc/spacecast/commit/937094079c1364ec484d3d1647e758f4a388ad97.
284
+
The workshop paper is based on code using a single run dataloader at commit: [fmihpc/spacecast@ce3cd1](https://github.com/fmihpc/spacecast/tree/ce3cd1).
0 commit comments