Skip to content

Commit dcadcc6

Browse files
committed
feat(weather): Improve configuration for DLWP Healpix and GraphCast examples
- Added configurable checkpoint directory to DLWP Healpix config and training script. - Implemented Trainer logic to use specific checkpoint directory. - Updated utils.py to respect exact checkpoint path. - Made Weights & Biases entity and project configurable in GraphCast example.
1 parent cd2a314 commit dcadcc6

File tree

12 files changed

+29
-13
lines changed

12 files changed

+29
-13
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
7171
- Support passing custom tokenizer, detokenizer, and attention `Module`s in
7272
experimental DiT architecture
7373
- Improved Transolver training recipe's configuration for checkpointing and normalization.
74+
- Improved configuration for DLWP Healpix (checkpoint directory) and GraphCast (W&B settings).
7475
- Bumped `multi-storage-client` version to 0.33.0 with rust client.
7576

7677
### Fixed

examples/weather/dlwp_healpix/configs/config.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ defaults:
2222

2323
experiment_name: ${now:%Y-%m-%d}/${now:%H-%M-%S}
2424
output_dir: outputs/${experiment_name}
25+
checkpoint_dir: ${output_dir}/tensorboard/checkpoints
2526
# checkpoints names are in the form training-state-<name>.mdlus
2627
checkpoint_name: last
2728
load_weights_only: false

examples/weather/dlwp_healpix/configs/config_hpx32_coupled_dlom.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ defaults:
2222

2323
experiment_name: ${now:%Y-%m-%d}/${now:%H-%M-%S}
2424
output_dir: outputs/${experiment_name}
25+
checkpoint_dir: ${output_dir}/tensorboard/checkpoints
2526
# checkpoints names are in the form training-state-<name>.mdlus
2627
checkpoint_name: last
2728
load_weights_only: false

examples/weather/dlwp_healpix/configs/config_hpx32_coupled_dlwp.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ defaults:
2222

2323
experiment_name: ${now:%Y-%m-%d}/${now:%H-%M-%S}
2424
output_dir: outputs/${experiment_name}
25+
checkpoint_dir: ${output_dir}/tensorboard/checkpoints
2526
# checkpoints names are in the form training-state-<name>.mdlus
2627
checkpoint_name: last
2728
load_weights_only: false

examples/weather/dlwp_healpix/configs/trainer/default.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,3 +26,4 @@ early_stopping_patience: null
2626
amp_mode: "fp16"
2727
graph_mode: "train_eval"
2828
output_dir: ${output_dir}
29+
checkpoint_dir: ${checkpoint_dir}

examples/weather/dlwp_healpix/configs/trainer/dlom.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,4 +26,5 @@ early_stopping_patience: null
2626
amp_mode: "fp16"
2727
graph_mode: "train_eval"
2828
output_dir: ${output_dir}
29+
checkpoint_dir: ${checkpoint_dir}
2930
max_norm: 0.25

examples/weather/dlwp_healpix/configs/trainer/dlwp.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,4 +26,5 @@ early_stopping_patience: null
2626
amp_mode: "fp16"
2727
graph_mode: "train_eval"
2828
output_dir: ${output_dir}
29+
checkpoint_dir: ${checkpoint_dir}
2930
max_norm: 0.25

examples/weather/dlwp_healpix/train.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -99,18 +99,23 @@ def train(cfg):
9999
iteration = 0
100100
epochs_since_improved = 0
101101

102+
# Determine checkpoint directory
103+
if cfg.get("checkpoint_dir") is None:
104+
# Fallback to default structure if not specified in config
105+
cfg.checkpoint_dir = str(
106+
Path(cfg.get("output_dir"), "tensorboard", "checkpoints")
107+
)
108+
109+
checkpoint_dir = Path(cfg.checkpoint_dir)
110+
102111
# Prepare training under consideration of checkpoint if given
103112
if cfg.get("checkpoint_name", None) is not None:
104113
checkpoint_path = Path(
105-
cfg.get("output_dir"),
106-
"tensorboard",
107-
"checkpoints",
114+
checkpoint_dir,
108115
"training-state-" + cfg.get("checkpoint_name") + ".mdlus",
109116
)
110117
optimizer_path = Path(
111-
cfg.get("output_dir"),
112-
"tensorboard",
113-
"checkpoints",
118+
checkpoint_dir,
114119
"optimizer-state-" + cfg.get("checkpoint_name") + ".ckpt",
115120
)
116121
if checkpoint_path.exists():

examples/weather/dlwp_healpix/trainer.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ def __init__(
5353
device: torch.device = torch.device("cpu"),
5454
output_dir: str = "/outputs/",
5555
max_norm: float = None,
56+
checkpoint_dir: str = None,
5657
):
5758
"""
5859
Constructor.
@@ -105,6 +106,10 @@ def __init__(
105106
num_shards=self.dist.world_size, shard_id=self.dist.rank
106107
)
107108
self.output_dir_tb = os.path.join(output_dir, "tensorboard")
109+
if checkpoint_dir is None:
110+
self.checkpoint_dir = os.path.join(self.output_dir_tb, "checkpoints")
111+
else:
112+
self.checkpoint_dir = checkpoint_dir
108113

109114
# set the other parameters
110115
self.optimizer = optimizer
@@ -524,7 +529,7 @@ def fit(
524529
iteration,
525530
validation_error,
526531
epochs_since_improved,
527-
self.output_dir_tb,
532+
self.checkpoint_dir,
528533
),
529534
)
530535
thread.start()

examples/weather/dlwp_healpix/utils.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -51,10 +51,7 @@ def write_checkpoint(
5151
:param dst_path: Path where the checkpoint is written to
5252
:param keep_n_checkpoints: Number of best checkpoints that will be saved (worse checkpoints are overwritten)
5353
"""
54-
root_path = os.path.join(
55-
dst_path,
56-
"checkpoints",
57-
)
54+
root_path = dst_path
5855
# root_path = os.path.dirname(ckpt_dst_path)
5956
ckpt_dst_path = os.path.join(
6057
root_path,

0 commit comments

Comments
 (0)