Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
185 changes: 185 additions & 0 deletions config/default_config_geoinfo.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
streams_directory: "./config/streams/era5_1deg/"

embed_orientation: "channels"
embed_local_coords: True
embed_centroids_local_coords: False
embed_size_centroids: 0
embed_unembed_mode: "block"
embed_dropout_rate: 0.1

target_cell_local_prediction: True

ae_local_dim_embed: 512
ae_local_num_blocks: 2
ae_local_num_heads: 16
ae_local_dropout_rate: 0.1
ae_local_with_qk_lnorm: True

ae_local_num_queries: 1
ae_local_queries_per_cell: False
ae_adapter_num_heads: 16
ae_adapter_embed: 128
ae_adapter_with_qk_lnorm: True
ae_adapter_with_residual: True
ae_adapter_dropout_rate: 0.1

ae_global_dim_embed: 512
ae_global_num_blocks: 8
ae_global_num_heads: 32
ae_global_dropout_rate: 0.1
ae_global_with_qk_lnorm: True
# TODO: switching to < 1 triggers triton-related issues.
# See https://github.com/ecmwf/WeatherGenerator/issues/1050
ae_global_att_dense_rate: 1.0
ae_global_block_factor: 64
ae_global_mlp_hidden_factor: 2

decoder_type: PerceiverIOCoordConditioning # CrossAttentionAdaNormConditioning
pred_adapter_kv: False
pred_self_attention: True
pred_dyadic_dims: False
pred_mlp_adaln: True

# number of steps offset applied to first target window; if set to zero and forecast_steps=0 then
# one is training an auto-encoder
forecast_offset : 0
forecast_delta_hrs: 0
forecast_steps: 0
forecast_policy: null
forecast_att_dense_rate: 1.0
fe_num_blocks: 0
fe_num_heads: 16
fe_dropout_rate: 0.1
fe_with_qk_lnorm: True
impute_latent_noise_std: 0.0 # 1e-4

healpix_level: 5

with_mixed_precision: True
with_flash_attention: True
compile_model: False
with_fsdp: True
attention_dtype: bf16
mlp_norm_eps: 1e-5
norm_eps: 1e-4

latent_noise_kl_weight: 0.0 # 1e-5
latent_noise_gamma: 2.0
latent_noise_saturate_encodings: 5
latent_noise_use_additive_noise: False
latent_noise_deterministic_latents: True

loss_fcts:
-
- "mse"
- 1.0
loss_fcts_val:
-
- "mse"
- 1.0

batch_size_per_gpu: 1
batch_size_validation_per_gpu: 1

# a regex that needs to fully match the name of the modules you want to freeze
# e.g. ".*ERA5" will match any module whose name ends in ERA5\
# encoders and decoders that exist per stream have the stream name attached at the end
freeze_modules: ""

# whether to track the exponential moving average of weights for validation
validate_with_ema: True
ema_ramp_up_ratio: 0.09
ema_halflife_in_thousands: 1e-3

# training mode: "forecast" or "masking" (masked token modeling)
# for "masking" to train with auto-encoder mode, forecast_offset should be 0
training_mode: "masking"
# masking rate when training mode is "masking"; ignored in foreacast mode
masking_rate: 0.6
# sample the masking rate (with normal distribution centered at masking_rate)
# note that a sampled masking rate leads to varying requirements
masking_rate_sampling: True
# sample a subset of all target points, useful e.g. to reduce memory requirements (also can specify per-stream)
sampling_rate_target: 1.0
# include a masking strategy here, currently only supporting "random", "block", "healpix", "channel", "causal" and "combination"
masking_strategy: "random"
# masking_strategy_config is a dictionary of additional parameters for the masking strategy
# required for "healpix" and "channel" masking strategies
# "healpix": requires healpix mask level to be specified with `hl_mask`
# "channel": requires "mode" to be specified, "per_cell" or "global",
masking_strategy_config: {"strategies": ["random", "healpix", "channel"],
"probabilities": [0.34, 0.33, 0.33],
"hl_mask": 3, "mode": "per_cell",
"same_strategy_per_batch": false
}

num_mini_epochs: 32
samples_per_mini_epoch: 4096
samples_per_validation: 512
shuffle: True

lr_scaling_policy: "sqrt"
lr_start: 1e-6
lr_max: 5e-5
lr_final_decay: 1e-6
lr_final: 0.0
lr_steps_warmup: 512
lr_steps_cooldown: 512
lr_policy_warmup: "cosine"
lr_policy_decay: "constant"
lr_policy_cooldown: "linear"

