Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
113 commits
Select commit Hold shift + click to select a range
b98d074
Partially revised config; model is still missing but proper setup of …
clessig Dec 30, 2025
bb12d5a
Changes necessary due to changed position of time keys and of run_id
clessig Dec 30, 2025
d510956
Handling of multiple loss terms / target_aux_calculators and non-Loss…
clessig Dec 30, 2025
793578e
Changed position of run_id in config
clessig Dec 30, 2025
6eabd27
Add function to extract batch size from mode_cfg
clessig Dec 30, 2025
d350b38
Changed position of run_id in config
clessig Dec 30, 2025
78b17af
Changes due to revised config. Also proper handling of target_aux_cal…
clessig Dec 30, 2025
d8a1291
Revised config structure, in particular for losses, and related changes
clessig Dec 30, 2025
0d4e471
Add missing copyright and minor changed to to_device()
clessig Dec 30, 2025
b99b5c9
Moved sanity checking from trainer here. Also learning_rate sub_part …
clessig Dec 30, 2025
9ef940b
Minor cleanups
clessig Dec 30, 2025
868e595
Changes due to changed structure of losses in config
clessig Dec 30, 2025
f005ef0
Changes due to changed structure of losses in config
clessig Dec 30, 2025
7b1d189
Minor changes due to changed position of run_id in config
clessig Dec 30, 2025
53eb0d0
Minor changes to accomodate new config, in particular target_aux_calc…
clessig Dec 30, 2025
cdbb696
Support batch_size > 1. Clean up of various smaller parts
clessig Dec 30, 2025
0b99f3e
Clean up and implementation for batch_size > 1.
clessig Dec 30, 2025
7d1226f
Fix to sharding problem with FSDP2
clessig Dec 30, 2025
0ca381d
Removed scatter offset computation which now happens on the fly in th…
clessig Dec 30, 2025
4d67ad2
Changes for revised config, simplify overall where possible
clessig Dec 30, 2025
66c83a2
Fix issues with source-target sample generation and matching. Work in…
clessig Dec 30, 2025
fff5749
Linting
clessig Dec 30, 2025
f28874b
Linting
clessig Dec 30, 2025
192930a
Linting
clessig Dec 30, 2025
32243f3
Linting
clessig Dec 30, 2025
0d11f87
Merge branch 'develop' of github.com:ecmwf/WeatherGenerator into cles…
clessig Dec 30, 2025
0b900d3
Type hint
clessig Dec 30, 2025
132a2be
Linting
clessig Dec 30, 2025
070e859
Linting
clessig Dec 31, 2025
39d02ef
Linting
clessig Dec 31, 2025
6413c0f
Renamed loss keys for consistency
clessig Dec 31, 2025
a93f7a2
implement reader merge
iluise Jan 7, 2026
a6ed021
Long list of fixes and improvements
clessig Jan 8, 2026
ed80885
Enabled support for minimal configs without rate
clessig Jan 8, 2026
7b3cf26
Merge branch 'develop' of github.com:ecmwf/WeatherGenerator into cles…
clessig Jan 8, 2026
1d7174d
Fixed validation. validation_io still broken
clessig Jan 8, 2026
6f076df
Fixed linting
clessig Jan 8, 2026
c622505
Fixed problem with target filtering for loss computation for SSL losses
clessig Jan 8, 2026
ce95b07
working version of merge reader
iluise Jan 8, 2026
fd9f64a
linter
iluise Jan 8, 2026
aaad059
lint
iluise Jan 8, 2026
cf34e92
fix lead time
iluise Jan 8, 2026
6bf0dc5
Merge branch 'develop' into iluise/fix-lead-time
iluise Jan 8, 2026
ceca952
Re-instantiated per loss-fct source/target correspondences. Introduce…
clessig Jan 8, 2026
4aa11ea
Fixed problem with undefined variable
clessig Jan 9, 2026
eedb67c
Revised config
clessig Jan 9, 2026
a22e1c0
Fixed bug with forecasting
clessig Jan 9, 2026
96b65b6
Added sanity check for config
clessig Jan 9, 2026
c2da8b1
Fix bug with duplicate targets
clessig Jan 9, 2026
47263de
Linting
clessig Jan 9, 2026
10b7a28
Fixed problem when losses is not specified in validation config
clessig Jan 9, 2026
1161923
Fix DINOv2
sophie-xhonneux Jan 9, 2026
fd4ac9e
Removed temporary patches; fixed properly in 10b7a28
clessig Jan 10, 2026
8131d5e
Linting
clessig Jan 10, 2026
b013066
Patched validation IO. Needs to be fixed properly.
clessig Jan 10, 2026
f4c8b24
Removed unused function
clessig Jan 10, 2026
8bedec6
Improved variable naming
clessig Jan 10, 2026
0264086
Improved encapsulation of functionality: total_batch_size
clessig Jan 10, 2026
65a52aa
Fixed broken inference
clessig Jan 10, 2026
2061fb2
Fixed problem with test where incorrect config was used
clessig Jan 10, 2026
5b48b36
Fixed processing and handling of spoof flag in loss calculation
clessig Jan 10, 2026
084039c
Fixed problem with pure masking where forecast_steps were 0. Removed …
clessig Jan 10, 2026
cf731d4
Fixed bug when output_streams is specified explicitly
clessig Jan 10, 2026
04ec6d4
Corrected config param for number of samples
clessig Jan 10, 2026
9bd3b25
Fixed bug in handling of spoof weight
clessig Jan 10, 2026
802a971
Improved clarity of logging statements
clessig Jan 10, 2026
4866cb6
Improved logging msgs
clessig Jan 10, 2026
15f5b60
Fix sinkhorn knopp
sophie-xhonneux Jan 11, 2026
401593e
Fix sinkhorn in multi-GPU mode
sophie-xhonneux Jan 11, 2026
5e0530d
Removed some old comments
clessig Jan 11, 2026
0bbba9b
Fixed inference overwrites
clessig Jan 11, 2026
36655a7
Merge branch 'clessig/develop/fix_config_1534' of github.com:ecmwf/We…
clessig Jan 11, 2026
0f76a40
Fixing empty output when masking
clessig Jan 11, 2026
a26622a
Intermediate stage to re-enable integration test
clessig Jan 11, 2026
5976ed2
Adjusted thresholds
clessig Jan 11, 2026
3169a04
Renaming
clessig Jan 11, 2026
8fad4ff
Removing old config files
clessig Jan 11, 2026
048d157
Adding copyright
clessig Jan 11, 2026
176b678
Revised default_config. This is a minimal example config for simple t…
clessig Jan 11, 2026
057d01d
Changed multiprocessing param
clessig Jan 11, 2026
88acc44
Adapation for new position of multiprocessing param
clessig Jan 11, 2026
8948af9
Adding example config that combines an SSL and physical loss term
clessig Jan 11, 2026
75f1e67
More cleanup
clessig Jan 11, 2026
bcc2360
Restoring some default values
clessig Jan 11, 2026
184f585
Merge branch 'develop' of github.com:ecmwf/WeatherGenerator into cles…
clessig Jan 11, 2026
a60aab5
Restoring default for decoder_type
clessig Jan 11, 2026
c335d97
Merge branch 'develop' into iluise/fix-lead-time
iluise Jan 12, 2026
33755c7
update to develop
iluise Jan 12, 2026
ecbd897
lint
iluise Jan 12, 2026
c714be2
Merge branch 'iluise/fix-lead-time' of github.com:ecmwf/WeatherGenera…
clessig Jan 12, 2026
da42ad6
Fixed problem where parameter was expected in old config place
clessig Jan 12, 2026
e4e3922
Fixed linting
clessig Jan 12, 2026
d741e2f
Simplified interface
clessig Jan 12, 2026
bcd561f
Re-enabled forecast step and location weighting
clessig Jan 12, 2026
8b0bb12
Merge branch 'develop' of github.com:ecmwf/WeatherGenerator into cles…
clessig Jan 13, 2026
72d92f9
Merge branches 'develop' and 'clessig/develop/fix_config_1534' of git…
clessig Jan 13, 2026
41c85a3
Merge branch 'develop' of github.com:ecmwf/WeatherGenerator into cles…
clessig Jan 13, 2026
a4f3eed
Linting
clessig Jan 13, 2026
78a7cd4
Using new option to have validate_before_training as an int arg that …
clessig Jan 13, 2026
4bfea4b
Added option to have validate_before_training as int argument (specif…
clessig Jan 13, 2026
9123bae
Refactored correspondence parsing
clessig Jan 13, 2026
a218319
Sophiex/dev/teacher overrides (#1557)
sophie-xhonneux Jan 13, 2026
841e027
Fixed missing default value
clessig Jan 13, 2026
5f2cb75
Bilinear decoder: adapt code for batchsize > 1 (#1592)
sophie-xhonneux Jan 13, 2026
71d43df
Changed defaults
clessig Jan 13, 2026
80e9181
Linting
clessig Jan 13, 2026
c4bd337
Fixed linting issue
clessig Jan 13, 2026
aae83c0
Reverting to ERA5-only as default
clessig Jan 13, 2026
eb235de
Fixed problem with train_continue
clessig Jan 13, 2026
b3cba32
Adding filtering of config based on enabled/disabled
clessig Jan 14, 2026
6541f06
Merge branch 'develop' of github.com:ecmwf/WeatherGenerator into cles…
clessig Jan 14, 2026
0e8c739
Fixed very hacking to get some plots
clessig Jan 14, 2026
2ef1eb9
Merge branch 'develop' of github.com:ecmwf/WeatherGenerator into cles…
clessig Jan 14, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions src/weathergen/train/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from weathergen.train.trainer_base import TrainerBase
from weathergen.train.utils import (
extract_batch_metadata,
filter_config_by_enabled,
get_batch_size_from_config,
get_target_idxs_from_cfg,
)
Expand Down Expand Up @@ -98,10 +99,19 @@ def init(self, cf: Config, devices):

self.freeze_modules = cf.get("freeze_modules", "")

# keys to filter for enabled/disabled
keys_to_filter = ["losses", "model_input", "target_input"]

# get training config and remove disabled options (e.g. because of overrides)
self.training_cfg = cf.get("training_config")
self.training_cfg = filter_config_by_enabled(self.training_cfg, keys_to_filter)

# validation and test configs are training configs, updated by specified keys
self.validation_cfg = merge_configs(self.training_cfg, cf.get("validation_config", {}))
self.validation_cfg = filter_config_by_enabled(self.validation_cfg, keys_to_filter)
# test cfg is derived from validation cfg with specified keys overwritten
self.test_cfg = merge_configs(self.validation_cfg, cf.get("test_config", {}))
self.test_cfg = filter_config_by_enabled(self.test_cfg, keys_to_filter)

# batch sizes
self.batch_size_per_gpu = get_batch_size_from_config(self.training_cfg)
Expand Down
17 changes: 16 additions & 1 deletion src/weathergen/train/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,21 @@ def get_target_idxs_from_cfg(cfg, loss_name) -> list[int] | None:

tc = [v.get("target_source_correspondence") for _, v in cfg.losses[loss_name].loss_fcts.items()]
tc = [list(t.keys()) for t in tc if t is not None]
target_idxs = list(set([i for t in tc for i in t])) if len(tc) > 0 else None
target_idxs = list(set([int(i) for t in tc for i in t])) if len(tc) > 0 else None

return target_idxs


def filter_config_by_enabled(cfg, keys):
"""
Filtered disabled entries from config
"""

for key in keys:
filtered = {}
for k, v in cfg.get(key, {}).items():
if v.get("enabled", True):
filtered[k] = v
cfg[key] = filtered

return cfg
21 changes: 11 additions & 10 deletions src/weathergen/utils/plot_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,12 +206,13 @@ def plot_lr(
if run_data.train.is_empty():
continue
run_id = run_data.run_id
x_col = next(filter(lambda c: x_axis in c, run_data.train.columns))
data_cols = list(filter(lambda c: "learning_rate" in c, run_data.train.columns))
# x_col = next(filter(lambda c: x_axis in c, run_data.train.columns))
# data_cols = list(filter(lambda c: "learning_rate" in c, run_data.train.columns))

plt.plot(
run_data.train[x_col],
run_data.train[data_cols],
# run_data.train[x_col],
# run_data.train[data_cols],
run_data.train,
linestyle,
color=colors[j % len(colors)],
)
Expand Down Expand Up @@ -384,7 +385,7 @@ def plot_loss_per_stream(
if run_data_mode.is_empty():
continue
# find the col of the request x-axis (e.g. samples)
x_col = next(filter(lambda c: x_axis in c, run_data_mode.columns))
# x_col = next(filter(lambda c: x_axis in c, run_data_mode.columns))
# find the cols of the requested metric (e.g. mse) for all streams
# TODO: fix captialization
data_cols = filter(
Expand All @@ -393,11 +394,11 @@ def plot_loss_per_stream(
)

for col in data_cols:
x_vals = np.array(run_data_mode[x_col])
# x_vals = np.array(run_data_mode[x_col])
y_data = np.array(run_data_mode[col])

plt.plot(
x_vals,
# x_vals,
y_data,
linestyle,
color=colors[j % len(colors)],
Expand Down Expand Up @@ -512,7 +513,7 @@ def plot_loss_per_run(
alpha = 0.35 if "train" in mode else alpha
run_data_mode = run_data.by_mode(mode)

x_col = [c for _, c in enumerate(run_data_mode.columns) if x_axis in c][0]
# x_col = [c for _, c in enumerate(run_data_mode.columns) if x_axis in c][0]
# find the cols of the requested metric (e.g. mse) for all streams
data_cols = [c for _, c in enumerate(run_data_mode.columns) if err in c]

Expand All @@ -525,11 +526,11 @@ def plot_loss_per_run(
if run_data_mode[col].shape[0] == 0:
continue

x_vals = np.array(run_data_mode[x_col])
# x_vals = np.array(run_data_mode[x_col])
y_data = np.array(run_data_mode[col])

plt.plot(
x_vals,
# x_vals,
y_data,
linestyle,
color=colors[j % len(colors)],
Expand Down
52 changes: 30 additions & 22 deletions src/weathergen/utils/train_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,20 +156,21 @@ def read(run_id: str, model_path: str = None, mini_epoch: int = -1) -> Metrics:
cols_train = ["dtime", "samples", "mse", "lr"]
cols1 = [_weathergen_timestamp, "num_samples", "loss_avg_mean", "learning_rate"]
for si in cf.streams:
for lf in cf.loss_fcts:
for lf, _ in cf.training_config.losses.items():
cols1 += [_key_loss(si["name"], lf[0])]
cols_train += [
si["name"].replace(",", "").replace("/", "_").replace(" ", "_") + ", " + lf[0]
]
with_stddev = [("stats" in lf) for lf in cf.loss_fcts]
if with_stddev:
for si in cf.streams:
cols1 += [_key_stddev(si["name"])]
cols_train += [
si["name"].replace(",", "").replace("/", "_").replace(" ", "_")
+ ", "
+ "stddev"
si["name"].replace(",", "").replace("/", "_").replace(" ", "_") + ", " + lf
]
# with_stddev = [("stats" in lf) for lf, _ in cf.training_config.losses.items()]
# if with_stddev:
# for si in cf.streams:
# cols1 += [_key_stddev(si["name"])]
# cols_train += [
# si["name"].replace(",", "").replace("/", "_").replace(" ", "_")
# + ", "
# + "stddev"
# ]

# read training log data
try:
with open(fname_log_train, "rb") as f:
Expand Down Expand Up @@ -214,20 +215,25 @@ def read(run_id: str, model_path: str = None, mini_epoch: int = -1) -> Metrics:
cols_val = ["dtime", "samples"]
cols2 = [_weathergen_timestamp, "num_samples"]
for si in cf.streams:
for lf in cf.loss_fcts_val:
cfg = (
cf.validation_config
if cf.validation_config.get("losses") is not None
else cf.training_config
)
for lf, _ in cfg.losses.items():
cols_val += [
si["name"].replace(",", "").replace("/", "_").replace(" ", "_") + ", " + lf[0]
si["name"].replace(",", "").replace("/", "_").replace(" ", "_") + ", " + lf
]
cols2 += [_key_loss(si["name"], lf[0])]
with_stddev = [("stats" in lf) for lf in cf.loss_fcts_val]
if with_stddev:
for si in cf.streams:
cols2 += [_key_stddev(si["name"])]
cols_val += [
si["name"].replace(",", "").replace("/", "_").replace(" ", "_")
+ ", "
+ "stddev"
]
# with_stddev = [("stats" in lf) for lf in cf.loss_fcts_val]
# if with_stddev:
# for si in cf.streams:
# cols2 += [_key_stddev(si["name"])]
# cols_val += [
# si["name"].replace(",", "").replace("/", "_").replace(" ", "_")
# + ", "
# + "stddev"
# ]
# read validation log data
try:
with open(fname_log_val, "rb") as f:
Expand Down Expand Up @@ -370,6 +376,8 @@ def clean_df(df, columns: list[str] | None):
idcs = [i for i in range(len(columns)) if columns[i] == "loss_avg_mean"]
if len(idcs) > 0:
columns[idcs[0]] = "loss_avg_0_mean"
# TODO, TODO, TODO
columns = ["LossPhysical.loss_avg"]
df = df.select(columns)
# Remove all rows where all columns are null
df = df.filter(~pl.all_horizontal(pl.col(c).is_null() for c in columns))
Expand Down