Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion packages/common/src/weathergen/common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ def load_config(
# use OmegaConf.unsafe_merge if too slow
c = OmegaConf.merge(base_config, private_config, *overwrite_configs)
assert isinstance(c, Config)

# Ensure the config has mini-epoch notation
if hasattr(c, "samples_per_epoch"):
c.samples_per_mini_epoch = c.samples_per_epoch
Expand Down
4 changes: 3 additions & 1 deletion packages/dashboard/atmo_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,9 @@ def get_score_step_48h(score_col: str) -> pl.DataFrame:
.sort("start_time")
.filter(pl.col(score_col).is_not_null())
)
_logger.info(f"Getting score data for {score_col} at 48h (step={step_48h}): len={len(score_data)}")
_logger.info(
f"Getting score data for {score_col} at 48h (step={step_48h}): len={len(score_data)}"
)

# Iterate over the runs to get the metric at step 48h
scores_dt: list[float | None] = []
Expand Down
208 changes: 208 additions & 0 deletions src/weathergen/model/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,14 @@


class MultiSelfAttentionHeadVarlen(torch.nn.Module):
"""Multi-head self-attention with variable length sequences.

This module implements multi-head self-attention for variable length sequences packed into a
single tensor. It leverages FlashAttention's variable length API (`flash_attn_varlen_func`)
to efficiently handle batches of sequences with differing lengths without padding, using
cumulative length indices to define sequence boundaries.
"""

