Skip to content

Conversation

@kctezcan
Copy link
Contributor

Description

Issue Number

Closes #1587

Checklist before asking for review

  • I have performed a self-review of my code
  • My changes comply with basic sanity checks:
    • I have fixed formatting issues with ./scripts/actions.sh lint
    • I have run unit tests with ./scripts/actions.sh unit-test
    • I have documented my code and I have updated the docstrings.
    • I have added unit tests, if relevant
  • I have tried my changes with data and code:
    • I have run the integration tests with ./scripts/actions.sh integration-test
    • (bigger changes) I have run a full training and I have written in the comment the run_id(s): launch-slurm.py --time 60
    • (bigger changes and experiments) I have shared a hegdedoc in the github issue with all the configurations and runs for this experiments
  • I have informed and aligned with people impacted by my change:
    • for config changes: the MatterMost channels and/or a design doc
    • for changes of dependencies: the MatterMost software development channel

@kctezcan
Copy link
Contributor Author

Some design choices to discuss:

  1. the latent space has the default dimensions of 12288x2048. The optimal predictor dimension is 384. We use a linear layer without bias at the entrance of the predictor to map 2048->384 and as the last layer another linear map without bias to map 384->2048. These dimensions are all config parameters.
  2. We hardcoded 4 as the MLP factor (factor for the hidden dimension in the MLP layers) as this is the optimal value from the literature.
  3. we have introduced new parameters only in the jepa config: pred_xxx

Copy link
Contributor

@sophie-xhonneux sophie-xhonneux left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Broadly, I am happy to approve the comments are addressed.

But I would like to talk about whether the global assimilation engine makes sense to be a part of the encoder or if for JEPA we should instead use the AggregationEngine (Global transformer without the masked tokens)

# ### Example validation and training config for student-teacher with JEPA
validation_config:
losses:
LossPhysical: {weight: 0.0, loss_fcts: [['mse', 0.8], ['mae', 0.2]]}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we remove this if instead of setting the weight to 0?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this file is removed with the new configs

training_mode: "student_teacher" # "masking", "student_teacher", "forecast"
target_and_aux_calc: "EMATeacher"
losses :
LossPhysical: {weight: 0.0, loss_fcts: [['mse', 0.8], ['mae', 0.2]]}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this file is removed with the new configs

@kctezcan
Copy link
Contributor Author

I renamed the new parameters related to the JEPA predictor from pred_ to sslpred_ because I realized there were already parameters called pred_ for the decoder (i.e. prediction heads)

forecast_att_dense_rate: 1.0
with_step_conditioning: True # False

sslpred_num_blocks: 12
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where would this go into the new config structure?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Option A is model but it's ultimately something specific to the JEPA loss term.


for _ in range(self.cf.sslpred_num_blocks):
self.pred_blocks.append(
MultiSelfAttentionHead(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should have a transformer_block module for attention + MLP (can/should be in a separate PR)

)
elif loss == "JEPA":
self.latent_heads[loss] = LatentPredictionHead(
self.latent_heads[loss] = TransformerPredictionHead(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is here no if-statement to chose between LatentPredictionHead and TransformerPredictionHead? LatentPredictionHead should be renamed to LatentPredictionHeadMLP (and then I would also prefer LatentPredictionHeadTransformer).

@sophie-xhonneux
Copy link
Contributor

I clean up the PR a bit, but didn't know how to open a PR against this PR so I put it here #1649

@clessig
Copy link
Collaborator

clessig commented Jan 18, 2026

I clean up the PR a bit, but didn't know how to open a PR against this PR so I put it here #1649

You need open the PR against the repo Kerem is using (MeteoSwiss fork of WeatherGenerator).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

Status: No status

Development

Successfully merging this pull request may close these issues.

Implement transformer based predictors for JEPA

3 participants