Skip to content
13 changes: 7 additions & 6 deletions config/default_config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ pred_mlp_adaln: True

# number of steps offset applied to first target window; if set to zero and forecast_steps=0 then
# one is training an auto-encoder
forecast_offset : 1
forecast_offset : 0
forecast_delta_hrs: 0
forecast_steps: 1
forecast_policy: "fixed"
Expand All @@ -50,6 +50,7 @@ fe_num_heads: 16
fe_dropout_rate: 0.1
fe_with_qk_lnorm: True
fe_diffusion_model: True
repeat_data: True
impute_latent_noise_std: 0.0 # 1e-4
# Diffusion related parameters
frequency_embedding_dim: 256
Expand Down Expand Up @@ -150,8 +151,8 @@ training_config:
loss: "LatentDiffusionLoss" # placeholder

num_mini_epochs: 32
samples_per_mini_epoch: 4096
samples_per_validation: 512
samples_per_mini_epoch: 2
samples_per_validation: 2

shuffle: True

Expand All @@ -174,9 +175,9 @@ log_grad_norms: False

# start_date: 197901010000
start_date: 201401010000
end_date: 202012310000
start_date_val: 202101010000
end_date_val: 202201010000
end_date: 201401011200
start_date_val: 201401010000
end_date_val: 201401011200
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's set the end_dates for train and val to 201401011800, such that we can keep (fsm + 1) below.

len_hrs: 6
step_hrs: 6
input_window_steps: 1
Expand Down
6 changes: 4 additions & 2 deletions src/weathergen/datasets/data_reader_anemoi.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,10 @@ def __init__(
ds: Dataset = anemoi_datasets.open_dataset(
ds0, **kwargs, start=tw_handler.t_start, end=tw_handler.t_end
)

period = np.timedelta64(ds.frequency)
if len(ds.dates) != 1:
period = np.timedelta64(ds.frequency)
else:
period = np.timedelta64(0, "s")
data_start_time = ds.dates[0]
data_end_time = ds.dates[-1]
assert data_start_time is not None and data_end_time is not None, (
Expand Down
17 changes: 15 additions & 2 deletions src/weathergen/datasets/multi_stream_data_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,8 @@ def __init__(
self.forecast_policy = cf.forecast_policy

self.len = 100000000
self.samples_per_mini_epoch = samples_per_mini_epoch
self.repeat_data = cf.get("repeat_data", False)

self.streams_datasets: list[list[AnyDataReader]] = []
for _, stream_info in enumerate(cf.streams):
Expand Down Expand Up @@ -186,7 +188,12 @@ def __init__(

index_range = self.time_window_handler.get_index_range()
self.len = int(index_range.end - index_range.start)
self.len = min(self.len, samples_per_mini_epoch if samples_per_mini_epoch else self.len)
if not self.repeat_data:
self.len = min(self.len, samples_per_mini_epoch if samples_per_mini_epoch else self.len)
else:
assert samples_per_mini_epoch, "must specify samples_per_mini_epoch if repeat_data"
self.len = samples_per_mini_epoch

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if I understand correctly, but do we not want to run an epoch with e.g. samples_per_mini_epoch: 4096 while always repeating the same e.g. 4 samples? In this case we would need to introduce another config parameter, e.g. repeat_num_samples: 4 or repeat_num_idxs: [1,2,3,4] (to define specific indices, not sure if would need this).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is addressed in line 278, but may not cover some edge cases.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe the underlying question is whether we need to adjust start_date and end_date in the config to shorten the dataset or if we can also repeat samples randomly from the entire dataset. We can discuss in 5 minutes :)

# adjust len to split loading across all workers and ensure it is multiple of batch_size
len_chunk = ((self.len // cf.world_size) // batch_size) * batch_size
self.len = min(self.len, len_chunk)
Expand Down Expand Up @@ -269,10 +276,16 @@ def reset(self):
idx_end = index_range.end
# native length of datasets, independent of mini_epoch length that has potentially been
# specified
forecast_len = (self.len_hrs * (fsm + 1)) // self.step_hrs
forecast_len = (self.len_hrs * (fsm)) // self.step_hrs # NOTE: why was it fsm +1?
idx_end -= forecast_len + self.forecast_offset

assert idx_end > 0, "dataset size too small for forecast range"
self.perms = np.arange(index_range.start, idx_end)
if self.repeat_data:
assert self.samples_per_mini_epoch == self.len, "length of sampler was set different from samples_per_mini_epoch –- aborting to avoid unintended effects"
assert self.len % len(self.perms) == 0, "length of permutations is not a multiple of length of available data –- aborting to avoid unintended effects"
self.perms = np.tile(self.perms, self.len // len(self.perms)) #TODO: maybe use samples_per_mini_epoch?

if self.shuffle:
self.perms = self.rng.permutation(self.perms)

Expand Down
3 changes: 3 additions & 0 deletions src/weathergen/model/model.py
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's remove this comment

Original file line number Diff line number Diff line change
Expand Up @@ -598,6 +598,9 @@ def forward(
(streams_data, _, target_coords_idxs, metadata) = batch

tokens, posteriors = self.encode(model_params=model_params, batch=batch)

#TODO: Check here that the tokens are always the same when overfitting to a single sample

if encode_only:
return tokens, posteriors

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ def compute_loss(

loss_fsteps = torch.tensor(0.0, device=self.device, requires_grad=True)
ctr_fsteps = 0

for target_tokens, pred_tokens, fstep_loss_weight in zip(
target_tokens_all, pred_tokens_all, fstep_loss_weights, strict=True
):
Expand Down