Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
305 changes: 305 additions & 0 deletions config/config_jepa.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,305 @@
# (C) Copyright 2025 WeatherGenerator contributors.
#
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
#
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

embed_orientation: "channels"
embed_unembed_mode: "block"
embed_dropout_rate: 0.1

ae_local_dim_embed: 1024
ae_local_num_blocks: 2
ae_local_num_heads: 16
ae_local_dropout_rate: 0.1
ae_local_with_qk_lnorm: True

ae_local_num_queries: 1
ae_local_queries_per_cell: False
ae_adapter_num_heads: 16
ae_adapter_embed: 128
ae_adapter_with_qk_lnorm: True
ae_adapter_with_residual: True
ae_adapter_dropout_rate: 0.1

ae_global_dim_embed: 2048
ae_global_num_blocks: 2
ae_global_num_heads: 32
ae_global_dropout_rate: 0.1
ae_global_with_qk_lnorm: True
# TODO: switching to < 1 triggers triton-related issues.
# See https://github.com/ecmwf/WeatherGenerator/issues/1050
ae_global_att_dense_rate: 1.0
ae_global_block_factor: 64
ae_global_mlp_hidden_factor: 2
ae_global_trailing_layer_norm: False

ae_aggregation_num_blocks: 8
ae_aggregation_num_heads: 32
ae_aggregation_dropout_rate: 0.1
ae_aggregation_with_qk_lnorm: True
ae_aggregation_att_dense_rate: 1.0
ae_aggregation_block_factor: 64
ae_aggregation_mlp_hidden_factor: 2

decoder_type: PerceiverIOCoordConditioning # CrossAttentionAdaNormConditioning
pred_adapter_kv: False
pred_self_attention: True
pred_dyadic_dims: False
pred_mlp_adaln: True
num_class_tokens: 1
num_register_tokens: 7

# number of steps offset applied to first target window; if set to zero and forecast_steps=0 then
# one is training an auto-encoder
fe_num_blocks: 6
fe_num_heads: 16
fe_dropout_rate: 0.1
fe_with_qk_lnorm: True
fe_layer_norm_after_blocks: [] # Index starts at 0. Thus, [3] adds a LayerNorm after the fourth layer
fe_impute_latent_noise_std: 0.0 # 1e-4
# currently fixed to 1.0 (due to limitations with flex_attention and triton)
forecast_att_dense_rate: 1.0
with_step_conditioning: True # False

healpix_level: 5

with_mixed_precision: True
with_flash_attention: True
compile_model: False
with_fsdp: False
attention_dtype: bf16
mixed_precision_dtype: bf16
mlp_norm_eps: 1e-5
norm_eps: 1e-4

latent_noise_kl_weight: 0.0 # 1e-5
latent_noise_gamma: 2.0
latent_noise_saturate_encodings: 5
latent_noise_use_additive_noise: False
latent_noise_deterministic_latents: True

latent_heads_allowed: True


freeze_modules: ""

norm_type: "LayerNorm"


#####################################

streams_directory: "./config/streams/era5_1deg/"
# streams_directory: "./config/streams/era5_nppatms_synop/"
streams: ???

general:

# mutable parameters
istep: 0
rank: ???
world_size: ???

# local_rank,
# with_ddp,
# data_path_*,
# model_path,
# run_path,
# path_shared_

multiprocessing_method: "fork"

desc: ""
run_id: ???
run_history: []

# logging frequency in the training loop (in number of batches)
train_log_freq:
terminal: 10
metrics: 20
checkpoint: 250

# parameters for data loading
data_loading :

num_workers: 8
rng_seed: ???


# config for training
training_config:

# training_mode: "masking", "student_teacher", "latent_loss"
training_mode: ["student_teacher"]

num_mini_epochs: 32
samples_per_mini_epoch: 256
shuffle: True

start_date: 1979-01-01T00:00
end_date: 2022-12-31T00:00

time_window_step: 06:00:00
time_window_len: 06:00:00

window_offset_prediction : 0

learning_rate_scheduling :
lr_start: 1e-6
lr_max: 5e-5
lr_final_decay: 1e-6
lr_final: 0.0
num_steps_warmup: 512
num_steps_cooldown: 512
policy_warmup: "cosine"
policy_decay: "constant"
policy_cooldown: "linear"
parallel_scaling_policy: "sqrt"

optimizer:
grad_clip: 1.0
weight_decay: 0.1
log_grad_norms: False
adamw :
# parameters are scaled by number of DDP workers
beta1 : 0.975
beta2 : 0.9875
eps : 2e-08