grad_clip: 1.0
weight_decay: 0.1
norm_type: "LayerNorm"
nn_module: "te"
log_grad_norms: False

start_date: 199001010000
end_date: 202012310000
start_date_val: 202101010000
end_date_val: 202201010000
len_hrs: 6
step_hrs: 6
input_window_steps: 1

val_initial: False

loader_num_workers: 8
log_validation: 0
streams_output: ["ERA5"]

istep: 0
run_history: []

desc: ""
data_loader_rng_seed: ???
run_id: ???

# The period to log in the training loop (in number of batch steps)
train_log_freq:
terminal: 10
metrics: 20
checkpoint: 250


# Tags for experiment tracking
# These tags will be logged in MLFlow along with completed runs for train, eval, val
# The tags are free-form, with the following rules:
# - tags should be primitive types (strings, numbers, booleans). NO lists or dictionaries
# - tags should not duplicate existing config entries.
# - try to reuse existing tags where possible. MLFlow does not like having too many unique tags
# - do not use long strings in values (less than 20 characters is a good rule of thumb, we may enforce this in the future)
wgtags:
# The name of the organization of the person running the experiment.
# This may be autofilled in the future. Expected values are lowercase strings of
# the organizations codenames in https://confluence.ecmwf.int/display/MAEL/Staff+Contact+List
# e.g. "ecmwf", "cmcc", "metnor", "jsc", "escience"
org: None
# The name of the experiment. This is a distinctive codename for the experiment campaign being run.
# This is expected to be the primary tag for comparing experiments in MLFlow.
# Expected values are lowercase strings with no spaces, just underscores:
# Examples: "rollout_ablation_grid"
exp: None
# *** Experiment-specific tags ***
grid: None
33 changes: 33 additions & 0 deletions config/streams/era5_1deg_cerra/cerra.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# (C) Copyright 2024 WeatherGenerator contributors.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's please remove all the config files. We need to find a place for them that is not the main repo.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The configs are here only for the reviewer to test the work, I will remove them before submitting the PR

#
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
#
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

CERRA :
type : anemoi
filenames : ['cerra-rr-an-oper-se-al-ec-mars-5p5km-1985-2023-3h-v2.zarr']
geoinfo : ['al', 'lsm', 'orog']
loss_weight : 1.
masking_rate : 0.6
masking_rate_none : 0.05
token_size : 512
embed :
net : transformer
num_tokens : 1
num_heads : 2
dim_embed : 64
num_blocks : 2
embed_target_coords :
net : linear
dim_embed : 64
target_readout :
type : 'obs_value'
num_layers : 2
num_heads : 4
pred_head :
ens_size : 1
num_layers : 1
38 changes: 38 additions & 0 deletions config/streams/era5_1deg_cerra/era5.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# (C) Copyright 2024 WeatherGenerator contributors.
#
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
#
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

