Skip to content

Commit 38f2228

Browse files
committed
Refactor and simplify due to standardize
1 parent 905bf05 commit 38f2228

File tree

3 files changed

+82
-135
lines changed

3 files changed

+82
-135
lines changed

bayesflow/approximators/continuous_approximator.py

Lines changed: 15 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from bayesflow.adapters import Adapter
88
from bayesflow.networks import InferenceNetwork, SummaryNetwork
99
from bayesflow.types import Tensor
10-
from bayesflow.utils import filter_kwargs, logging, split_arrays, squeeze_inner_estimates_dict
10+
from bayesflow.utils import filter_kwargs, logging, split_arrays, squeeze_inner_estimates_dict, concatenate_valid
1111
from bayesflow.utils.serialization import serialize, deserialize, serializable
1212

1313
from .approximator import Approximator
@@ -180,7 +180,9 @@ def compute_metrics(
180180

181181
summary_metrics, summary_outputs = self._compute_summary_metrics(summary_variables, stage=stage)
182182

183-
inference_conditions = self._combine_conditions(inference_conditions, summary_outputs, stage=stage)
183+
if "inference_conditions" in self.standardize:
184+
inference_conditions = self.standardize_layers["inference_conditions"](inference_conditions, stage=stage)
185+
inference_conditions = concatenate_valid((inference_conditions, summary_outputs), axis=-1)
184186

185187
inference_variables = self._prepare_inference_variables(inference_variables, stage=stage)
186188

@@ -192,6 +194,7 @@ def compute_metrics(
192194
loss = inference_metrics["loss"] + summary_metrics["loss"]
193195
else:
194196
loss = inference_metrics.pop("loss")
197+
195198
inference_metrics = {f"{key}/inference_{key}": value for key, value in inference_metrics.items()}
196199
summary_metrics = {f"{key}/summary_{key}": value for key, value in summary_metrics.items()}
197200

@@ -222,21 +225,6 @@ def _prepare_inference_variables(self, inference_variables: Tensor, stage: str)
222225

223226
return inference_variables
224227

225-
def _combine_conditions(
226-
self, inference_conditions: Tensor | None, summary_outputs: Tensor | None, stage: str
227-
) -> Tensor:
228-
"""Helper function to combine direct (inference) conditions and outputs of the summary network."""
229-
if inference_conditions is None:
230-
return summary_outputs
231-
232-
if "inference_conditions" in self.standardize:
233-
inference_conditions = self.standardize_layers["inference_conditions"](inference_conditions, stage=stage)
234-
235-
if summary_outputs is None:
236-
return inference_conditions
237-
238-
return keras.ops.concatenate([inference_conditions, summary_outputs], axis=-1)
239-
240228
def fit(self, *args, **kwargs):
241229
"""
242230
Trains the approximator on the provided dataset or on-demand data generated from the given simulator.
@@ -457,24 +445,17 @@ def _sample(
457445
summary_variables: Tensor = None,
458446
**kwargs,
459447
) -> Tensor:
460-
if self.summary_network is None:
461-
if summary_variables is not None:
462-
raise ValueError("Cannot use summary variables without a summary network.")
463-
else:
464-
if summary_variables is None:
465-
raise ValueError("Summary variables are required when a summary network is present.")
448+
if (self.summary_network is None) != (summary_variables is None):
449+
raise ValueError("Summary variables and summary network must be used together.")
466450

451+
if self.summary_network is not None:
467452
summary_outputs = self.summary_network(
468453
summary_variables, **filter_kwargs(kwargs, self.summary_network.call)
469454
)
470-
471-
if inference_conditions is None:
472-
inference_conditions = summary_outputs
473-
else:
474-
inference_conditions = keras.ops.concatenate([inference_conditions, summary_outputs], axis=1)
455+
inference_conditions = concatenate_valid((inference_conditions, summary_outputs), axis=-1)
475456

476457
if inference_conditions is not None:
477-
# conditions must always have shape (batch_size, dims)
458+
# conditions must always have shape (batch_size, ...)
478459
batch_size = keras.ops.shape(inference_conditions)[0]
479460
inference_conditions = keras.ops.expand_dims(inference_conditions, axis=1)
480461
inference_conditions = keras.ops.broadcast_to(
@@ -485,9 +466,7 @@ def _sample(
485466
batch_shape = (num_samples,)
486467

487468
return self.inference_network.sample(
488-
batch_shape,
489-
conditions=inference_conditions,
490-
**filter_kwargs(kwargs, self.inference_network.sample),
469+
batch_shape, conditions=inference_conditions, **filter_kwargs(kwargs, self.inference_network.sample)
491470
)
492471

493472
def summaries(self, data: Mapping[str, np.ndarray], **kwargs) -> np.ndarray:
@@ -567,21 +546,14 @@ def _log_prob(
567546
summary_variables: Tensor = None,
568547
**kwargs,
569548
) -> Tensor:
570-
if self.summary_network is None:
571-
if summary_variables is not None:
572-
raise ValueError("Cannot use summary variables without a summary network.")
573-
else:
574-
if summary_variables is None:
575-
raise ValueError("Summary variables are required when a summary network is present.")
549+
if (self.summary_network is None) != (summary_variables is None):
550+
raise ValueError("Summary variables and summary network must be used together.")
576551

552+
if self.summary_network is not None:
577553
summary_outputs = self.summary_network(
578554
summary_variables, **filter_kwargs(kwargs, self.summary_network.call)
579555
)
580-
581-
if inference_conditions is None:
582-
inference_conditions = summary_outputs
583-
else:
584-
inference_conditions = keras.ops.concatenate([inference_conditions, summary_outputs], axis=-1)
556+
inference_conditions = concatenate_valid((inference_conditions, summary_outputs), axis=-1)
585557

586558
return self.inference_network.log_prob(
587559
inference_variables,

bayesflow/approximators/model_comparison_approximator.py

Lines changed: 49 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from bayesflow.networks import SummaryNetwork
99
from bayesflow.simulators import ModelComparisonSimulator, Simulator
1010
from bayesflow.types import Tensor
11-
from bayesflow.utils import filter_kwargs, logging
11+
from bayesflow.utils import filter_kwargs, logging, concatenate_valid
1212
from bayesflow.utils.serialization import serialize, deserialize, serializable
1313

1414
from .approximator import Approximator
@@ -180,7 +180,10 @@ def compute_metrics(
180180

181181
summary_metrics, summary_outputs = self._compute_summary_metrics(summary_variables, stage=stage)
182182

183-
classifier_conditions = self._combine_conditions(classifier_conditions, summary_outputs, stage=stage)
183+
if classifier_conditions is not None and "classifier_conditions" in self.standardize:
184+
classifier_conditions = self.standardize_layers["classifier_conditions"](classifier_conditions, stage=stage)
185+
186+
classifier_conditions = concatenate_valid((classifier_conditions, summary_outputs), axis=-1)
184187

185188
logits = self._compute_logits(classifier_conditions)
186189
cross_entropy = keras.ops.mean(keras.losses.categorical_crossentropy(model_indices, logits, from_logits=True))
@@ -193,49 +196,17 @@ def compute_metrics(
193196
metric.name: metric(model_indices, predictions) for metric in self.classifier_network.metrics
194197
}
195198

196-
loss = classifier_metrics.get("loss") + summary_metrics.get("loss", keras.ops.zeros(()))
199+
if "loss" in summary_metrics:
200+
loss = classifier_metrics["loss"] + summary_metrics["loss"]
201+
else:
202+
loss = classifier_metrics.pop("loss")
197203

198204
classifier_metrics = {f"{key}/classifier_{key}": value for key, value in classifier_metrics.items()}
199205
summary_metrics = {f"{key}/summary_{key}": value for key, value in summary_metrics.items()}
200206

201207
metrics = {"loss": loss} | classifier_metrics | summary_metrics
202208
return metrics
203209

204-
def _compute_summary_metrics(self, summary_variables: Tensor, stage: str) -> tuple[dict, Tensor | None]:
205-
"""Helper function to compute summary metrics and outputs."""
206-
if self.summary_network is None:
207-
return {}, None
208-
if summary_variables is None:
209-
raise ValueError("Summary variables are required when a summary network is present.")
210-
211-
if "summary_variables" in self.standardize:
212-
summary_variables = self.standardize_layers["summary_variables"](summary_variables, stage=stage)
213-
214-
summary_metrics = self.summary_network.compute_metrics(summary_variables, stage=stage)
215-
summary_outputs = summary_metrics.pop("outputs")
216-
return summary_metrics, summary_outputs
217-
218-
def _combine_conditions(
219-
self, classifier_conditions: Tensor | None, summary_outputs: Tensor | None, stage
220-
) -> Tensor:
221-
"""Helper to combine classifier conditions and summary outputs, if present."""
222-
if classifier_conditions is None:
223-
return summary_outputs
224-
225-
if "classifier_conditions" in self.standardize:
226-
classifier_conditions = self.standardize_layers["inference_conditions"](classifier_conditions, stage=stage)
227-
228-
if summary_outputs is None:
229-
return classifier_conditions
230-
231-
return keras.ops.concatenate([classifier_conditions, summary_outputs], axis=-1)
232-
233-
def _compute_logits(self, classifier_conditions: Tensor) -> Tensor:
234-
"""Helper to compute projected logits from the classifier network."""
235-
logits = self.classifier_network(classifier_conditions)
236-
logits = self.logits_projector(logits)
237-
return logits
238-
239210
def fit(
240211
self,
241212
*,
@@ -352,7 +323,7 @@ def predict(
352323
self,
353324
*,
354325
conditions: Mapping[str, np.ndarray],
355-
logits: bool = False,
326+
probs: bool = True,
356327
**kwargs,
357328
) -> np.ndarray:
358329
"""
@@ -363,15 +334,14 @@ def predict(
363334
----------
364335
conditions : Mapping[str, np.ndarray]
365336
Dictionary of conditioning variables as NumPy arrays.
366-
logits: bool, default=False
367-
Should the posterior model probabilities be on the (unconstrained) logit space?
368-
If `False`, the output is a unit simplex instead.
337+
probs: bool, optional
338+
A flag indicating whether model probabilities (True) or logits (False) are returned. Default is True.
369339
**kwargs : dict
370-
Additional keyword arguments for the adapter and classification process.
340+
Additional keyword arguments for the adapter and classifier.
371341
372342
Returns
373343
-------
374-
np.ndarray
344+
outputs: np.ndarray
375345
Predicted posterior model probabilities given `conditions`.
376346
"""
377347

@@ -389,34 +359,7 @@ def predict(
389359

390360
output = self._predict(**conditions, **kwargs)
391361

392-
if not logits:
393-
output = keras.ops.softmax(output)
394-
395-
output = keras.ops.convert_to_numpy(output)
396-
397-
return output
398-
399-
def _predict(self, classifier_conditions: Tensor = None, summary_variables: Tensor = None, **kwargs) -> Tensor:
400-
if self.summary_network is None:
401-
if summary_variables is not None:
402-
raise ValueError("Cannot use summary variables without a summary network.")
403-
else:
404-
if summary_variables is None:
405-
raise ValueError("Summary variables are required when a summary network is present")
406-
407-
summary_outputs = self.summary_network(
408-
summary_variables, **filter_kwargs(kwargs, self.summary_network.call)
409-
)
410-
411-
if classifier_conditions is None:
412-
classifier_conditions = summary_outputs
413-
else:
414-
classifier_conditions = keras.ops.concatenate([classifier_conditions, summary_outputs], axis=1)
415-
416-
output = self.classifier_network(classifier_conditions)
417-
output = self.logits_projector(output)
418-
419-
return output
362+
return keras.ops.convert_to_numpy(keras.ops.softmax(output) if probs else output)
420363

421364
def summaries(self, data: Mapping[str, np.ndarray], **kwargs) -> np.ndarray:
422365
"""
@@ -449,6 +392,40 @@ def summaries(self, data: Mapping[str, np.ndarray], **kwargs) -> np.ndarray:
449392

450393
return summaries
451394

395+
def _compute_logits(self, classifier_conditions: Tensor) -> Tensor:
396+
"""Helper to compute projected logits from the classifier network."""
397+
logits = self.classifier_network(classifier_conditions)
398+
logits = self.logits_projector(logits)
399+
return logits
400+
401+
def _predict(self, classifier_conditions: Tensor = None, summary_variables: Tensor = None, **kwargs) -> Tensor:
402+
"""Helper method to obtain logits from the internal classifier based on conditions."""
403+
if (self.summary_network is None) != (summary_variables is None):
404+
raise ValueError("Summary variables and summary network must be used together.")
405+
406+
if self.summary_network is not None:
407+
summary_outputs = self.summary_network(
408+
summary_variables, **filter_kwargs(kwargs, self.summary_network.call)
409+
)
410+
classifier_conditions = concatenate_valid((classifier_conditions, summary_outputs), axis=-1)
411+
412+
logits = self._compute_logits(classifier_conditions)
413+
return logits
414+
415+
def _compute_summary_metrics(self, summary_variables: Tensor, stage: str) -> tuple[dict, Tensor | None]:
416+
"""Helper function to compute summary metrics and outputs."""
417+
if self.summary_network is None:
418+
return {}, None
419+
if summary_variables is None:
420+
raise ValueError("Summary variables are required when a summary network is present.")
421+
422+
if "summary_variables" in self.standardize:
423+
summary_variables = self.standardize_layers["summary_variables"](summary_variables, stage=stage)
424+
425+
summary_metrics = self.summary_network.compute_metrics(summary_variables, stage=stage)
426+
summary_outputs = summary_metrics.pop("outputs")
427+
return summary_metrics, summary_outputs
428+
452429
def _batch_size_from_data(self, data: Mapping[str, any]) -> int:
453430
"""
454431
Fetches the current batch size from an input dictionary. Can only be used during training when

bayesflow/approximators/point_approximator.py

Lines changed: 18 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import keras
66

77
from bayesflow.types import Tensor
8-
from bayesflow.utils import filter_kwargs, split_arrays, squeeze_inner_estimates_dict, logging
8+
from bayesflow.utils import filter_kwargs, split_arrays, squeeze_inner_estimates_dict, logging, concatenate_valid
99
from bayesflow.utils.serialization import serializable
1010

1111
from .continuous_approximator import ContinuousApproximator
@@ -57,11 +57,14 @@ def estimate(
5757
"""
5858

5959
conditions = self._prepare_conditions(conditions, **kwargs)
60+
6061
estimates = self._estimate(**conditions, **kwargs)
6162
estimates = self._apply_inverse_adapter_to_estimates(estimates, **kwargs)
63+
6264
# Optionally split the arrays along the last axis.
6365
if split:
6466
estimates = split_arrays(estimates, axis=-1)
67+
6568
# Reorder the nested dictionary so that original variable names are at the top.
6669
estimates = PointApproximator._reorder_estimates(estimates)
6770
# Remove unnecessary nesting.
@@ -108,9 +111,10 @@ def sample(
108111
of shape (num_datasets, num_samples, variable_block_size).
109112
"""
110113
conditions = self._prepare_conditions(conditions, **kwargs)
114+
111115
samples = self._sample(num_samples, **conditions, **kwargs)
112116
samples = self._apply_inverse_adapter_to_samples(samples, **kwargs)
113-
# Optionally split the arrays along the last axis.
117+
114118
if split:
115119
raise NotImplementedError("split=True is currently not supported for `PointApproximator`.")
116120

@@ -148,18 +152,19 @@ def log_prob(
148152
149153
Log-probabilities have shape (num_datasets,).
150154
"""
151-
log_prob = super().log_prob(data=data, **kwargs)
152-
# Squeeze log probabilities dictionary if there's only one key-value pair.
153-
log_prob = PointApproximator._squeeze_parametric_score_major_dict(log_prob)
154-
155-
return log_prob
155+
return super().log_prob(data=data, **kwargs)
156156

157157
def _prepare_conditions(self, conditions: Mapping[str, np.ndarray], **kwargs) -> dict[str, Tensor]:
158-
"""Adapts and converts the conditions to tensors."""
158+
"""Adapts, optionally standardizes, and converts the conditions to tensors."""
159159

160160
conditions = self.adapter(conditions, strict=False, stage="inference", **kwargs)
161161
conditions = {k: v for k, v in conditions.items() if k in ContinuousApproximator.CONDITION_KEYS}
162162

163+
# Optionally standardize conditions
164+
for key, value in conditions.items():
165+
if key in self.standardize:
166+
conditions[key] = self.standardize_layers[key](value)
167+
163168
return keras.tree.map_structure(keras.ops.convert_to_tensor, conditions)
164169

165170
def _apply_inverse_adapter_to_estimates(
@@ -233,7 +238,7 @@ def _squeeze_estimates(
233238
def _squeeze_parametric_score_major_dict(samples: Mapping[str, np.ndarray]) -> np.ndarray or dict[str, np.ndarray]:
234239
"""Squeezes the dictionary to just the value if there is only one key-value pair."""
235240
if len(samples) == 1:
236-
return next(iter(samples.values())) # Extract and return the only item's value
241+
return next(iter(samples.values()))
237242
return samples
238243

239244
def _estimate(
@@ -242,21 +247,14 @@ def _estimate(
242247
summary_variables: Tensor = None,
243248
**kwargs,
244249
) -> dict[str, dict[str, Tensor]]:
245-
if self.summary_network is None:
246-
if summary_variables is not None:
247-
raise ValueError("Cannot use summary variables without a summary network.")
248-
else:
249-
if summary_variables is None:
250-
raise ValueError("Summary variables are required when a summary network is present.")
250+
if (self.summary_network is None) != (summary_variables is None):
251+
raise ValueError("Summary variables and summary network must be used together.")
251252

253+
if self.summary_network is not None:
252254
summary_outputs = self.summary_network(
253255
summary_variables, **filter_kwargs(kwargs, self.summary_network.call)
254256
)
255-
256-
if inference_conditions is None:
257-
inference_conditions = summary_outputs
258-
else:
259-
inference_conditions = keras.ops.concatenate([inference_conditions, summary_outputs], axis=1)
257+
inference_conditions = concatenate_valid((inference_conditions, summary_outputs), axis=-1)
260258

261259
return self.inference_network(
262260
conditions=inference_conditions,

0 commit comments

Comments
 (0)