@@ -68,27 +68,36 @@ def __init__(
6868 self .standardize_layers = {var : Standardization (trainable = False ) for var in self .standardize }
6969
7070 def build (self , data_shapes : dict [str , tuple [int ] | dict [str , dict ]]) -> None :
71+ # Build summary network and get output shape if present
7172 summary_outputs_shape = None
72- inference_conditions_shape = data_shapes .get ("inference_conditions" , None )
7373 if self .summary_network is not None :
7474 if not self .summary_network .built :
7575 self .summary_network .build (data_shapes ["summary_variables" ])
7676 summary_outputs_shape = self .summary_network .compute_output_shape (data_shapes ["summary_variables" ])
77+
78+ # Compute inference_conditions_shape by combining original and summary outputs
7779 inference_conditions_shape = concatenate_valid_shapes (
78- [inference_conditions_shape , summary_outputs_shape ], axis = - 1
80+ [data_shapes . get ( "inference_conditions" ) , summary_outputs_shape ], axis = - 1
7981 )
82+
83+ # Build inference network if needed
8084 if not self .inference_network .built :
8185 self .inference_network .build (data_shapes ["inference_variables" ], inference_conditions_shape )
86+
87+ # Set up standardization layers if requested
8288 if self .standardize == "all" :
89+ # Only include variables present in data_shapes
8390 self .standardize = [
8491 var
8592 for var in ["inference_variables" , "summary_variables" , "inference_conditions" ]
8693 if var in data_shapes
8794 ]
88-
8995 self .standardize_layers = {var : Standardization (trainable = False ) for var in self .standardize }
90- for var , layer in self .standardize_layers .items ():
96+
97+ # Build all standardization layers, if present
98+ for var , layer in getattr (self , "standardize_layers" , {}).items ():
9199 layer .build (data_shapes [var ])
100+
92101 self .built = True
93102
94103 @classmethod
0 commit comments