-
Notifications
You must be signed in to change notification settings - Fork 51
Transformer predictor for JEPA #1590
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Transformer predictor for JEPA #1590
Conversation
|
Some design choices to discuss:
|
sophie-xhonneux
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Broadly, I am happy to approve the comments are addressed.
But I would like to talk about whether the global assimilation engine makes sense to be a part of the encoder or if for JEPA we should instead use the AggregationEngine (Global transformer without the masked tokens)
config/default_config_jepa.yml
Outdated
| # ### Example validation and training config for student-teacher with JEPA | ||
| validation_config: | ||
| losses: | ||
| LossPhysical: {weight: 0.0, loss_fcts: [['mse', 0.8], ['mae', 0.2]]} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we remove this if instead of setting the weight to 0?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this file is removed with the new configs
config/default_config_jepa.yml
Outdated
| training_mode: "student_teacher" # "masking", "student_teacher", "forecast" | ||
| target_and_aux_calc: "EMATeacher" | ||
| losses : | ||
| LossPhysical: {weight: 0.0, loss_fcts: [['mse', 0.8], ['mae', 0.2]]} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this file is removed with the new configs
|
I renamed the new parameters related to the JEPA predictor from |
| forecast_att_dense_rate: 1.0 | ||
| with_step_conditioning: True # False | ||
|
|
||
| sslpred_num_blocks: 12 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Where would this go into the new config structure?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Option A is model but it's ultimately something specific to the JEPA loss term.
|
|
||
| for _ in range(self.cf.sslpred_num_blocks): | ||
| self.pred_blocks.append( | ||
| MultiSelfAttentionHead( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should have a transformer_block module for attention + MLP (can/should be in a separate PR)
| ) | ||
| elif loss == "JEPA": | ||
| self.latent_heads[loss] = LatentPredictionHead( | ||
| self.latent_heads[loss] = TransformerPredictionHead( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is here no if-statement to chose between LatentPredictionHead and TransformerPredictionHead? LatentPredictionHead should be renamed to LatentPredictionHeadMLP (and then I would also prefer LatentPredictionHeadTransformer).
|
I clean up the PR a bit, but didn't know how to open a PR against this PR so I put it here #1649 |
You need open the PR against the repo Kerem is using (MeteoSwiss fork of WeatherGenerator). |
Description
Issue Number
Closes #1587
Checklist before asking for review
./scripts/actions.sh lint./scripts/actions.sh unit-test./scripts/actions.sh integration-testlaunch-slurm.py --time 60