@@ -49,6 +49,7 @@ def __init__(
4949 self .inference_network = inference_network
5050 self .summary_network = summary_network
5151 self .standardize = standardize
52+
5253 self .inference_variables_norm = None
5354 self .summary_variables_norm = None
5455 self .inference_conditions_norm = None
@@ -59,7 +60,6 @@ def build_adapter(
5960 inference_variables : Sequence [str ],
6061 inference_conditions : Sequence [str ] = None ,
6162 summary_variables : Sequence [str ] = None ,
62- standardize : bool = True ,
6363 sample_weight : str = None ,
6464 ) -> Adapter :
6565 """Create an :py:class:`~bayesflow.adapters.Adapter` suited for the approximator.
@@ -72,8 +72,6 @@ def build_adapter(
7272 Names of the inference conditions in the data
7373 summary_variables : Sequence of str, optional
7474 Names of the summary variables in the data
75- standardize : bool, optional
76- Decide whether to standardize all variables, default is True
7775 sample_weight : str, optional
7876 Name of the sample weights
7977 """
@@ -95,9 +93,6 @@ def build_adapter(
9593
9694 adapter .keep (["inference_variables" , "inference_conditions" , "summary_variables" , "sample_weight" ])
9795
98- if standardize :
99- adapter .standardize (exclude = "sample_weight" )
100-
10196 return adapter
10297
10398 def compile (
@@ -118,7 +113,7 @@ def compile(
118113
119114 return super ().compile (* args , ** kwargs )
120115
121- def build_from_data (self , adapted_data : dict [str , any ]) -> None :
116+ def build_from_data (self , adapted_data : dict [str , any ]):
122117 # Determine input standardization
123118 if self .standardize == "all" :
124119 keys = ["inference_variables" , "summary_variables" , "inference_conditions" ]
@@ -129,13 +124,15 @@ def build_from_data(self, adapted_data: dict[str, any]) -> None:
129124 else :
130125 keys = []
131126
132- if "inference_variables" in keys :
127+ if "inference_variables" in adapted_data and "inference_variables" in keys :
133128 self .inference_variables_norm = Standardization ()
134129 self .inference_variables_norm (adapted_data ["inference_variables" ])
135- if "summary_variables" in keys and self .summary_network :
130+
131+ if "summary_variables" in adapted_data and "summary_variables" in keys and self .summary_network :
136132 self .summary_variables_norm = Standardization ()
137133 self .summary_variables_norm (adapted_data ["summary_variables" ])
138- if "inference_conditions" in keys :
134+
135+ if "inference_conditions" in adapted_data and "inference_conditions" in keys :
139136 self .inference_conditions_norm = Standardization ()
140137 self .inference_conditions_norm (adapted_data ["inference_conditions" ])
141138
@@ -394,21 +391,18 @@ def sample(
394391
395392 # Optionally standardize conditions
396393 if "summary_variables" in conditions and self .summary_variables_norm :
397- conditions ["summary_variables" ] = self .summary_variables_norm (
398- conditions ["summary_variables" ], stage = "inference"
399- )
394+ conditions ["summary_variables" ] = self .summary_variables_norm (conditions ["summary_variables" ])
400395
401396 if "inference_conditions" in conditions and self .inference_conditions_norm :
402- conditions ["inference_conditions" ] = self .inference_conditions_norm (
403- conditions ["inference_conditions" ], stage = "inference"
404- )
397+ conditions ["inference_conditions" ] = self .inference_conditions_norm (conditions ["inference_conditions" ])
398+
405399 conditions = keras .tree .map_structure (keras .ops .convert_to_tensor , conditions )
406400
407401 # Sample and undo optional standardization
408402 samples = self ._sample (num_samples = num_samples , ** conditions , ** kwargs )
409403
410404 if self .inference_variables_norm :
411- samples = self .inference_variables_norm (samples , stage = "inference" , forward = False )
405+ samples = self .inference_variables_norm (samples , forward = False )
412406
413407 samples = {"inference_variables" : samples }
414408 samples = keras .tree .map_structure (keras .ops .convert_to_numpy , samples )
@@ -512,16 +506,14 @@ def log_prob(self, data: Mapping[str, np.ndarray], **kwargs) -> np.ndarray | dic
512506
513507 # Optionally standardize conditions and variables
514508 if "summary_variables" in data and self .summary_variables_norm :
515- data ["summary_variables" ] = self .summary_variables_norm (data ["summary_variables" ], stage = "inference" )
509+ data ["summary_variables" ] = self .summary_variables_norm (data ["summary_variables" ])
516510
517511 if "inference_conditions" in data and self .inference_conditions_norm :
518- data ["inference_conditions" ] = self .inference_conditions_norm (
519- data ["inference_conditions" ], stage = "inference"
520- )
512+ data ["inference_conditions" ] = self .inference_conditions_norm (data ["inference_conditions" ])
521513
522514 if self .inference_variables_norm :
523515 data ["inference_variables" ], log_det_jac = self .summary_variables_norm (
524- data ["inference_variables" ], stage = "inference" , log_det_jac = True
516+ data ["inference_variables" ], log_det_jac = True
525517 )
526518 log_det_jac = keras .ops .convert_to_numpy (log_det_jac )
527519 else :
0 commit comments