2929
3030
3131class EmbeddingEngine :
32+ name : "EmbeddingEngine"
33+
3234 def __init__ (self , cf : Config , sources_size ) -> None :
3335 """
3436 Initialize the EmbeddingEngine with the configuration.
@@ -47,6 +49,8 @@ def create(self) -> torch.nn.ModuleList:
4749 :return: torch.nn.ModuleList containing the embedding layers.
4850 """
4951 for i , si in enumerate (self .cf .streams ):
52+ stream_name = si .get ("name" , i )
53+
5054 if "diagnostic" in si and si ["diagnostic" ]:
5155 self .embeds .append (torch .nn .Identity ())
5256 continue
@@ -66,12 +70,15 @@ def create(self) -> torch.nn.ModuleList:
6670 norm_type = self .cf .norm_type ,
6771 embed_size_centroids = self .cf .embed_size_centroids ,
6872 unembed_mode = self .cf .embed_unembed_mode ,
73+ stream_name = stream_name ,
6974 )
7075 )
7176 elif si ["embed" ]["net" ] == "linear" :
7277 self .embeds .append (
7378 StreamEmbedLinear (
74- self .sources_size [i ] * si ["token_size" ], self .cf .ae_local_dim_embed
79+ self .sources_size [i ] * si ["token_size" ],
80+ self .cf .ae_local_dim_embed ,
81+ stream_name = stream_name ,
7582 )
7683 )
7784 else :
@@ -80,6 +87,8 @@ def create(self) -> torch.nn.ModuleList:
8087
8188
8289class LocalAssimilationEngine :
90+ name : "LocalAssimilationEngine"
91+
8392 def __init__ (self , cf : Config ) -> None :
8493 """
8594 Initialize the LocalAssimilationEngine with the configuration.
@@ -122,6 +131,8 @@ def create(self) -> torch.nn.ModuleList:
122131
123132
124133class Local2GlobalAssimilationEngine :
134+ name : "Local2GlobalAssimilationEngine"
135+
125136 def __init__ (self , cf : Config ) -> None :
126137 """
127138 Initialize the Local2GlobalAssimilationEngine with the configuration.
@@ -183,6 +194,8 @@ def create(self) -> torch.nn.ModuleList:
183194
184195
185196class GlobalAssimilationEngine :
197+ name : "GlobalAssimilationEngine"
198+
186199 def __init__ (self , cf : Config , num_healpix_cells : int ) -> None :
187200 """
188201 Initialize the GlobalAssimilationEngine with the configuration.
@@ -250,6 +263,8 @@ def create(self) -> torch.nn.ModuleList:
250263
251264
252265class ForecastingEngine :
266+ name : "ForecastingEngine"
267+
253268 def __init__ (self , cf : Config , num_healpix_cells : int ) -> None :
254269 """
255270 Initialize the ForecastingEngine with the configuration.
@@ -327,13 +342,13 @@ def init_weights_final(m):
327342
328343
329344class EnsPredictionHead (torch .nn .Module ):
330- #########################################
331345 def __init__ (
332346 self ,
333347 dim_embed ,
334348 dim_out ,
335349 ens_num_layers ,
336350 ens_size ,
351+ stream_name : str ,
337352 norm_type = "LayerNorm" ,
338353 hidden_factor = 2 ,
339354 final_activation : None | str = None ,
@@ -342,6 +357,8 @@ def __init__(
342357
343358 super (EnsPredictionHead , self ).__init__ ()
344359
360+ self .name = f"EnsPredictionHead_{ stream_name } "
361+
345362 dim_internal = dim_embed * hidden_factor
346363 # norm = torch.nn.LayerNorm if norm_type == "LayerNorm" else RMSNorm
347364 enl = ens_num_layers
@@ -390,6 +407,7 @@ def __init__(
390407 tr_mlp_hidden_factor ,
391408 softcap ,
392409 tro_type ,
410+ stream_name : str ,
393411 ):
394412 """
395413 Initialize the TargetPredictionEngine with the configuration.
@@ -403,6 +421,7 @@ def __init__(
403421 :param tro_type: Type of target readout (e.g., "obs_value").
404422 """
405423 super (TargetPredictionEngineClassic , self ).__init__ ()
424+ self .name = f"TargetPredictionEngine_{ stream_name } "
406425
407426 self .cf = cf
408427 self .dims_embed = dims_embed
@@ -496,6 +515,7 @@ def __init__(
496515 tr_mlp_hidden_factor ,
497516 softcap ,
498517 tro_type ,
518+ stream_name : str ,
499519 ):
500520 """
501521 Initialize the TargetPredictionEngine with the configuration.
@@ -519,6 +539,7 @@ def __init__(
519539 LayerNorm that does not scale after the layer is applied
520540 """
521541 super (TargetPredictionEngine , self ).__init__ ()
542+ self .name = f"TargetPredictionEngine_{ stream_name } "
522543
523544 self .cf = cf
524545 self .dims_embed = dims_embed
0 commit comments