Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions config/default_config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@ streams_directory: "./config/streams/era5_1deg/"
embed_orientation: "channels"
embed_unembed_mode: "block"
embed_dropout_rate: 0.1
embed_gradient_checkpoint_enabled: True

stream_embed_gradient_checkpoint_enabled: True

target_cell_local_prediction: True

Expand All @@ -14,11 +17,13 @@ ae_local_with_qk_lnorm: True

ae_local_num_queries: 1
ae_local_queries_per_cell: False
ae_local_blocks_grdient_checkpoint_enabled: True
ae_adapter_num_heads: 16
ae_adapter_embed: 128
ae_adapter_with_qk_lnorm: True
ae_adapter_with_residual: True
ae_adapter_dropout_rate: 0.1
ae_adapter_grdient_checkpoint_enabled: True

ae_global_dim_embed: 2048
ae_global_num_blocks: 8
Expand All @@ -31,6 +36,7 @@ ae_global_att_dense_rate: 1.0
ae_global_block_factor: 64
ae_global_mlp_hidden_factor: 2
ae_global_trailing_layer_norm: False
ae_global_gradient_checkpoint_enabled: True

ae_aggregation_num_blocks: 2
ae_aggregation_num_heads: 32
Expand All @@ -39,12 +45,17 @@ ae_aggregation_with_qk_lnorm: True
ae_aggregation_att_dense_rate: 1.0
ae_aggregation_block_factor: 64
ae_aggregation_mlp_hidden_factor: 2
ae_aggregation_gradient_checkpoint_enabled: True

decoder_type: PerceiverIOCoordConditioning # CrossAttentionAdaNormConditioning
target_pred_engine_gradient_checkpoint_enabled: True
target_pred_engine_classic_gradient_checkpoint_enabled: True

pred_adapter_kv: False
pred_self_attention: True
pred_dyadic_dims: False
pred_mlp_adaln: True
pred_head_gradient_checkpoint_enabled: True
num_class_tokens: 1
num_register_tokens: 7

Expand All @@ -63,6 +74,7 @@ fe_dropout_rate: 0.1
fe_with_qk_lnorm: True
fe_layer_norm_after_blocks: [] # Index starts at 0. Thus, [3] adds a LayerNorm after the fourth layer
fe_impute_latent_noise_std: 0.0 # 1e-4
fe_gradient_checkpoint_enabled: True

healpix_level: 5

Expand Down
34 changes: 25 additions & 9 deletions src/weathergen/model/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,21 +7,24 @@
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

from functools import partial

import numpy as np
import torch
from torch.utils.checkpoint import checkpoint

from weathergen.model.attention import MultiSelfAttentionHead
from weathergen.model.layers import MLP

# from weathergen.model.mlp import MLP
from weathergen.model.norms import RMSNorm
from weathergen.model.positional_encoding import positional_encoding_harmonic
from weathergen.model.utils import cond_checkpoint


