88from bayesflow .networks import SummaryNetwork
99from bayesflow .simulators import ModelComparisonSimulator , Simulator
1010from bayesflow .types import Tensor
11- from bayesflow .utils import filter_kwargs , logging
11+ from bayesflow .utils import filter_kwargs , logging , concatenate_valid
1212from bayesflow .utils .serialization import serialize , deserialize , serializable
1313
1414from .approximator import Approximator
@@ -180,7 +180,10 @@ def compute_metrics(
180180
181181 summary_metrics , summary_outputs = self ._compute_summary_metrics (summary_variables , stage = stage )
182182
183- classifier_conditions = self ._combine_conditions (classifier_conditions , summary_outputs , stage = stage )
183+ if classifier_conditions is not None and "classifier_conditions" in self .standardize :
184+ classifier_conditions = self .standardize_layers ["classifier_conditions" ](classifier_conditions , stage = stage )
185+
186+ classifier_conditions = concatenate_valid ((classifier_conditions , summary_outputs ), axis = - 1 )
184187
185188 logits = self ._compute_logits (classifier_conditions )
186189 cross_entropy = keras .ops .mean (keras .losses .categorical_crossentropy (model_indices , logits , from_logits = True ))
@@ -193,49 +196,17 @@ def compute_metrics(
193196 metric .name : metric (model_indices , predictions ) for metric in self .classifier_network .metrics
194197 }
195198
196- loss = classifier_metrics .get ("loss" ) + summary_metrics .get ("loss" , keras .ops .zeros (()))
199+ if "loss" in summary_metrics :
200+ loss = classifier_metrics ["loss" ] + summary_metrics ["loss" ]
201+ else :
202+ loss = classifier_metrics .pop ("loss" )
197203
198204 classifier_metrics = {f"{ key } /classifier_{ key } " : value for key , value in classifier_metrics .items ()}
199205 summary_metrics = {f"{ key } /summary_{ key } " : value for key , value in summary_metrics .items ()}
200206
201207 metrics = {"loss" : loss } | classifier_metrics | summary_metrics
202208 return metrics
203209
204- def _compute_summary_metrics (self , summary_variables : Tensor , stage : str ) -> tuple [dict , Tensor | None ]:
205- """Helper function to compute summary metrics and outputs."""
206- if self .summary_network is None :
207- return {}, None
208- if summary_variables is None :
209- raise ValueError ("Summary variables are required when a summary network is present." )
210-
211- if "summary_variables" in self .standardize :
212- summary_variables = self .standardize_layers ["summary_variables" ](summary_variables , stage = stage )
213-
214- summary_metrics = self .summary_network .compute_metrics (summary_variables , stage = stage )
215- summary_outputs = summary_metrics .pop ("outputs" )
216- return summary_metrics , summary_outputs
217-
218- def _combine_conditions (
219- self , classifier_conditions : Tensor | None , summary_outputs : Tensor | None , stage
220- ) -> Tensor :
221- """Helper to combine classifier conditions and summary outputs, if present."""
222- if classifier_conditions is None :
223- return summary_outputs
224-
225- if "classifier_conditions" in self .standardize :
226- classifier_conditions = self .standardize_layers ["inference_conditions" ](classifier_conditions , stage = stage )
227-
228- if summary_outputs is None :
229- return classifier_conditions
230-
231- return keras .ops .concatenate ([classifier_conditions , summary_outputs ], axis = - 1 )
232-
233- def _compute_logits (self , classifier_conditions : Tensor ) -> Tensor :
234- """Helper to compute projected logits from the classifier network."""
235- logits = self .classifier_network (classifier_conditions )
236- logits = self .logits_projector (logits )
237- return logits
238-
239210 def fit (
240211 self ,
241212 * ,
@@ -352,7 +323,7 @@ def predict(
352323 self ,
353324 * ,
354325 conditions : Mapping [str , np .ndarray ],
355- logits : bool = False ,
326+ probs : bool = True ,
356327 ** kwargs ,
357328 ) -> np .ndarray :
358329 """
@@ -363,15 +334,14 @@ def predict(
363334 ----------
364335 conditions : Mapping[str, np.ndarray]
365336 Dictionary of conditioning variables as NumPy arrays.
366- logits: bool, default=False
367- Should the posterior model probabilities be on the (unconstrained) logit space?
368- If `False`, the output is a unit simplex instead.
337+ probs: bool, optional
338+ A flag indicating whether model probabilities (True) or logits (False) are returned. Default is True.
369339 **kwargs : dict
370- Additional keyword arguments for the adapter and classification process .
340+ Additional keyword arguments for the adapter and classifier .
371341
372342 Returns
373343 -------
374- np.ndarray
344+ outputs: np.ndarray
375345 Predicted posterior model probabilities given `conditions`.
376346 """
377347
@@ -389,34 +359,7 @@ def predict(
389359
390360 output = self ._predict (** conditions , ** kwargs )
391361
392- if not logits :
393- output = keras .ops .softmax (output )
394-
395- output = keras .ops .convert_to_numpy (output )
396-
397- return output
398-
399- def _predict (self , classifier_conditions : Tensor = None , summary_variables : Tensor = None , ** kwargs ) -> Tensor :
400- if self .summary_network is None :
401- if summary_variables is not None :
402- raise ValueError ("Cannot use summary variables without a summary network." )
403- else :
404- if summary_variables is None :
405- raise ValueError ("Summary variables are required when a summary network is present" )
406-
407- summary_outputs = self .summary_network (
408- summary_variables , ** filter_kwargs (kwargs , self .summary_network .call )
409- )
410-
411- if classifier_conditions is None :
412- classifier_conditions = summary_outputs
413- else :
414- classifier_conditions = keras .ops .concatenate ([classifier_conditions , summary_outputs ], axis = 1 )
415-
416- output = self .classifier_network (classifier_conditions )
417- output = self .logits_projector (output )
418-
419- return output
362+ return keras .ops .convert_to_numpy (keras .ops .softmax (output ) if probs else output )
420363
421364 def summaries (self , data : Mapping [str , np .ndarray ], ** kwargs ) -> np .ndarray :
422365 """
@@ -449,6 +392,40 @@ def summaries(self, data: Mapping[str, np.ndarray], **kwargs) -> np.ndarray:
449392
450393 return summaries
451394
395+ def _compute_logits (self , classifier_conditions : Tensor ) -> Tensor :
396+ """Helper to compute projected logits from the classifier network."""
397+ logits = self .classifier_network (classifier_conditions )
398+ logits = self .logits_projector (logits )
399+ return logits
400+
401+ def _predict (self , classifier_conditions : Tensor = None , summary_variables : Tensor = None , ** kwargs ) -> Tensor :
402+ """Helper method to obtain logits from the internal classifier based on conditions."""
403+ if (self .summary_network is None ) != (summary_variables is None ):
404+ raise ValueError ("Summary variables and summary network must be used together." )
405+
406+ if self .summary_network is not None :
407+ summary_outputs = self .summary_network (
408+ summary_variables , ** filter_kwargs (kwargs , self .summary_network .call )
409+ )
410+ classifier_conditions = concatenate_valid ((classifier_conditions , summary_outputs ), axis = - 1 )
411+
412+ logits = self ._compute_logits (classifier_conditions )
413+ return logits
414+
415+ def _compute_summary_metrics (self , summary_variables : Tensor , stage : str ) -> tuple [dict , Tensor | None ]:
416+ """Helper function to compute summary metrics and outputs."""
417+ if self .summary_network is None :
418+ return {}, None
419+ if summary_variables is None :
420+ raise ValueError ("Summary variables are required when a summary network is present." )
421+
422+ if "summary_variables" in self .standardize :
423+ summary_variables = self .standardize_layers ["summary_variables" ](summary_variables , stage = stage )
424+
425+ summary_metrics = self .summary_network .compute_metrics (summary_variables , stage = stage )
426+ summary_outputs = summary_metrics .pop ("outputs" )
427+ return summary_metrics , summary_outputs
428+
452429 def _batch_size_from_data (self , data : Mapping [str , any ]) -> int :
453430 """
454431 Fetches the current batch size from an input dictionary. Can only be used during training when
0 commit comments