1212from bayesflow .utils .serialization import serialize , deserialize , serializable
1313
1414from .approximator import Approximator
15+ from ..networks .standardization import Standardization
1516
1617
1718@serializable ("bayesflow.approximators" )
@@ -44,17 +45,37 @@ def __init__(
4445 classifier_network : keras .Layer ,
4546 adapter : Adapter ,
4647 summary_network : SummaryNetwork = None ,
48+ standardize : str | Sequence [str ] | None = "all" ,
4749 ** kwargs ,
4850 ):
4951 super ().__init__ (** kwargs )
5052 self .classifier_network = classifier_network
5153 self .adapter = adapter
5254 self .summary_network = summary_network
5355 self .num_models = num_models
56+ self .standardize = standardize
5457 self .logits_projector = keras .layers .Dense (num_models )
5558
59+ self .summary_variables_norm = None
60+ self .classifier_conditions_norm = None
61+
5662 def build (self , data_shapes : Mapping [str , Shape ]):
5763 data = {key : keras .ops .zeros (value ) for key , value in data_shapes .items ()}
64+
65+ if self .standardize == "all" :
66+ keys = ModelComparisonApproximator .SAMPLE_KEYS
67+ elif isinstance (self .standardize , str ):
68+ keys = [self .standardize ]
69+ elif isinstance (self .standardize , Sequence ):
70+ keys = self .standardize
71+ else :
72+ keys = []
73+
74+ if "summary_variables" in data_shapes and "summary_variables" in data and self .summary_network :
75+ self .summary_variables_norm = Standardization ()
76+ if "classifier_conditions" in data_shapes and "classifier_conditions" in keys :
77+ self .classifier_conditions_norm = Standardization ()
78+
5879 self .compute_metrics (** data , stage = "training" )
5980
6081 @classmethod
@@ -134,46 +155,97 @@ def compute_metrics(
134155 summary_variables : Tensor = None ,
135156 stage : str = "training" ,
136157 ) -> dict [str , Tensor ]:
137- if self .summary_network is None :
138- summary_metrics = {}
139- else :
140- summary_metrics = self .summary_network .compute_metrics (summary_variables , stage = stage )
141- summary_outputs = summary_metrics .pop ("outputs" )
158+ """
159+ Computes loss and tracks metrics for the classifier and summary networks.
142160
143- if classifier_conditions is None :
144- classifier_conditions = summary_outputs
145- else :
146- classifier_conditions = keras . ops . concatenate ([ classifier_conditions , summary_outputs ], axis = - 1 )
161+ This method coordinates summary metric computation ( if present), combines summary outputs with
162+ classifier conditions, computes classifier logits and cross-entropy loss, and aggregates all
163+ tracked metrics into a single dictionary. Keys are prefixed with "classifier_" or "summary_"
164+ to indicate their origin.
147165
148- # we could move this into its own class
149- logits = self .classifier_network (classifier_conditions )
150- logits = self .logits_projector (logits )
166+ Parameters
167+ ----------
168+ classifier_conditions : Tensor, optional
169+ Conditioning variables for the classifier network (default is None). May be
170+ combined with summary network outputs if present.
171+ model_indices : Tensor
172+ Ground-truth indices or one-hot encoded labels for classification.
173+ summary_variables : Tensor, optional
174+ Input tensor(s) for the summary network (default is None). Required if a summary
175+ network is present.
176+ stage : str, optional
177+ Current training stage (e.g., "training", "validation", "inference"). Controls
178+ certain metric computations (default is "training").
179+
180+ Returns
181+ -------
182+ metrics : dict[str, Tensor]
183+ Dictionary containing the total loss under the key "loss", as well as all tracked
184+ metrics for the classifier and summary networks. Each metric key is prefixed to
185+ indicate its source.
186+ """
187+
188+ summary_metrics , summary_outputs = self ._compute_summary_metrics (summary_variables , stage = stage )
189+
190+ classifier_conditions = self ._combine_conditions (classifier_conditions , summary_outputs , stage = stage )
151191
152- cross_entropy = keras . losses . categorical_crossentropy ( model_indices , logits , from_logits = True )
153- cross_entropy = keras .ops .mean (cross_entropy )
192+ logits = self . _compute_logits ( classifier_conditions )
193+ cross_entropy = keras .ops .mean (keras . losses . categorical_crossentropy ( model_indices , logits , from_logits = True ) )
154194
155195 classifier_metrics = {"loss" : cross_entropy }
156196
157197 if stage != "training" and any (self .classifier_network .metrics ):
158- # compute sample-based metrics
159198 predictions = keras .ops .argmax (logits , axis = - 1 )
160199 classifier_metrics |= {
161200 metric .name : metric (model_indices , predictions ) for metric in self .classifier_network .metrics
162201 }
163202
164- loss = classifier_metrics .get ("loss" , keras . ops . zeros (()) ) + summary_metrics .get ("loss" , keras .ops .zeros (()))
203+ loss = classifier_metrics .get ("loss" ) + summary_metrics .get ("loss" , keras .ops .zeros (()))
165204
166205 classifier_metrics = {f"{ key } /classifier_{ key } " : value for key , value in classifier_metrics .items ()}
167206 summary_metrics = {f"{ key } /summary_{ key } " : value for key , value in summary_metrics .items ()}
168207
169208 metrics = {"loss" : loss } | classifier_metrics | summary_metrics
170-
171209 return metrics
172210
211+ def _compute_summary_metrics (self , summary_variables : Tensor , stage : str ) -> tuple [dict , Tensor | None ]:
212+ """Helper function to compute summary metrics and outputs."""
213+ if self .summary_network is None :
214+ return {}, None
215+ if summary_variables is None :
216+ raise ValueError ("Summary variables are required when a summary network is present." )
217+
218+ if self .summary_variables_norm is not None :
219+ summary_variables = self .summary_variables_norm (summary_variables , stage = stage )
220+
221+ summary_metrics = self .summary_network .compute_metrics (summary_variables , stage = stage )
222+ summary_outputs = summary_metrics .pop ("outputs" )
223+ return summary_metrics , summary_outputs
224+
225+ def _combine_conditions (
226+ self , classifier_conditions : Tensor | None , summary_outputs : Tensor | None , stage
227+ ) -> Tensor :
228+ """Helper to combine classifier conditions and summary outputs, if present."""
229+ if classifier_conditions is None :
230+ return summary_outputs
231+
232+ if self .classifier_conditions_norm :
233+ classifier_conditions = self .classifier_conditions_norm (classifier_conditions , stage = stage )
234+
235+ if summary_outputs is None :
236+ return classifier_conditions
237+ return keras .ops .concatenate ([classifier_conditions , summary_outputs ], axis = - 1 )
238+
239+ def _compute_logits (self , classifier_conditions : Tensor ) -> Tensor :
240+ """Helper to compute projected logits from the classifier network."""
241+ logits = self .classifier_network (classifier_conditions )
242+ logits = self .logits_projector (logits )
243+ return logits
244+
173245 def fit (
174246 self ,
175247 * ,
176- adapter : Adapter = "auto" ,
248+ adapter : Adapter | str = "auto" ,
177249 dataset : keras .utils .PyDataset = None ,
178250 simulator : ModelComparisonSimulator = None ,
179251 simulators : Sequence [Simulator ] = None ,
@@ -182,11 +254,13 @@ def fit(
182254 """
183255 Trains the approximator on the provided dataset or on-demand generated from the given (multi-model) simulator.
184256 If `dataset` is not provided, a dataset is built from the `simulator`.
185- If `simulator` is not provided, it will be build from a list of `simulators`.
257+ If `simulator` is not provided, it will be built from a list of `simulators`.
186258 If the model has not been built, it will be built using a batch from the dataset.
187259
188260 Parameters
189261 ----------
262+ adapter : Adapter or str, optional
263+ The data adapter that will make the simulated / real outputs neural-network friendly.
190264 dataset : keras.utils.PyDataset, optional
191265 A dataset containing simulations for training. If provided, `simulator` must be None.
192266 simulator : ModelComparisonSimulator, optional
@@ -315,6 +389,13 @@ def predict(
315389
316390 conditions = keras .tree .map_structure (keras .ops .convert_to_tensor , conditions )
317391
392+ # Optionally standardize conditions
393+ if "summary_variables" in conditions and self .summary_variables_norm :
394+ conditions ["summary_variables" ] = self .summary_variables_norm (conditions ["summary_variables" ])
395+
396+ if "classifier_conditions" in conditions and self .classifier_conditions_norm :
397+ conditions ["classifier_conditions" ] = self .classifier_conditions_norm (conditions ["classifier_conditions" ])
398+
318399 output = self ._predict (** conditions , ** kwargs )
319400
320401 if not logits :
@@ -346,35 +427,33 @@ def _predict(self, classifier_conditions: Tensor = None, summary_variables: Tens
346427
347428 return output
348429
349- def summaries (self , data : Mapping [str , np .ndarray ], ** kwargs ):
430+ def summaries (self , data : Mapping [str , np .ndarray ], ** kwargs ) -> np . ndarray :
350431 """
351- Computes the summaries of given data .
432+ Computes the learned summary statistics of given inputs .
352433
353434 The `data` dictionary is preprocessed using the `adapter` and passed through the summary network.
354435
355436 Parameters
356437 ----------
357438 data : Mapping[str, np.ndarray]
358- Dictionary of data as NumPy arrays.
439+ Dictionary of simulated or real quantities as NumPy arrays.
359440 **kwargs : dict
360441 Additional keyword arguments for the adapter and the summary network.
361442
362443 Returns
363444 -------
364445 summaries : np.ndarray
365- Log-probabilities of the distribution `p(inference_variables | inference_conditions, h(summary_conditions))`
366-
367- Raises
368- ------
369- ValueError
370- If the approximator does not have a summary network, or the adapter does not produce the output required
371- by the summary network.
446+ The learned summary statistics.
372447 """
373448 if self .summary_network is None :
374449 raise ValueError ("A summary network is required to compute summaries." )
450+
375451 data_adapted = self .adapter (data , strict = False , stage = "inference" , ** kwargs )
376452 if "summary_variables" not in data_adapted or data_adapted ["summary_variables" ] is None :
377453 raise ValueError ("Summary variables are required to compute summaries." )
454+
378455 summary_variables = keras .ops .convert_to_tensor (data_adapted ["summary_variables" ])
379456 summaries = self .summary_network (summary_variables , ** filter_kwargs (kwargs , self .summary_network .call ))
457+ summaries = keras .ops .convert_to_numpy (summaries )
458+
380459 return summaries
0 commit comments