class StreamEmbedTransformer(torch.nn.Module):
def __init__(
self,
cf,
mode,
num_tokens,
token_size,
Expand Down Expand Up @@ -57,6 +60,7 @@ def __init__(
self.num_blocks = num_blocks
self.num_heads = num_heads
self.unembed_mode = unembed_mode
self.cf = cf

norm = torch.nn.LayerNorm if norm_type == "LayerNorm" else RMSNorm

Expand Down Expand Up @@ -135,21 +139,29 @@ def __init__(

self.dropout_final = torch.nn.Dropout(0.1)

self.checkpoint_stream_embed = partial(
cond_checkpoint, self.cf.get("stream_embed_gradient_checkpoint_enabled", True)
)

def forward_channels(self, x_in):
peh = positional_encoding_harmonic

# embed provided input data
x = peh(checkpoint(self.embed, x_in.transpose(-2, -1), use_reentrant=False))
x = peh(
self.checkpoint_stream_embed(self.embed, x_in.transpose(-2, -1), use_reentrant=False)
)

for layer in self.layers:
x = checkpoint(layer, x, use_reentrant=False)
x = self.checkpoint_stream_embed(layer, x, use_reentrant=False)

# read out
if self.unembed_mode == "full":
out = checkpoint(self.unembed, self.ln_final(x.flatten(-2, -1)), use_reentrant=False)
out = self.checkpoint_stream_embed(
self.unembed, self.ln_final(x.flatten(-2, -1)), use_reentrant=False
)
elif self.unembed_mode == "block":
out = [
checkpoint(ue, ln(x[:, i]), use_reentrant=False)
self.checkpoint_stream_embed(ue, ln(x[:, i]), use_reentrant=False)
for i, (ue, ln) in enumerate(zip(self.unembed, self.ln_final, strict=True))
]
out = torch.stack(out, dim=1).flatten(-2, -1)
Expand All @@ -165,14 +177,18 @@ def forward_channels(self, x_in):

def forward_columns(self, x_in):
# embed provided input data
x = positional_encoding_harmonic(checkpoint(self.embed, x_in, use_reentrant=False))
x = positional_encoding_harmonic(
self.checkpoint_stream_embed(self.embed, x_in, use_reentrant=False)
)

for layer in self.layers:
x = checkpoint(layer, x, use_reentrant=False)
x = self.checkpoint_stream_embed(layer, x, use_reentrant=False)

out = checkpoint(self.unembed1, x, use_reentrant=False)
out = self.checkpoint_stream_embed(self.unembed1, x, use_reentrant=False)
out = self.unembed_nonlin(out)
out = checkpoint(self.unembed2, out.transpose(-2, -1), use_reentrant=False)
out = self.checkpoint_stream_embed(
self.unembed2, out.transpose(-2, -1), use_reentrant=False
)
out = out.flatten(-2, -1).unsqueeze(1)

# final normalize and dropout
Expand Down
62 changes: 50 additions & 12 deletions src/weathergen/model/engines.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@
# nor does it submit to any jurisdiction.

import dataclasses
from functools import partial

import torch
import torch.nn as nn
from torch.utils.checkpoint import checkpoint

from weathergen.common.config import Config
from weathergen.model.attention import (
Expand All @@ -27,7 +27,7 @@
StreamEmbedTransformer,
)
from weathergen.model.layers import MLP
from weathergen.model.utils import ActivationFactory
from weathergen.model.utils import ActivationFactory, cond_checkpoint
from weathergen.utils.utils import get_dtype


Expand Down Expand Up @@ -72,6 +72,7 @@ def __init__(self, cf: Config, sources_size, stream_names: list[str]) -> None:
norm_type=self.cf.norm_type,
unembed_mode=self.cf.embed_unembed_mode,
stream_name=stream_name,
cf=self.cf,
)
elif si["embed"]["net"] == "linear":
self.embeds[stream_name] = StreamEmbedLinear(
Expand Down Expand Up @@ -158,9 +159,15 @@ def __init__(self, cf: Config) -> None:
)
)

self.checkpoint_ae_local_blocks = partial(
cond_checkpoint, self.cf.get("ae_local_blocks_grdient_checkpoint_enabled", True)
)

def forward(self, tokens_c, cell_lens_c, use_reentrant):
for block in self.ae_local_blocks:
tokens_c = checkpoint(block, tokens_c, cell_lens_c, use_reentrant=use_reentrant)
tokens_c = self.checkpoint_ae_local_blocks(
block, tokens_c, cell_lens_c, use_reentrant=use_reentrant
)
return tokens_c


Expand Down Expand Up @@ -223,9 +230,13 @@ def __init__(self, cf: Config) -> None:
)
)

self.checkpoint_ae_adapter = partial(
cond_checkpoint, self.cf.get("ae_adapter_grdient_checkpoint_enabled", True)
)

def forward(self, tokens_c, tokens_global_c, q_cells_lens_c, cell_lens_c, use_reentrant):
for block in self.ae_adapter:
tokens_global_c = checkpoint(
tokens_global_c = self.checkpoint_ae_adapter(
block,
tokens_global_c,
tokens_c,
Expand Down Expand Up @@ -301,9 +312,13 @@ def __init__(self, cf: Config, num_healpix_cells: int) -> None:
)
)

self.checkpoint_ae_aggregation = partial(
cond_checkpoint, self.cf.get("ae_aggregation_gradient_checkpoint_enabled", True)
)

def forward(self, tokens, use_reentrant):
for block in self.ae_aggregation_blocks:
tokens = checkpoint(block, tokens, use_reentrant=use_reentrant)
tokens = self.checkpoint_ae_aggregation(block, tokens, use_reentrant=use_reentrant)
return tokens


Expand Down Expand Up @@ -373,9 +388,17 @@ def __init__(self, cf: Config, num_healpix_cells: int) -> None:
torch.nn.LayerNorm(self.cf.ae_global_dim_embed, elementwise_affine=False)
)

self.ae_global_blocks.append(
torch.nn.LayerNorm(self.cf.ae_global_dim_embed, elementwise_affine=False)
)

self.checkpoint_ae_global = partial(
cond_checkpoint, self.cf.get("ae_global_gradient_checkpoint_enabled", True)
)

def forward(self, tokens, use_reentrant):
for block in self.ae_global_blocks:
tokens = checkpoint(block, tokens, use_reentrant=use_reentrant)
tokens = self.checkpoint_ae_global(block, tokens, use_reentrant=use_reentrant)
return tokens


Expand Down Expand Up @@ -463,13 +486,17 @@ def init_weights_final(m):
for block in self.fe_blocks:
block.apply(init_weights_final)

self.checkpoint_fe = partial(
cond_checkpoint, self.cf.get("fe_gradient_checkpoint_enabled", True)
)