losses : {
"student-teacher": {
enabled: True,
type: LossLatentSSLStudentTeacher,
weight: 1.0,
loss_fcts : {
"JEPA": {
'weight': 8, "loss_extra_args": {}, "out_dim": 2048, "head": transformer,
"pred_num_blocks": 24, "pred_num_heads": 12, "pred_with_qk_lnorm": True, "pred_intermediate_dim": 768,
"pred_dropout_rate": 0.1,
target_source_correspondence: {0 : {0 : "complement"} },
}
},
target_and_aux_calc: { "EMATeacher" :
{ ema_ramp_up_ratio : 0.09,
ema_halflife_in_thousands: 1e-3,
model_param_overrides : { latent_heads_allowed: False },
}
}
}
}

model_input: {
"random_easy" : {
# masking strategy: "random", "forecast"
masking_strategy: "random",
num_samples: 1,
num_steps_input: 1,
masking_strategy_config : {
diffusion_rn : True,
rate : 0.6,
rate_sampling: False
},
},
}

target_input: {
"random_easy_target" : {
masking_strategy: "healpix",
num_samples: 1,
masking_strategy_config : { rate : 0.2, hl_mask: 0, rate_sampling: False },
},
}

forecast :
time_step: 00:00:00
num_steps: 0
policy: null


# validation config; full validation config is merge of training and validation config
validation_config:

samples_per_mini_epoch: 256
shuffle: False

start_date: 2023-10-01T00:00
end_date: 2023-12-31T00:00

# whether to track the exponential moving average of weights for validation
validate_with_ema:
enabled : False
ema_ramp_up_ratio: 0.09
ema_halflife_in_thousands: 1e-3

# number of validation samples that are written to disk
write_num_samples: 0
# output streams to write; default all
output_streams: null

# run validation before training starts (mainly for model development)
validate_before_training: False

# losses: {
# "physical": {
# type: LossPhysical,
# weight: 1.0,
# loss_fcts: {
# "mse": {
# weight: 1.0,
# },
# },
# },
# }

# Requires enabled flags
# model_input: {
# "random_easy" : {
# enabled : False,
# },
# "random_hard" : {
# enabled : False,
# },
# "strategy1" : {
# # "random", "healpix". Masking strategy to use for model input for masking
# masking_strategy: "forecast",
# num_samples: 1,
# masking_strategy_config : { diffusion_rn : True, rate : 0.4 },
# # relationship: "independent", "subset", "disjoint".
# relationship: "independent",
# num_steps_input: 1,
# }
# }

# test config; full test config is merge of validation and test config
# test config is used by default when running inference

# Tags for experiment tracking
# These tags will be logged in MLFlow along with completed runs for train, eval, val
# The tags are free-form, with the following rules:
# - tags should be primitive types (strings, numbers, booleans). NO lists or dictionaries
# - tags should not duplicate existing config entries.
# - try to reuse existing tags where possible. MLFlow does not like having too many unique tags
# - do not use long strings in values (less than 20 characters is a good rule of thumb, we may enforce this in the future)
wgtags:
# The name of the organization of the person running the experiment.
# This may be autofilled in the future. Expected values are lowercase strings
# e.g. "ecmwf", "cmcc", "metnor", "jsc", "escience"
org: null
# The Github issue corresponding to this run (number such as 1234)
# Github issues are the central point when running experiment and contain
# links to hedgedocs, code branches, pull requests etc.
# It is recommended to associate a run with a Github issue.
issue: null
# The name of the experiment. This is a distinctive codename for the experiment campaign being run.
# This is expected to be the primary tag for comparing experiments in MLFlow, along with the
# issue number.
# Expected values are lowercase strings with no spaces, just underscores:
# Examples: "rollout_ablation_grid"
exp: null
# *** Experiment-specific tags ***
# All extra tags (including lists, dictionaries, etc.) are treated
# as strings by mlflow, so treat all extra tags as simple string key: value pairs.
grid: null
6 changes: 6 additions & 0 deletions config/config_physical_jepa.yml
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,12 @@ fe_impute_latent_noise_std: 0.0 # 1e-4
# currently fixed to 1.0 (due to limitations with flex_attention and triton)
forecast_att_dense_rate: 1.0

sslpred_num_blocks: 12
sslpred_num_heads: 12
sslpred_dropout_rate: 0.1
sslpred_with_qk_lnorm: True
sslpred_intermediate_dim: 384

healpix_level: 5

with_mixed_precision: True
Expand Down
6 changes: 6 additions & 0 deletions integration_tests/jepa1.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,12 @@ lr_steps_warmup: 2
lr_steps_cooldown: 2
loader_num_workers: 8

sslpred_num_blocks: 12
sslpred_num_heads: 12
sslpred_dropout_rate: 0.1
sslpred_with_qk_lnorm: True
sslpred_intermediate_dim: 384

train_log:
log_interval: 1
### Example validation and training config for student-teacher with JEPA
Expand Down
Loading
Loading