diff --git a/packages/common/src/weathergen/common/config.py b/packages/common/src/weathergen/common/config.py index c32732ba7..1b9e1928d 100644 --- a/packages/common/src/weathergen/common/config.py +++ b/packages/common/src/weathergen/common/config.py @@ -225,7 +225,7 @@ def load_config( # use OmegaConf.unsafe_merge if too slow c = OmegaConf.merge(base_config, private_config, *overwrite_configs) assert isinstance(c, Config) - + # Ensure the config has mini-epoch notation if hasattr(c, "samples_per_epoch"): c.samples_per_mini_epoch = c.samples_per_epoch diff --git a/packages/dashboard/atmo_eval.py b/packages/dashboard/atmo_eval.py index a98b32268..3dc077f6e 100644 --- a/packages/dashboard/atmo_eval.py +++ b/packages/dashboard/atmo_eval.py @@ -77,7 +77,9 @@ def get_score_step_48h(score_col: str) -> pl.DataFrame: .sort("start_time") .filter(pl.col(score_col).is_not_null()) ) - _logger.info(f"Getting score data for {score_col} at 48h (step={step_48h}): len={len(score_data)}") + _logger.info( + f"Getting score data for {score_col} at 48h (step={step_48h}): len={len(score_data)}" + ) # Iterate over the runs to get the metric at step 48h scores_dt: list[float | None] = [] diff --git a/src/weathergen/model/attention.py b/src/weathergen/model/attention.py index 39ed1c041..195889a0a 100644 --- a/src/weathergen/model/attention.py +++ b/src/weathergen/model/attention.py @@ -17,6 +17,14 @@ class MultiSelfAttentionHeadVarlen(torch.nn.Module): + """Multi-head self-attention with variable length sequences. + + This module implements multi-head self-attention for variable length sequences packed into a + single tensor. It leverages FlashAttention's variable length API (`flash_attn_varlen_func`) + to efficiently handle batches of sequences with differing lengths without padding, using + cumulative length indices to define sequence boundaries. + """ + def __init__( self, dim_embed, @@ -32,6 +40,21 @@ def __init__( norm_eps=1e-5, attention_dtype=torch.bfloat16, ): + """Initialize the MultiSelfAttentionHeadVarlen module. + + :param dim_embed: Embedding dimension. + :param num_heads: Number of attention heads. + :param dim_head_proj: Dimension of the projection head. + :param dropout_rate: Dropout rate. + :param with_residual: Whether to use residual connections. + :param with_qk_lnorm: Whether to use layer normalization for query and key. + :param with_flash: Whether to use flash attention. + :param norm_type: Type of normalization. + :param softcap: Softcap for attention. + :param dim_aux: Dimension of auxiliary data. + :param norm_eps: Epsilon for normalization. + :param attention_dtype: Data type for attention. + """ super(MultiSelfAttentionHeadVarlen, self).__init__() self.num_heads = num_heads @@ -69,6 +92,14 @@ def __init__( assert with_flash, "Only flash attention supported at the moment" def forward(self, x, x_lens, ada_ln_aux=None): + """Forward pass of the MultiSelfAttentionHeadVarlen module. + + :param x: Input tensor. + :param x_lens: Lengths of the input sequences. + :param ada_ln_aux: Auxiliary data for adaptive layer normalization. + + :return out: Output tensor. + """ if self.with_residual: x_in = x x = self.lnorm(x) if ada_ln_aux is None else self.lnorm(x, ada_ln_aux) @@ -106,6 +137,14 @@ def forward(self, x, x_lens, ada_ln_aux=None): class MultiSelfAttentionHeadVarlenFlex(torch.nn.Module): + """Multi-head self-attention with variable length sequences and flex attention. + + This module implements multi-head self-attention using PyTorch's FlexAttention. It allows + for defining custom sparse attention patterns via a score modification function. This is + particularly useful for optimizing attention mechanisms where full NxN interactions are not + required or desired, enabling flexible and efficient attention computations. + """ + def __init__( self, dim_embed, @@ -120,6 +159,20 @@ def __init__( norm_eps=1e-5, attention_dtype=torch.bfloat16, ): + """Initialize the MultiSelfAttentionHeadVarlenFlex module. + + :param dim_embed: Embedding dimension. + :param num_heads: Number of attention heads. + :param dim_head_proj: Dimension of the projection head. + :param dropout_rate: Dropout rate. + :param with_residual: Whether to use residual connections. + :param with_qk_lnorm: Whether to use layer normalization for query and key. + :param with_flash: Whether to use flash attention. + :param norm_type: Type of normalization. + :param softcap: Softcap for attention. + :param norm_eps: Epsilon for normalization. + :param attention_dtype: Data type for attention. + """ super(MultiSelfAttentionHeadVarlenFlex, self).__init__() self.num_heads = num_heads @@ -160,6 +213,13 @@ def sparsity_mask(score, b, h, q_idx, kv_idx): self.compiled_flex_attention = torch.compile(att, dynamic=False) def forward(self, x, x_lens=None): + """Forward pass of the MultiSelfAttentionHeadVarlenFlex module. + + :param x: Input tensor. + :param x_lens: Lengths of the input sequences. + + :return out: Output tensor. + """ if self.with_residual: x_in = x x = self.lnorm(x) @@ -181,6 +241,14 @@ def forward(self, x, x_lens=None): class MultiSelfAttentionHeadLocal(torch.nn.Module): + """Multi-head self-attention with local (block-wise) attention. + + This module implements local (block-wise) multi-head self-attention. It restricts attention + to local blocks defined by `block_factor`, meaning tokens only attend to other tokens within + the same block. This effectively reduces the computational complexity from quadratic to + linear with respect to the sequence length (for a fixed block size), making it suitable for + processing long sequences where local interactions dominate.""" + def __init__( self, dim_embed, @@ -198,6 +266,23 @@ def __init__( norm_eps=1e-5, attention_dtype=torch.bfloat16, ): + """Initialize the MultiSelfAttentionHeadLocal module. + + :param dim_embed: Embedding dimension. + :param num_heads: Number of attention heads. + :param qkv_len: Length of the query, key and value. + :param block_factor: Block factor. + :param dim_head_proj: Dimension of the projection head. + :param dropout_rate: Dropout rate. + :param with_residual: Whether to use residual connections. + :param with_qk_lnorm: Whether to use layer normalization for query and key. + :param with_flash: Whether to use flash attention. + :param norm_type: Type of normalization. + :param softcap: Softcap for attention. + :param dim_aux: Dimension of the auxiliary data. + :param norm_eps: Epsilon for normalization. + :param attention_dtype: Data type for attention. + """ super(MultiSelfAttentionHeadLocal, self).__init__() self.num_heads = num_heads @@ -243,6 +328,13 @@ def mask_block_local(batch, head, idx_q, idx_kv): self.flex_attention = torch.compile(flex_attention, dynamic=False) def forward(self, x, ada_ln_aux=None): + """Forward pass of the MultiSelfAttentionHeadLocal module. + + :param x: Input tensor. + :param ada_ln_aux: Auxiliary data for adaptive layer normalization. + + :return out: Output tensor. + """ if self.with_residual: x_in = x x = self.lnorm(x) if ada_ln_aux is None else self.lnorm(x, ada_ln_aux) @@ -263,6 +355,14 @@ def forward(self, x, ada_ln_aux=None): class MultiCrossAttentionHeadVarlen(torch.nn.Module): + """Multi-head cross-attention with variable length sequences. + + This module implements multi-head cross-attention for variable length sequences. Similar to + the self-attention variant, it uses FlashAttention (`flash_attn_varlen_func`) to handle + packed sequences of queries and keys/values with different lengths. It ensures correct masking + and efficient computation for cases where both source and target sequences vary in length. + """ + def __init__( self, dim_embed_q, @@ -279,6 +379,22 @@ def __init__( norm_eps=1e-5, attention_dtype=torch.bfloat16, ): + """Initialize the MultiCrossAttentionHeadVarlen module. + + :param dim_embed_q: Embedding dimension of the query. + :param dim_embed_kv: Embedding dimension of the key and value. + :param num_heads: Number of attention heads. + :param dim_head_proj: Dimension of the projection head. + :param dropout_rate: Dropout rate. + :param with_residual: Whether to use residual connections. + :param with_qk_lnorm: Whether to use layer normalization for query and key. + :param with_flash: Whether to use flash attention. + :param norm_type: Type of normalization. + :param softcap: Softcap for attention. + :param dim_aux: Dimension of the auxiliary data. + :param norm_eps: Epsilon for normalization. + :param attention_dtype: Data type for attention. + """ super(MultiCrossAttentionHeadVarlen, self).__init__() self.num_heads = num_heads @@ -321,6 +437,16 @@ def __init__( assert with_flash, "Only flash attention supported at the moment" def forward(self, x_q, x_kv, x_q_lens=None, x_kv_lens=None, ada_ln_aux=None): + """Forward pass of the MultiCrossAttentionHeadVarlen module. + + :param x_q: Query tensor. + :param x_kv: Key and value tensor. + :param x_q_lens: Lengths of the query sequences. + :param x_kv_lens: Lengths of the key and value sequences. + :param ada_ln_aux: Auxiliary data for adaptive layer normalization. + + :return outs: Output tensors. + """ if self.with_residual: x_q_in = x_q x_q = self.lnorm_in_q(x_q) if ada_ln_aux is None else self.lnorm_in_q(x_q, ada_ln_aux) @@ -362,6 +488,14 @@ def forward(self, x_q, x_kv, x_q_lens=None, x_kv_lens=None, ada_ln_aux=None): class MultiCrossAttentionHeadVarlenSlicedQ(torch.nn.Module): + """Multi-head cross-attention with variable length sequences and sliced queries. + + This module implements a memory-efficient variant of multi-head cross-attention where the + query projection is sliced into chunks. This allows processing extremely large query sets + (e.g., global queries against local latents) by computing attention for subsets of queries + sequentially. This approach reduces peak memory usage significantly, enabling the model to + scale to higher resolutions or larger query counts.""" + def __init__( self, dim_embed_q, @@ -379,6 +513,23 @@ def __init__( norm_eps=1e-5, attention_dtype=torch.bfloat16, ): + """Initialize the MultiCrossAttentionHeadVarlenSlicedQ module. + + :param dim_embed_q: Embedding dimension of the query. + :param dim_embed_kv: Embedding dimension of the key and value. + :param num_slices_q: Number of slices for the query. + :param num_heads: Number of attention heads. + :param dim_head_proj: Dimension of the projection head. + :param dropout_rate: Dropout rate. + :param with_residual: Whether to use residual connections. + :param with_qk_lnorm: Whether to use layer normalization for query and key. + :param with_flash: Whether to use flash attention. + :param norm_type: Type of normalization. + :param softcap: Softcap for attention. + :param dim_aux: Dimension of the auxiliary data. + :param norm_eps: Epsilon for normalization. + :param attention_dtype: Data type for attention. + """ super(MultiCrossAttentionHeadVarlenSlicedQ, self).__init__() self.num_slices_q = num_slices_q @@ -428,6 +579,16 @@ def __init__( assert with_flash, "Only flash attention supported at the moment" def forward(self, x_q, x_kv, x_q_lens=None, x_kv_lens=None, ada_ln_aux=None): + """Forward pass of the MultiCrossAttentionHeadVarlenSlicedQ module. + + :param x_q: Query tensor. + :param x_kv: Key and value tensor. + :param x_q_lens: Lengths of the query sequences. + :param x_kv_lens: Lengths of the key and value sequences. + :param ada_ln_aux: Auxiliary data for adaptive layer normalization. + + :return outs: Output tensors. + """ if self.with_residual: x_q_in = x_q x_q = self.lnorm_in_q(x_q) if ada_ln_aux is None else self.lnorm_in_q(x_q, ada_ln_aux) @@ -473,6 +634,8 @@ def forward(self, x_q, x_kv, x_q_lens=None, x_kv_lens=None, ada_ln_aux=None): class MultiSelfAttentionHead(torch.nn.Module): + """Multi-head self-attention.""" + def __init__( self, dim_embed, @@ -488,6 +651,21 @@ def __init__( norm_eps=1e-5, attention_dtype=torch.bfloat16, ): + """Initialize the MultiSelfAttentionHead module. + + :param dim_embed: Embedding dimension. + :param num_heads: Number of attention heads. + :param dim_head_proj: Dimension of the projection head. + :param dropout_rate: Dropout rate. + :param with_residual: Whether to use residual connections. + :param with_qk_lnorm: Whether to use layer normalization for query and key. + :param with_flash: Whether to use flash attention. + :param softcap: Softcap for attention. + :param norm_type: Type of normalization. + :param dim_aux: Dimension of the auxiliary data. + :param norm_eps: Epsilon for normalization. + :param attention_dtype: Data type for attention. + """ super(MultiSelfAttentionHead, self).__init__() self.num_heads = num_heads @@ -528,6 +706,13 @@ def __init__( self.softmax = torch.nn.Softmax(dim=-1) def forward(self, x, ada_ln_aux=None): + """Forward pass of the MultiSelfAttentionHead module. + + :param x: Input tensor. + :param ada_ln_aux: Auxiliary data for adaptive layer normalization. + + :return out: Output tensor. + """ if self.with_residual: x_in = x x = self.lnorm(x) if ada_ln_aux is None else self.lnorm(x, ada_ln_aux) @@ -553,6 +738,8 @@ def forward(self, x, ada_ln_aux=None): class MultiCrossAttentionHead(torch.nn.Module): + """Multi-head cross-attention.""" + def __init__( self, dim_embed_q, @@ -567,6 +754,20 @@ def __init__( norm_eps=1e-5, attention_dtype=torch.bfloat16, ): + """Initialize the MultiCrossAttentionHead module. + + :param dim_embed_q: Embedding dimension of the query. + :param dim_embed_kv: Embedding dimension of the key and value. + :param num_heads: Number of attention heads. + :param dim_head_proj: Dimension of the projection head. + :param dropout_rate: Dropout rate. + :param with_residual: Whether to use residual connections. + :param with_qk_lnorm: Whether to use layer normalization for query and key. + :param with_flash: Whether to use flash attention. + :param norm_type: Type of normalization. + :param norm_eps: Epsilon for normalization. + :param attention_dtype: Data type for attention. + """ super(MultiCrossAttentionHead, self).__init__() self.num_heads = num_heads @@ -607,6 +808,13 @@ def __init__( ######################################### def forward(self, x_q, x_kv): + """Forward pass of the MultiCrossAttentionHead module. + + :param x_q: Query tensor. + :param x_kv: Key and value tensor. + + :return outs: Output tensors. + """ if self.with_residual: x_q_in = x_q x_q, x_kv = self.lnorm_in_q(x_q), self.lnorm_in_kv(x_kv) diff --git a/src/weathergen/model/engines.py b/src/weathergen/model/engines.py index 7359d1403..c4930e291 100644 --- a/src/weathergen/model/engines.py +++ b/src/weathergen/model/engines.py @@ -30,18 +30,19 @@ class EmbeddingEngine(torch.nn.Module): + """Embedding engine for the model.""" + name: "EmbeddingEngine" def __init__(self, cf: Config, sources_size) -> None: - """ - Initialize the EmbeddingEngine with the configuration. + """Initialize the EmbeddingEngine with the configuration. :param cf: Configuration object containing parameters for the engine. - :param sources_size: List of source sizes for each stream. + :param sources_size: Tensor of number of channels for each stream """ super(EmbeddingEngine, self).__init__() self.cf = cf - self.sources_size = sources_size # KCT:iss130, what is this? + self.sources_size = sources_size self.embeds = torch.nn.ModuleList() for i, si in enumerate(self.cf.streams): @@ -81,6 +82,15 @@ def __init__(self, cf: Config, sources_size) -> None: raise ValueError("Unsupported embedding network type") def forward(self, streams_data, pe_embed, dtype, device): + """Forward pass of the embedding engine. + + :param streams_data: Tensor of streams data. + :param pe_embed: Positional encoding embeddings. + :param dtype: Data type for the embeddings. + :param device: Device to run the embeddings on. + + :return tokens_all: Embedded tokens. + """ source_tokens_lens = torch.stack( [ torch.stack( @@ -126,11 +136,20 @@ def forward(self, streams_data, pe_embed, dtype, device): class LocalAssimilationEngine(torch.nn.Module): + """Local assimilation engine for the model. + + The LocalAssimilationEngine is responsible for fusing information from different input + streams (e.g., satellite, station data) within each HEALPix cell. It operates locally, + meaning attention is computed only among tokens belonging to the same cell. This step + aggregates high-resolution, heterogeneous input data into a unified cell-level + representation before global interaction takes place. It uses a sequence of self-attention + blocks and MLPs. + """ + name: "LocalAssimilationEngine" def __init__(self, cf: Config) -> None: - """ - Initialize the LocalAssimilationEngine with the configuration. + """Initialize the LocalAssimilationEngine with the configuration. :param cf: Configuration object containing parameters for the engine. """ @@ -163,17 +182,26 @@ def __init__(self, cf: Config) -> None: ) def forward(self, tokens_c, cell_lens_c, use_reentrant): + """Forward pass of the local assimilation engine. + + :param tokens_c: Tokens to be assimilated. + :param cell_lens_c: Cell lengths for the tokens. + :param use_reentrant: Whether to use reentrant mode. + + :return tokens_c: Assimilated tokens. + """ for block in self.ae_local_blocks: tokens_c = checkpoint(block, tokens_c, cell_lens_c, use_reentrant=use_reentrant) return tokens_c class Local2GlobalAssimilationEngine(torch.nn.Module): + """Local2GlobalAssimilationEngine for the model.""" + name: "Local2GlobalAssimilationEngine" def __init__(self, cf: Config) -> None: - """ - Initialize the Local2GlobalAssimilationEngine with the configuration. + """Initialize the Local2GlobalAssimilationEngine with the configuration. :param cf: Configuration object containing parameters for the engine. """ @@ -225,6 +253,16 @@ def __init__(self, cf: Config) -> None: ) def forward(self, tokens_c, tokens_global_c, q_cells_lens_c, cell_lens_c, use_reentrant): + """Forward pass of the local to global assimilation engine. + + :param tokens_c: Tokens to be assimilated. + :param tokens_global_c: Global tokens to be assimilated. + :param q_cells_lens_c: Query cell lengths for the tokens. + :param cell_lens_c: Cell lengths for the tokens. + :param use_reentrant: Whether to use reentrant mode. + + :return tokens_global_c: Assimilated tokens. + """ for block in self.ae_adapter: tokens_global_c = checkpoint( block, @@ -238,11 +276,20 @@ def forward(self, tokens_c, tokens_global_c, q_cells_lens_c, cell_lens_c, use_re class GlobalAssimilationEngine(torch.nn.Module): + """Global assimilation engine for the model. + + The GlobalAssimilationEngine processes the unified cell-level representations generated by + the LocalAssimilationEngine. Its primary role is to model long-range dependencies and + physical interactions across the entire globe. It alternates between local attention + (focusing on neighboring cells) and global attention (fully connected or sparse global + patterns) to efficiently propagate information. This engine transforms the local latents + into a globally consistent state representation. + """ + name: "GlobalAssimilationEngine" def __init__(self, cf: Config, num_healpix_cells: int) -> None: - """ - Initialize the GlobalAssimilationEngine with the configuration. + """Initialize the GlobalAssimilationEngine with the configuration. :param cf: Configuration object containing parameters for the engine. :param num_healpix_cells: Number of healpix cells used for local queries. @@ -300,17 +347,25 @@ def __init__(self, cf: Config, num_healpix_cells: int) -> None: ) def forward(self, tokens, use_reentrant): + """Forward pass of the global assimilation engine. + + :param tokens: Tokens to be assimilated. + :param use_reentrant: Whether to use reentrant mode. + + :return tokens: Assimilated tokens. + """ for block in self.ae_global_blocks: tokens = checkpoint(block, tokens, use_reentrant=use_reentrant) return tokens class ForecastingEngine(torch.nn.Module): + """Forecasting engine for the model.""" + name: "ForecastingEngine" def __init__(self, cf: Config, num_healpix_cells: int) -> None: - """ - Initialize the ForecastingEngine with the configuration. + """Initialize the ForecastingEngine with the configuration. :param cf: Configuration object containing parameters for the engine. :param num_healpix_cells: Number of healpix cells used for local queries. @@ -368,6 +423,7 @@ def __init__(self, cf: Config, num_healpix_cells: int) -> None: ) def init_weights_final(m): + """Initialize the weights of the forecasting engine.""" if isinstance(m, torch.nn.Linear): torch.nn.init.normal_(m.weight, mean=0, std=0.001) if m.bias is not None: @@ -377,6 +433,13 @@ def init_weights_final(m): block.apply(init_weights_final) def forward(self, tokens, fstep): + """Forward pass of the forecasting engine. + + :param tokens: Tokens to be forecasted. + :param fstep: Forecast step. + + :return tokens: Forecasted tokens. + """ aux_info = torch.tensor([fstep], dtype=torch.float32, device="cuda") for block in self.fe_blocks: tokens = checkpoint(block, tokens, aux_info, use_reentrant=False) @@ -385,6 +448,8 @@ def forward(self, tokens, fstep): class EnsPredictionHead(torch.nn.Module): + """Ensemble prediction head for the model.""" + def __init__( self, dim_embed, @@ -396,7 +461,17 @@ def __init__( hidden_factor=2, final_activation: None | str = None, ): - """Constructor""" + """Initialize the EnsPredictionHead with the configuration. + + :param dim_embed: Dimension of the embedding. + :param dim_out: Dimension of the output. + :param ens_num_layers: Number of layers in the ensemble. + :param ens_size: Size of the ensemble. + :param stream_name: Name of the stream. + :param norm_type: Type of normalization. + :param hidden_factor: Hidden factor to create an internal dimension. + :param final_activation: Optional final activation function. + """ super(EnsPredictionHead, self).__init__() @@ -428,6 +503,12 @@ def __init__( ######################################### def forward(self, toks): + """Forward pass of the EnsPredictionHead. + + :param toks: Tokens to be predicted. + + :return preds: Ensemble predictions. + """ preds = [] for pred_head in self.pred_heads: cpred = toks @@ -440,6 +521,16 @@ def forward(self, toks): class TargetPredictionEngineClassic(nn.Module): + """Target prediction engine for the model. + + The TargetPredictionEngineClassic is a specialized decoding module that projects the global + latent states back to specific target coordinates (e.g., station locations). It typically + employs a PerceiverIO-style architecture where target coordinate embeddings query the + latent state via cross-attention. This engine is "Classic" in the sense that it strictly + follows the original design with coordinate conditioning and optional self-attention, + without the flexible decoder types found in the newer `TargetPredictionEngine`. + """ + def __init__( self, cf, @@ -451,11 +542,10 @@ def __init__( tro_type, stream_name: str, ): - """ - Initialize the TargetPredictionEngine with the configuration. + """Initialize the TargetPredictionEngine with the configuration. :param cf: Configuration object containing parameters for the engine. - :param dims_embed: List of embedding dimensions for each layer. + :param dims_embed: Tensor of embedding dimensions for each layer. :param dim_coord_in: Input dimension for coordinates. :param tr_dim_head_proj: Dimension for head projection. :param tr_mlp_hidden_factor: Hidden factor for the MLP layers. @@ -525,6 +615,16 @@ def __init__( ) def forward(self, latent, output, latent_lens, output_lens, coordinates): + """Forward pass of the TargetPredictionEngineClassic. + + :param latent: Latent tokens. + :param output: Output tokens. + :param latent_lens: Lengths of the latent tokens. + :param output_lens: Lengths of the output tokens. + :param coordinates: Target coordinates for auxiliary information. + + :returns tc_tokens: Output tokens. + """ tc_tokens = output tcs_lens = output_lens tokens_stream = latent @@ -548,6 +648,17 @@ def forward(self, latent, output, latent_lens, output_lens, coordinates): class TargetPredictionEngine(nn.Module): + """TargetPredictionEngine for the model. + + The TargetPredictionEngine handles the decoding of the latent representation into the target + observational space. Unlike the Classic version which solely relies on a fixed + PerceiverIO-like structure with coordinate conditioning, this engine is configurable via + `decoder_type`. It supports various conditioning mechanisms, allowing for experimentation + with how the latent state and auxiliary information (like coordinates) are fused to generate + predictions. It includes normalization, optional positional embeddings and a flexible + sequence of decoding blocks. + """ + def __init__( self, cf, @@ -559,11 +670,10 @@ def __init__( tro_type, stream_name: str, ): - """ - Initialize the TargetPredictionEngine with the configuration. + """Initialize the TargetPredictionEngine with the configuration. :param cf: Configuration object containing parameters for the engine. - :param dims_embed: List of embedding dimensions for each layer. + :param dims_embed: Tensor of embedding dimensions for each layer. :param dim_coord_in: Input dimension for coordinates. :param tr_dim_head_proj: Dimension for head projection. :param tr_mlp_hidden_factor: Hidden factor for the MLP layers. @@ -692,6 +802,16 @@ def __init__( ) def forward(self, latent, output, latent_lens, output_lens, coordinates): + """Forward pass of the TargetPredictionEngine. + + :param latent: Latent tokens. + :param output: Output tokens. + :param latent_lens: Lengths of the latent tokens. + :param output_lens: Lengths of the output tokens. + :param coordinates: Target coordinates for auxiliary information. + + :return output: Output tokens. + """ latent = ( self.dropout(self.latent_in_norm(latent + self.pos_embed)) if self.cf.decoder_type != "PerceiverIOCoordConditioning" diff --git a/src/weathergen/utils/validation_io.py b/src/weathergen/utils/validation_io.py index 355be0e51..f8a5a1cc5 100644 --- a/src/weathergen/utils/validation_io.py +++ b/src/weathergen/utils/validation_io.py @@ -30,7 +30,7 @@ def write_output( sample_idxs, ): stream_names = [stream.name for stream in cf.streams] - analysis_streams_output = cf.get( 'analysis_streams_output', None) + analysis_streams_output = cf.get("analysis_streams_output", None) if cf.streams_output is not None: output_stream_names = cf.streams_output elif analysis_streams_output is not None: # --- to be removed at some point ---