def forward(self, tokens, fstep):
aux_info = None
for _b_idx, block in enumerate(self.fe_blocks):
if isinstance(block, torch.nn.modules.normalization.LayerNorm):
tokens = block(tokens)
else:
tokens = checkpoint(block, tokens, aux_info, use_reentrant=False)
tokens = self.checkpoint_fe(block, tokens, aux_info, use_reentrant=False)
return tokens


Expand Down Expand Up @@ -613,6 +640,11 @@ def __init__(
)
)

self.checkpoint_pred = partial(
cond_checkpoint,
self.cf.get("target_pred_engine_classic_gradient_checkpoint_enabled", True),
)

def forward(self, latent, output, latent_lens, output_lens, coordinates):
tc_tokens = output
tcs_lens = output_lens
Expand All @@ -622,9 +654,11 @@ def forward(self, latent, output, latent_lens, output_lens, coordinates):

for ib, block in enumerate(self.tte):
if self.cf.pred_self_attention and ib % 3 == 1:
tc_tokens = checkpoint(block, tc_tokens, tcs_lens, tcs_aux, use_reentrant=False)
tc_tokens = self.checkpoint_pred(
block, tc_tokens, tcs_lens, tcs_aux, use_reentrant=False
)
else:
tc_tokens = checkpoint(
tc_tokens = self.checkpoint_pred(
block,
tc_tokens,
tokens_stream,
Expand Down Expand Up @@ -780,6 +814,10 @@ def __init__(
f"{self.cf.decoder_type} is not implemented for prediction heads"
)

self.checkpoint_target = partial(
cond_checkpoint, self.cf.get("target_pred_engine_gradient_checkpoint_enabled", True)
)

def forward(self, latent, output, latent_lens, output_lens, coordinates):
latent = (
self.dropout(self.latent_in_norm(latent + self.pos_embed))
Expand All @@ -788,7 +826,7 @@ def forward(self, latent, output, latent_lens, output_lens, coordinates):
)
for layer in self.tte:
if isinstance(layer, OriginalPredictionBlock):
output = checkpoint(
output = self.checkpoint_target(
layer,
latent=latent.flatten(0, 1),
output=output,
Expand All @@ -798,7 +836,7 @@ def forward(self, latent, output, latent_lens, output_lens, coordinates):
use_reentrant=False,
)
elif isinstance(layer, CrossAttentionBlock):
output = checkpoint(
output = self.checkpoint_target(
layer,
x=output,
x_kv=latent.flatten(0, 1),
Expand All @@ -808,7 +846,7 @@ def forward(self, latent, output, latent_lens, output_lens, coordinates):
use_reentrant=False,
)
else:
output = checkpoint(
output = self.checkpoint_target(
layer,
x=output,
x_lens=output_lens,
Expand Down
18 changes: 14 additions & 4 deletions src/weathergen/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,13 @@
import logging
import math
import warnings
from functools import partial

import astropy_healpix as hp
import astropy_healpix.healpy
import numpy as np
import torch
import torch.nn as nn
from torch.utils.checkpoint import checkpoint

from weathergen.common.config import Config
from weathergen.datasets.batch import ModelBatch
Expand All @@ -32,7 +32,7 @@
TargetPredictionEngineClassic,
)
from weathergen.model.layers import MLP, NamedLinear
from weathergen.model.utils import get_num_parameters
from weathergen.model.utils import cond_checkpoint, get_num_parameters
from weathergen.utils.distributed import is_root
from weathergen.utils.utils import get_dtype

Expand Down Expand Up @@ -289,6 +289,14 @@ def __init__(self, cf: Config, sources_size, targets_num_channels, targets_coord
self.class_token_idx = cf.num_class_tokens + cf.num_register_tokens
self.register_token_idx = cf.num_register_tokens

self.checkpoint_fn_embed = partial(
cond_checkpoint, self.cf.get("embed_gradient_checkpoint_enabled", True)
)

self.checkpoint_fn_pred = partial(
cond_checkpoint, self.cf.get("pred_head_gradient_checkpoint_enabled", True)
)

#########################################
def create(self) -> "Model":
"""Create each individual module of the model"""
Expand Down Expand Up @@ -670,7 +678,7 @@ def predict(
)
# embed token coords
tc_embed = self.embed_target_coords[stream_name]
tc_tokens = checkpoint(tc_embed, t_coords, use_reentrant=False)
tc_tokens = self.checkpoint_fn_embed(tc_embed, t_coords, use_reentrant=False)

# skip when coordinate embeddings yields nan (i.e. the coord embedding network diverged)
if torch.isnan(tc_tokens).any():
Expand Down Expand Up @@ -705,7 +713,9 @@ def predict(
)

# final prediction head to map back to physical space
pred = checkpoint(self.pred_heads[stream_name], tc_tokens, use_reentrant=False)
pred = self.checkpoint_fn_pred(
self.pred_heads[stream_name], tc_tokens, use_reentrant=False
)

output.add_physical_prediction(fstep, stream_name, pred)

Expand Down
Loading
Loading