ERA5 :
type : anemoi
filenames : ['aifs-ea-an-oper-0001-mars-o96-1979-2023-6h-v8.zarr']
source_exclude : ['w_', 'skt', 'tcw', 'cp', 'tp']
target_exclude : ['w_', 'slor', 'sdor', 'tcw', 'cp', 'tp']
geoinfo : ['z', 'lsm', 'sdor', 'slor']
loss_weight : 1.
masking_rate : 0.6
masking_rate_none : 0.05
token_size : 8
tokenize_spacetime : True
max_num_targets: -1
embed :
net : transformer
num_tokens : 1
num_heads : 8
dim_embed : 64
num_blocks : 2
embed_target_coords :
net : linear
dim_embed : 64
target_readout :
type : 'obs_value' # token or obs_value
num_layers : 2
num_heads : 4
# sampling_rate : 0.2
pred_head :
ens_size : 1
num_layers : 1
99 changes: 95 additions & 4 deletions src/weathergen/datasets/data_reader_anemoi.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,52 @@ def __init__(
# get target channel weights from stream config
self.target_channel_weights = self.parse_target_channel_weights()

self.geoinfo_channels = []
self.geoinfo_idx = []
# select/filter requested geoinfo channels (static/constant-in-time variables)
self.geoinfo_idx = self.select_geoinfo_channels(ds0)
self.geoinfo_channels = [ds.variables[i] for i in self.geoinfo_idx]

# set geoinfo normalization statistics and cache geoinfo data
if len(self.geoinfo_idx) > 0:
self.mean_geoinfo = ds.statistics["mean"][self.geoinfo_idx]
self.stdev_geoinfo = ds.statistics["stdev"][self.geoinfo_idx]
# Cache geoinfo data once (constant in time, no need to read every epoch)
# Read from first timestep and store for reuse
geoinfo_data = ds[0:1][:, list(self.geoinfo_idx), 0]
# Shape: (1, num_geoinfo, num_gridpoints) -> (num_gridpoints, num_geoinfo)
self._cached_geoinfo = geoinfo_data[0].transpose().astype(np.float32)

# Log diagnostic info for geoinfo statistics
ds_name = stream_info["name"]
for i, ch_idx in enumerate(self.geoinfo_idx):
ch_name = ds.variables[ch_idx]
mean_val = self.mean_geoinfo[i]
stdev_val = self.stdev_geoinfo[i]
nan_count = np.isnan(self._cached_geoinfo[:, i]).sum()
if stdev_val == 0 or np.isclose(stdev_val, 0):
_logger.warning(
f"{ds_name}: geoinfo channel '{ch_name}' has stdev=0 "
"(constant field, will skip division in normalization)"
)
if nan_count > 0:
_logger.warning(
f"{ds_name}: geoinfo channel '{ch_name}' has {nan_count} NaN values "
f"({100 * nan_count / len(self._cached_geoinfo):.2f}% of grid points)"
)
_logger.debug(
f"{ds_name}: geoinfo '{ch_name}' - mean={mean_val:.4f}, stdev={stdev_val:.4f}"
)

# Replace NaN values in cached geoinfo with 0 (after normalization this will be neutral)
total_nans = np.isnan(self._cached_geoinfo).sum()
if total_nans > 0:
_logger.warning(
f"{ds_name}: Replacing {total_nans} total NaN values in geoinfo with 0"
)
self._cached_geoinfo = np.nan_to_num(self._cached_geoinfo, nan=0.0)
else:
self.mean_geoinfo = np.zeros(0)
self.stdev_geoinfo = np.ones(0)
self._cached_geoinfo = None

ds_name = stream_info["name"]
_logger.info(f"{ds_name}: source channels: {self.source_channels}")
Expand All @@ -132,6 +176,7 @@ def init_empty(self) -> None:
super().init_empty()
self.ds = None
self.len = 0
self._cached_geoinfo = None

@override
def length(self) -> int:
Expand Down Expand Up @@ -196,8 +241,12 @@ def _get(self, idx: TIndex, channels_idx: list[int]) -> ReaderData:
# repeat latlon len(t_idxs) times
coords = np.vstack((latlon,) * len(t_idxs))

# empty geoinfos for anemoi
geoinfos = np.zeros((len(data), 0), dtype=data.dtype)
# use cached geoinfo data (no disk read needed - already loaded during init)
if self._cached_geoinfo is not None:
# repeat cached geoinfo for all timesteps
geoinfos = np.vstack((self._cached_geoinfo,) * len(t_idxs))
else:
geoinfos = np.zeros((len(data), 0), dtype=data.dtype)

# date time matching #data points of data
# Assuming a fixed frequency for the dataset
Expand Down Expand Up @@ -255,6 +304,48 @@ def select_channels(self, ds0: anemoi_datasets, ch_type: str) -> NDArray[np.int6

return np.array(chs_idx, dtype=np.int64)

def select_geoinfo_channels(self, ds0: anemoi_datasets) -> NDArray[np.int64]:
"""
Select geoinfo channels (static/constant-in-time variables)

Parameters
----------
ds0 :
raw anemoi dataset with available channels

Returns
-------
NDArray of channel indices for geoinfo variables

"""

geoinfo_channels = self.stream_info.get("geoinfo", [])

if not geoinfo_channels:
return np.array([], dtype=np.int64)

# Select channels that are constant in time and match the geoinfo list
chs_idx = np.sort(
[
ds0.name_to_index[k]
for (k, v) in ds0.typed_variables.items()
if (
v.is_constant_in_time
and not v.is_computed_forcing
and np.array([f in k for f in geoinfo_channels]).any()
)
]
)

if len(chs_idx) == 0 and len(geoinfo_channels) > 0:
stream_name = self.stream_info["name"]
_logger.warning(
f"No matching geoinfo channels found for {stream_name}. "
f"Requested: {geoinfo_channels}"
)

return np.array(chs_idx, dtype=np.int64)


def _clip_lat(lats: NDArray) -> NDArray[np.float32]:
"""
Expand Down
Loading
Loading