1111from bayesflow .utils .serialization import serialize , deserialize , serializable
1212
1313from .approximator import Approximator
14+ from ..networks .standardization import Standardization
1415
1516
1617@serializable ("bayesflow.approximators" )
@@ -40,12 +41,17 @@ def __init__(
4041 adapter : Adapter ,
4142 inference_network : InferenceNetwork ,
4243 summary_network : SummaryNetwork = None ,
44+ standardize : str | Sequence [str ] | None = "all" ,
4345 ** kwargs ,
4446 ):
4547 super ().__init__ (** kwargs )
4648 self .adapter = adapter
4749 self .inference_network = inference_network
4850 self .summary_network = summary_network
51+ self .standardize = standardize
52+ self .inference_variables_norm = None
53+ self .summary_variables_norm = None
54+ self .inference_conditions_norm = None
4955
5056 @classmethod
5157 def build_adapter (
@@ -112,6 +118,31 @@ def compile(
112118
113119 return super ().compile (* args , ** kwargs )
114120
121+ def build_from_data (self , adapted_data : dict [str , any ]) -> None :
122+ # Determine input standardization
123+ if self .standardize == "all" :
124+ keys = ["inference_variables" , "summary_variables" , "inference_conditions" ]
125+ elif isinstance (self .standardize , str ):
126+ keys = [self .standardize ]
127+ elif isinstance (self .standardize , Sequence ):
128+ keys = self .standardize
129+ else :
130+ keys = []
131+
132+ if "inference_variables" in keys :
133+ self .inference_variables_norm = Standardization ()
134+ self .inference_variables_norm (adapted_data ["inference_variables" ])
135+ if "summary_variables" in keys and self .summary_network :
136+ self .summary_variables_norm = Standardization ()
137+ self .summary_variables_norm (adapted_data ["summary_variables" ])
138+ if "inference_conditions" in keys :
139+ self .inference_conditions_norm = Standardization ()
140+ self .inference_conditions_norm (adapted_data ["inference_conditions" ])
141+
142+ # Call compute metrics once to build inner networks
143+ self .compute_metrics (** filter_kwargs (adapted_data , self .compute_metrics ), stage = "training" )
144+ self .built = True
145+
115146 def compile_from_config (self , config ):
116147 self .compile (** deserialize (config ))
117148 if hasattr (self , "optimizer" ) and self .built :
@@ -126,6 +157,10 @@ def compute_metrics(
126157 sample_weight : Tensor = None ,
127158 stage : str = "training" ,
128159 ) -> dict [str , Tensor ]:
160+ # Optionally standardize optional inference conditions
161+ if inference_conditions and self .inference_conditions_norm :
162+ inference_conditions = self .inference_conditions_norm (inference_conditions , stage = stage )
163+
129164 if self .summary_network is None :
130165 if summary_variables is not None :
131166 raise ValueError ("Cannot compute summary metrics without a summary network." )
@@ -135,6 +170,9 @@ def compute_metrics(
135170 if summary_variables is None :
136171 raise ValueError ("Summary variables are required when a summary network is present." )
137172
173+ if self .summary_variables_norm is not None :
174+ summary_variables = self .summary_variables_norm (summary_variables , stage = stage )
175+
138176 summary_metrics = self .summary_network .compute_metrics (summary_variables , stage = stage )
139177 summary_outputs = summary_metrics .pop ("outputs" )
140178
@@ -146,6 +184,10 @@ def compute_metrics(
146184
147185 # Force a conversion to Tensor
148186 inference_variables = keras .tree .map_structure (keras .ops .convert_to_tensor , inference_variables )
187+
188+ if self .inference_variables_norm is not None :
189+ inference_variables = self .inference_variables_norm (inference_variables , stage = stage )
190+
149191 inference_metrics = self .inference_network .compute_metrics (
150192 inference_variables , conditions = inference_conditions , sample_weight = sample_weight , stage = stage
151193 )
@@ -223,6 +265,7 @@ def get_config(self):
223265 "adapter" : self .adapter ,
224266 "inference_network" : self .inference_network ,
225267 "summary_network" : self .summary_network ,
268+ "standardize" : self .standardize ,
226269 }
227270
228271 return base_config | serialize (config )
@@ -349,16 +392,33 @@ def sample(
349392 # Ensure only keys relevant for sampling are present in the conditions dictionary
350393 conditions = {k : v for k , v in conditions .items () if k in ContinuousApproximator .SAMPLE_KEYS }
351394
395+ # Optionally standardize conditions
396+ if "summary_variables" in conditions and self .summary_variables_norm :
397+ conditions ["summary_variables" ] = self .summary_variables_norm (
398+ conditions ["summary_variables" ], stage = "inference"
399+ )
400+
401+ if "inference_conditions" in conditions and self .inference_conditions_norm :
402+ conditions ["inference_conditions" ] = self .inference_conditions_norm (
403+ conditions ["inference_conditions" ], stage = "inference"
404+ )
352405 conditions = keras .tree .map_structure (keras .ops .convert_to_tensor , conditions )
353- conditions = {"inference_variables" : self ._sample (num_samples = num_samples , ** conditions , ** kwargs )}
354- conditions = keras .tree .map_structure (keras .ops .convert_to_numpy , conditions )
406+
407+ # Sample and undo optional standardization
408+ samples = self ._sample (num_samples = num_samples , ** conditions , ** kwargs )
409+
410+ if self .inference_variables_norm :
411+ samples = self .inference_variables_norm (samples , stage = "inference" , forward = False )
412+
413+ samples = {"inference_variables" : samples }
414+ samples = keras .tree .map_structure (keras .ops .convert_to_numpy , samples )
355415
356416 # Back-transform quantities and samples
357- conditions = self .adapter (conditions , inverse = True , strict = False , ** kwargs )
417+ samples = self .adapter (samples , inverse = True , strict = False , ** kwargs )
358418
359419 if split :
360- conditions = split_arrays (conditions , axis = - 1 )
361- return conditions
420+ samples = split_arrays (samples , axis = - 1 )
421+ return samples
362422
363423 def _sample (
364424 self ,
@@ -400,37 +460,35 @@ def _sample(
400460 ** filter_kwargs (kwargs , self .inference_network .sample ),
401461 )
402462
403- def summaries (self , data : Mapping [str , np .ndarray ], ** kwargs ):
463+ def summaries (self , data : Mapping [str , np .ndarray ], ** kwargs ) -> np . ndarray :
404464 """
405- Computes the summaries of given data .
465+ Computes the learned summary statistics of given inputs .
406466
407467 The `data` dictionary is preprocessed using the `adapter` and passed through the summary network.
408468
409469 Parameters
410470 ----------
411471 data : Mapping[str, np.ndarray]
412- Dictionary of data as NumPy arrays.
472+ Dictionary of simulated or real quantities as NumPy arrays.
413473 **kwargs : dict
414474 Additional keyword arguments for the adapter and the summary network.
415475
416476 Returns
417477 -------
418478 summaries : np.ndarray
419- Log-probabilities of the distribution `p(inference_variables | inference_conditions, h(summary_conditions))`
420-
421- Raises
422- ------
423- ValueError
424- If the approximator does not have a summary network, or the adapter does not produce the output required
425- by the summary network.
479+ The learned summary statistics.
426480 """
427481 if self .summary_network is None :
428- raise ValueError ("A summary network is required to compute summeries." )
482+ raise ValueError ("A summary network is required to compute summaries." )
483+
429484 data_adapted = self .adapter (data , strict = False , stage = "inference" , ** kwargs )
430485 if "summary_variables" not in data_adapted or data_adapted ["summary_variables" ] is None :
431486 raise ValueError ("Summary variables are required to compute summaries." )
487+
432488 summary_variables = keras .ops .convert_to_tensor (data_adapted ["summary_variables" ])
433489 summaries = self .summary_network (summary_variables , ** filter_kwargs (kwargs , self .summary_network .call ))
490+ summaries = keras .ops .convert_to_numpy (summaries )
491+
434492 return summaries
435493
436494 def log_prob (self , data : Mapping [str , np .ndarray ], ** kwargs ) -> np .ndarray | dict [str , np .ndarray ]:
@@ -451,6 +509,24 @@ def log_prob(self, data: Mapping[str, np.ndarray], **kwargs) -> np.ndarray | dic
451509 Log-probabilities of the distribution `p(inference_variables | inference_conditions, h(summary_conditions))`
452510 """
453511 data , log_det_jac = self .adapter (data , strict = False , stage = "inference" , log_det_jac = True , ** kwargs )
512+
513+ # Optionally standardize conditions and variables
514+ if "summary_variables" in data and self .summary_variables_norm :
515+ data ["summary_variables" ] = self .summary_variables_norm (data ["summary_variables" ], stage = "inference" )
516+
517+ if "inference_conditions" in data and self .inference_conditions_norm :
518+ data ["inference_conditions" ] = self .inference_conditions_norm (
519+ data ["inference_conditions" ], stage = "inference"
520+ )
521+
522+ if self .inference_variables_norm :
523+ data ["inference_variables" ], log_det_jac = self .summary_variables_norm (
524+ data ["inference_variables" ], stage = "inference" , log_det_jac = True
525+ )
526+ log_det_jac = keras .ops .convert_to_numpy (log_det_jac )
527+ else :
528+ log_det_jac = 0.0
529+
454530 data = keras .tree .map_structure (keras .ops .convert_to_tensor , data )
455531 log_prob = self ._log_prob (** data , ** kwargs )
456532 log_prob = keras .tree .map_structure (keras .ops .convert_to_numpy , log_prob )
0 commit comments