def __init__(
self,
dim_embed,
Expand All @@ -32,6 +40,21 @@ def __init__(
norm_eps=1e-5,
attention_dtype=torch.bfloat16,
):
"""Initialize the MultiSelfAttentionHeadVarlen module.

:param dim_embed: Embedding dimension.
:param num_heads: Number of attention heads.
:param dim_head_proj: Dimension of the projection head.
:param dropout_rate: Dropout rate.
:param with_residual: Whether to use residual connections.
:param with_qk_lnorm: Whether to use layer normalization for query and key.
:param with_flash: Whether to use flash attention.
:param norm_type: Type of normalization.
:param softcap: Softcap for attention.
:param dim_aux: Dimension of auxiliary data.
:param norm_eps: Epsilon for normalization.
:param attention_dtype: Data type for attention.
"""
super(MultiSelfAttentionHeadVarlen, self).__init__()

self.num_heads = num_heads
Expand Down Expand Up @@ -69,6 +92,14 @@ def __init__(
assert with_flash, "Only flash attention supported at the moment"

def forward(self, x, x_lens, ada_ln_aux=None):
"""Forward pass of the MultiSelfAttentionHeadVarlen module.

:param x: Input tensor.
:param x_lens: Lengths of the input sequences.
:param ada_ln_aux: Auxiliary data for adaptive layer normalization.

:return out: Output tensor.
"""
if self.with_residual:
x_in = x
x = self.lnorm(x) if ada_ln_aux is None else self.lnorm(x, ada_ln_aux)
Expand Down Expand Up @@ -106,6 +137,14 @@ def forward(self, x, x_lens, ada_ln_aux=None):


class MultiSelfAttentionHeadVarlenFlex(torch.nn.Module):
"""Multi-head self-attention with variable length sequences and flex attention.

This module implements multi-head self-attention using PyTorch's FlexAttention. It allows
for defining custom sparse attention patterns via a score modification function. This is
particularly useful for optimizing attention mechanisms where full NxN interactions are not
required or desired, enabling flexible and efficient attention computations.
"""

def __init__(
self,
dim_embed,
Expand All @@ -120,6 +159,20 @@ def __init__(
norm_eps=1e-5,
attention_dtype=torch.bfloat16,
):
"""Initialize the MultiSelfAttentionHeadVarlenFlex module.

:param dim_embed: Embedding dimension.
:param num_heads: Number of attention heads.
:param dim_head_proj: Dimension of the projection head.
:param dropout_rate: Dropout rate.
:param with_residual: Whether to use residual connections.
:param with_qk_lnorm: Whether to use layer normalization for query and key.
:param with_flash: Whether to use flash attention.
:param norm_type: Type of normalization.
:param softcap: Softcap for attention.
:param norm_eps: Epsilon for normalization.
:param attention_dtype: Data type for attention.
"""
super(MultiSelfAttentionHeadVarlenFlex, self).__init__()

self.num_heads = num_heads
Expand Down Expand Up @@ -160,6 +213,13 @@ def sparsity_mask(score, b, h, q_idx, kv_idx):
self.compiled_flex_attention = torch.compile(att, dynamic=False)

def forward(self, x, x_lens=None):
"""Forward pass of the MultiSelfAttentionHeadVarlenFlex module.

:param x: Input tensor.
:param x_lens: Lengths of the input sequences.

:return out: Output tensor.
"""
if self.with_residual:
x_in = x
x = self.lnorm(x)
Expand All @@ -181,6 +241,14 @@ def forward(self, x, x_lens=None):


class MultiSelfAttentionHeadLocal(torch.nn.Module):
"""Multi-head self-attention with local (block-wise) attention.

This module implements local (block-wise) multi-head self-attention. It restricts attention
to local blocks defined by `block_factor`, meaning tokens only attend to other tokens within
the same block. This effectively reduces the computational complexity from quadratic to
linear with respect to the sequence length (for a fixed block size), making it suitable for
processing long sequences where local interactions dominate."""

def __init__(
self,
dim_embed,
Expand All @@ -198,6 +266,23 @@ def __init__(
norm_eps=1e-5,
attention_dtype=torch.bfloat16,
):
"""Initialize the MultiSelfAttentionHeadLocal module.

:param dim_embed: Embedding dimension.
:param num_heads: Number of attention heads.
:param qkv_len: Length of the query, key and value.
:param block_factor: Block factor.
:param dim_head_proj: Dimension of the projection head.
:param dropout_rate: Dropout rate.
:param with_residual: Whether to use residual connections.
:param with_qk_lnorm: Whether to use layer normalization for query and key.
:param with_flash: Whether to use flash attention.
:param norm_type: Type of normalization.
:param softcap: Softcap for attention.
:param dim_aux: Dimension of the auxiliary data.
:param norm_eps: Epsilon for normalization.
:param attention_dtype: Data type for attention.
"""
super(MultiSelfAttentionHeadLocal, self).__init__()

self.num_heads = num_heads
Expand Down Expand Up @@ -243,6 +328,13 @@ def mask_block_local(batch, head, idx_q, idx_kv):
self.flex_attention = torch.compile(flex_attention, dynamic=False)

def forward(self, x, ada_ln_aux=None):
"""Forward pass of the MultiSelfAttentionHeadLocal module.

:param x: Input tensor.
:param ada_ln_aux: Auxiliary data for adaptive layer normalization.

:return out: Output tensor.
"""
if self.with_residual:
x_in = x
x = self.lnorm(x) if ada_ln_aux is None else self.lnorm(x, ada_ln_aux)
Expand All @@ -263,6 +355,14 @@ def forward(self, x, ada_ln_aux=None):


class MultiCrossAttentionHeadVarlen(torch.nn.Module):
"""Multi-head cross-attention with variable length sequences.

This module implements multi-head cross-attention for variable length sequences. Similar to
the self-attention variant, it uses FlashAttention (`flash_attn_varlen_func`) to handle
packed sequences of queries and keys/values with different lengths. It ensures correct masking
and efficient computation for cases where both source and target sequences vary in length.
"""

def __init__(
self,
dim_embed_q,
Expand All @@ -279,6 +379,22 @@ def __init__(
norm_eps=1e-5,
attention_dtype=torch.bfloat16,
):
"""Initialize the MultiCrossAttentionHeadVarlen module.

:param dim_embed_q: Embedding dimension of the query.
:param dim_embed_kv: Embedding dimension of the key and value.
:param num_heads: Number of attention heads.
:param dim_head_proj: Dimension of the projection head.
:param dropout_rate: Dropout rate.
:param with_residual: Whether to use residual connections.
:param with_qk_lnorm: Whether to use layer normalization for query and key.
:param with_flash: Whether to use flash attention.
:param norm_type: Type of normalization.
:param softcap: Softcap for attention.
:param dim_aux: Dimension of the auxiliary data.
:param norm_eps: Epsilon for normalization.
:param attention_dtype: Data type for attention.
"""
super(MultiCrossAttentionHeadVarlen, self).__init__()

self.num_heads = num_heads
Expand Down Expand Up @@ -321,6 +437,16 @@ def __init__(
assert with_flash, "Only flash attention supported at the moment"

def forward(self, x_q, x_kv, x_q_lens=None, x_kv_lens=None, ada_ln_aux=None):
"""Forward pass of the MultiCrossAttentionHeadVarlen module.

:param x_q: Query tensor.
:param x_kv: Key and value tensor.
:param x_q_lens: Lengths of the query sequences.
:param x_kv_lens: Lengths of the key and value sequences.
:param ada_ln_aux: Auxiliary data for adaptive layer normalization.

:return outs: Output tensors.
"""
if self.with_residual:
x_q_in = x_q
x_q = self.lnorm_in_q(x_q) if ada_ln_aux is None else self.lnorm_in_q(x_q, ada_ln_aux)
Expand Down Expand Up @@ -362,6 +488,14 @@ def forward(self, x_q, x_kv, x_q_lens=None, x_kv_lens=None, ada_ln_aux=None):


class MultiCrossAttentionHeadVarlenSlicedQ(torch.nn.Module):
"""Multi-head cross-attention with variable length sequences and sliced queries.

This module implements a memory-efficient variant of multi-head cross-attention where the
query projection is sliced into chunks. This allows processing extremely large query sets
(e.g., global queries against local latents) by computing attention for subsets of queries
sequentially. This approach reduces peak memory usage significantly, enabling the model to
scale to higher resolutions or larger query counts."""

def __init__(
self,
dim_embed_q,
Expand All @@ -379,6 +513,23 @@ def __init__(
norm_eps=1e-5,
attention_dtype=torch.bfloat16,
):
"""Initialize the MultiCrossAttentionHeadVarlenSlicedQ module.

:param dim_embed_q: Embedding dimension of the query.
:param dim_embed_kv: Embedding dimension of the key and value.
:param num_slices_q: Number of slices for the query.
:param num_heads: Number of attention heads.
:param dim_head_proj: Dimension of the projection head.
:param dropout_rate: Dropout rate.
:param with_residual: Whether to use residual connections.
:param with_qk_lnorm: Whether to use layer normalization for query and key.
:param with_flash: Whether to use flash attention.
:param norm_type: Type of normalization.
:param softcap: Softcap for attention.
:param dim_aux: Dimension of the auxiliary data.
:param norm_eps: Epsilon for normalization.
:param attention_dtype: Data type for attention.
"""
super(MultiCrossAttentionHeadVarlenSlicedQ, self).__init__()

self.num_slices_q = num_slices_q
Expand Down Expand Up @@ -428,6 +579,16 @@ def __init__(
assert with_flash, "Only flash attention supported at the moment"

def forward(self, x_q, x_kv, x_q_lens=None, x_kv_lens=None, ada_ln_aux=None):
"""Forward pass of the MultiCrossAttentionHeadVarlenSlicedQ module.

:param x_q: Query tensor.
:param x_kv: Key and value tensor.
:param x_q_lens: Lengths of the query sequences.
:param x_kv_lens: Lengths of the key and value sequences.
:param ada_ln_aux: Auxiliary data for adaptive layer normalization.

:return outs: Output tensors.
"""
if self.with_residual:
x_q_in = x_q
x_q = self.lnorm_in_q(x_q) if ada_ln_aux is None else self.lnorm_in_q(x_q, ada_ln_aux)
Expand Down Expand Up @@ -473,6 +634,8 @@ def forward(self, x_q, x_kv, x_q_lens=None, x_kv_lens=None, ada_ln_aux=None):


class MultiSelfAttentionHead(torch.nn.Module):
"""Multi-head self-attention."""

def __init__(
self,
dim_embed,
Expand All @@ -488,6 +651,21 @@ def __init__(
norm_eps=1e-5,
attention_dtype=torch.bfloat16,
):
"""Initialize the MultiSelfAttentionHead module.

:param dim_embed: Embedding dimension.
:param num_heads: Number of attention heads.
:param dim_head_proj: Dimension of the projection head.
:param dropout_rate: Dropout rate.
:param with_residual: Whether to use residual connections.
:param with_qk_lnorm: Whether to use layer normalization for query and key.
:param with_flash: Whether to use flash attention.
:param softcap: Softcap for attention.
:param norm_type: Type of normalization.
:param dim_aux: Dimension of the auxiliary data.
:param norm_eps: Epsilon for normalization.
:param attention_dtype: Data type for attention.
"""
super(MultiSelfAttentionHead, self).__init__()

self.num_heads = num_heads
Expand Down Expand Up @@ -528,6 +706,13 @@ def __init__(
self.softmax = torch.nn.Softmax(dim=-1)

def forward(self, x, ada_ln_aux=None):
"""Forward pass of the MultiSelfAttentionHead module.

:param x: Input tensor.
:param ada_ln_aux: Auxiliary data for adaptive layer normalization.

:return out: Output tensor.
"""
if self.with_residual:
x_in = x
x = self.lnorm(x) if ada_ln_aux is None else self.lnorm(x, ada_ln_aux)
Expand All @@ -553,6 +738,8 @@ def forward(self, x, ada_ln_aux=None):


class MultiCrossAttentionHead(torch.nn.Module):
"""Multi-head cross-attention."""

def __init__(
self,
dim_embed_q,
Expand All @@ -567,6 +754,20 @@ def __init__(
norm_eps=1e-5,
attention_dtype=torch.bfloat16,
):
"""Initialize the MultiCrossAttentionHead module.

:param dim_embed_q: Embedding dimension of the query.
:param dim_embed_kv: Embedding dimension of the key and value.
:param num_heads: Number of attention heads.
:param dim_head_proj: Dimension of the projection head.
:param dropout_rate: Dropout rate.
:param with_residual: Whether to use residual connections.
:param with_qk_lnorm: Whether to use layer normalization for query and key.
:param with_flash: Whether to use flash attention.
:param norm_type: Type of normalization.
:param norm_eps: Epsilon for normalization.
:param attention_dtype: Data type for attention.
"""
super(MultiCrossAttentionHead, self).__init__()

self.num_heads = num_heads
Expand Down Expand Up @@ -607,6 +808,13 @@ def __init__(

#########################################
def forward(self, x_q, x_kv):
"""Forward pass of the MultiCrossAttentionHead module.

:param x_q: Query tensor.
:param x_kv: Key and value tensor.

:return outs: Output tensors.
"""
if self.with_residual:
x_q_in = x_q
x_q, x_kv = self.lnorm_in_q(x_q), self.lnorm_in_kv(x_kv)
Expand Down
Loading
Loading