Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
experimental DiT architecture
- Improved Transolver training recipe's configuration for checkpointing and normalization.
- Bumped `multi-storage-client` version to 0.33.0 with rust client.
- Improved configuration for DLWP Healpix (checkpoint directory) and GraphCast (W&B settings).

### Fixed

Expand Down
1 change: 1 addition & 0 deletions examples/weather/dlwp_healpix/configs/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ defaults:

experiment_name: ${now:%Y-%m-%d}/${now:%H-%M-%S}
output_dir: outputs/${experiment_name}
checkpoint_dir: ${output_dir}/tensorboard/checkpoints
# checkpoints names are in the form training-state-<name>.mdlus
checkpoint_name: last
load_weights_only: false
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ defaults:

experiment_name: ${now:%Y-%m-%d}/${now:%H-%M-%S}
output_dir: outputs/${experiment_name}
checkpoint_dir: ${output_dir}/tensorboard/checkpoints
# checkpoints names are in the form training-state-<name>.mdlus
checkpoint_name: last
load_weights_only: false
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ defaults:

experiment_name: ${now:%Y-%m-%d}/${now:%H-%M-%S}
output_dir: outputs/${experiment_name}
checkpoint_dir: ${output_dir}/tensorboard/checkpoints
# checkpoints names are in the form training-state-<name>.mdlus
checkpoint_name: last
load_weights_only: false
Expand Down
1 change: 1 addition & 0 deletions examples/weather/dlwp_healpix/configs/trainer/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,4 @@ early_stopping_patience: null
amp_mode: "fp16"
graph_mode: "train_eval"
output_dir: ${output_dir}
checkpoint_dir: ${checkpoint_dir}
1 change: 1 addition & 0 deletions examples/weather/dlwp_healpix/configs/trainer/dlom.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,4 +26,5 @@ early_stopping_patience: null
amp_mode: "fp16"
graph_mode: "train_eval"
output_dir: ${output_dir}
checkpoint_dir: ${checkpoint_dir}
max_norm: 0.25
1 change: 1 addition & 0 deletions examples/weather/dlwp_healpix/configs/trainer/dlwp.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,4 +26,5 @@ early_stopping_patience: null
amp_mode: "fp16"
graph_mode: "train_eval"
output_dir: ${output_dir}
checkpoint_dir: ${checkpoint_dir}
max_norm: 0.25
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,4 @@ _target_: torch.optim.lr_scheduler.CosineAnnealingLR
optimizer: ${model.optimizer}
T_max: ${trainer.max_epochs}
eta_min: 4e-5
last_epoch: -1
verbose: false
last_epoch: -1
17 changes: 11 additions & 6 deletions examples/weather/dlwp_healpix/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,18 +99,23 @@ def train(cfg):
iteration = 0
epochs_since_improved = 0

# Determine checkpoint directory
if cfg.get("checkpoint_dir") is None:
# Fallback to default structure if not specified in config
cfg.checkpoint_dir = str(
Path(cfg.get("output_dir"), "tensorboard", "checkpoints")
)

checkpoint_dir = Path(cfg.checkpoint_dir)

# Prepare training under consideration of checkpoint if given
if cfg.get("checkpoint_name", None) is not None:
checkpoint_path = Path(
cfg.get("output_dir"),
"tensorboard",
"checkpoints",
checkpoint_dir,
"training-state-" + cfg.get("checkpoint_name") + ".mdlus",
)
optimizer_path = Path(
cfg.get("output_dir"),
"tensorboard",
"checkpoints",
checkpoint_dir,
"optimizer-state-" + cfg.get("checkpoint_name") + ".ckpt",
)
if checkpoint_path.exists():
Expand Down
7 changes: 6 additions & 1 deletion examples/weather/dlwp_healpix/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def __init__(
device: torch.device = torch.device("cpu"),
output_dir: str = "/outputs/",
max_norm: float = None,
checkpoint_dir: str = None,
):
"""
Constructor.
Expand Down Expand Up @@ -105,6 +106,10 @@ def __init__(
num_shards=self.dist.world_size, shard_id=self.dist.rank
)
self.output_dir_tb = os.path.join(output_dir, "tensorboard")
if checkpoint_dir is None:
self.checkpoint_dir = os.path.join(self.output_dir_tb, "checkpoints")
else:
self.checkpoint_dir = checkpoint_dir

# set the other parameters
self.optimizer = optimizer
Expand Down Expand Up @@ -524,7 +529,7 @@ def fit(
iteration,
validation_error,
epochs_since_improved,
self.output_dir_tb,
self.checkpoint_dir,
),
)
thread.start()
Expand Down
5 changes: 1 addition & 4 deletions examples/weather/dlwp_healpix/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,7 @@ def write_checkpoint(
:param dst_path: Path where the checkpoint is written to
:param keep_n_checkpoints: Number of best checkpoints that will be saved (worse checkpoints are overwritten)
"""
root_path = os.path.join(
dst_path,
"checkpoints",
)
root_path = dst_path
# root_path = os.path.dirname(ckpt_dst_path)
ckpt_dst_path = os.path.join(
root_path,
Expand Down
2 changes: 2 additions & 0 deletions examples/weather/graphcast/conf/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,8 @@ val_freq: 5 # Frequency (iterations) for performing light-weight v
# │ Logging and Monitoring │
# └───────────────────────────────────────────┘

wb_entity: PhysicsNeMo
wb_project: GraphCast
wb_mode: online # Weights and Biases mode ["online", "offline", "disabled"]. If you don’t have a Weights and Biases API key, set this to "disabled".
watch_model: false # If true, records the model parameter gradients through Weights and Biases.

Expand Down
4 changes: 2 additions & 2 deletions examples/weather/graphcast/train_graphcast.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,8 +330,8 @@ def main(cfg: DictConfig) -> None:
# initialize loggers
if dist.rank == 0:
initialize_wandb(
project="GraphCast",
entity="PhysicsNeMo",
project=cfg.get("wb_project", "GraphCast"),
entity=cfg.get("wb_entity", "PhysicsNeMo"),
name=f"GraphCast-{HydraConfig.get().job.name}",
group="group",
mode=cfg.wb_mode,
Expand Down