@@ -49,17 +49,23 @@ def __init__(
4949 self .inference_network = inference_network
5050 self .summary_network = summary_network
5151
52- if standardize == "all" :
53- standardize = ["inference_variables" , "summary_variables" , "inference_conditions" ]
54- elif isinstance (standardize , str ):
55- standardize = [standardize ]
56- elif isinstance (standardize , Sequence ):
57- standardize = standardize
58- else :
59- standardize = []
52+ # if standardize == "all":
53+ # standardize = ["inference_variables", "summary_variables", "inference_conditions"]
54+ # elif isinstance(standardize, str):
55+ # standardize = [standardize]
56+ # elif isinstance(standardize, Sequence):
57+ # standardize = standardize
58+ # else:
59+ # standardize = []
6060
6161 self .standardize = standardize
62- self .standardize_layers = {s : Standardization () for s in standardize }
62+
63+ if standardize == "all" :
64+ # we have to lazily initialize these
65+ self .standardize_layers = None
66+ else :
67+ print ("eager init" )
68+ self .standardize_layers = {s : Standardization (trainable = False ) for s in self .standardize }
6369
6470 @classmethod
6571 def build_adapter (
@@ -121,7 +127,16 @@ def compile(
121127 return super ().compile (* args , ** kwargs )
122128
123129 def build_from_data (self , adapted_data : dict [str , any ]):
130+ if self .standardize == "all" :
131+ self .standardize = list (adapted_data .keys ())
132+ self .standardize = ["inference_variables" , "summary_variables" , "inference_conditions" ]
133+ self .standardize = list (filter (lambda x : x in adapted_data , self .standardize ))
134+
135+ if self .standardize_layers is None :
136+ self .standardize_layers = {s : Standardization (trainable = False ) for s in self .standardize }
137+
124138 self .compute_metrics (** filter_kwargs (adapted_data , self .compute_metrics ), stage = "training" )
139+
125140 self .built = True
126141
127142 def compile_from_config (self , config ):
@@ -207,11 +222,11 @@ def _compute_summary_metrics(self, summary_variables: Tensor | None, stage: str)
207222 summary_outputs = summary_metrics .pop ("outputs" )
208223 return summary_metrics , summary_outputs
209224
210- def _prepare_inference_variables (self , inference_variables , stage ) :
225+ def _prepare_inference_variables (self , inference_variables : Tensor , stage : str ) -> Tensor :
211226 """Helper function to convert inference variables to tensors and optionally standardize them."""
212- inference_variables = keras .tree .map_structure (keras .ops .convert_to_tensor , inference_variables )
213227 if "inference_variables" in self .standardize :
214228 inference_variables = self .standardize_layers ["inference_variables" ](inference_variables , stage = stage )
229+
215230 return inference_variables
216231
217232 def _combine_conditions (
0 commit comments