Skip to content

Commit 8ea6782

Browse files
committed
Refactor compute metrics and add standardization to model comp
1 parent 4df270a commit 8ea6782

File tree

2 files changed

+182
-56
lines changed

2 files changed

+182
-56
lines changed

bayesflow/approximators/continuous_approximator.py

Lines changed: 74 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -67,11 +67,11 @@ def build_adapter(
6767
Parameters
6868
----------
6969
inference_variables : Sequence of str
70-
Names of the inference variables in the data
70+
Names of the inference variables (to be modeled) in the data dict.
7171
inference_conditions : Sequence of str, optional
72-
Names of the inference conditions in the data
72+
Names of the inference conditions (to be used as direct conditions) in the data dict.
7373
summary_variables : Sequence of str, optional
74-
Names of the summary variables in the data
74+
Names of the summary variables (to be passed to a summary network) in the data dict.
7575
sample_weight : str, optional
7676
Name of the sample weights
7777
"""
@@ -151,36 +151,44 @@ def compute_metrics(
151151
sample_weight: Tensor = None,
152152
stage: str = "training",
153153
) -> dict[str, Tensor]:
154-
# Optionally standardize optional inference conditions
155-
if inference_conditions is not None and self.inference_conditions_norm:
156-
inference_conditions = self.inference_conditions_norm(inference_conditions, stage=stage)
154+
"""
155+
Computes loss and tracks metrics for the inference and summary networks.
157156
158-
if self.summary_network is None:
159-
if summary_variables is not None:
160-
raise ValueError("Cannot compute summary metrics without a summary network.")
157+
This method orchestrates the end-to-end computation of metrics and loss for a model
158+
with both inference and optional summary network. It handles standardization of input
159+
variables, combines summary outputs with inference conditions when necessary, and
160+
aggregates loss and all tracked metrics into a unified dictionary. The returned dictionary
161+
includes both the total loss and all individual metrics, with keys indicating their source.
161162
162-
summary_metrics = {}
163-
else:
164-
if summary_variables is None:
165-
raise ValueError("Summary variables are required when a summary network is present.")
166-
167-
if self.summary_variables_norm is not None:
168-
summary_variables = self.summary_variables_norm(summary_variables, stage=stage)
163+
Parameters
164+
----------
165+
inference_variables : Tensor
166+
Input tensor(s) for the inference network. These are typically latent variables to be modeled.
167+
inference_conditions : Tensor, optional
168+
Conditioning variables for the inference network (default is None).
169+
May be combined with outputs from the summary network if present.
170+
summary_variables : Tensor, optional
171+
Input tensor(s) for the summary network (default is None). Required if
172+
a summary network is present.
173+
sample_weight : Tensor, optional
174+
Weighting tensor for metric computation (default is None).
175+
stage : str, optional
176+
Current training stage (e.g., "training", "validation", "inference"). Controls
177+
the behavior of standardization and some metric computations (default is "training").
169178
170-
summary_metrics = self.summary_network.compute_metrics(summary_variables, stage=stage)
171-
summary_outputs = summary_metrics.pop("outputs")
179+
Returns
180+
-------
181+
metrics : dict[str, Tensor]
182+
Dictionary containing the total loss under the key "loss", as well as all tracked
183+
metrics for the inference and summary networks. Each metric key is prefixed with
184+
"inference_" or "summary_" to indicate its source.
185+
"""
172186

173-
# append summary outputs to inference conditions
174-
if inference_conditions is None:
175-
inference_conditions = summary_outputs
176-
else:
177-
inference_conditions = keras.ops.concatenate([inference_conditions, summary_outputs], axis=-1)
187+
summary_metrics, summary_outputs = self._compute_summary_metrics(summary_variables, stage=stage)
178188

179-
# Force a conversion to Tensor
180-
inference_variables = keras.tree.map_structure(keras.ops.convert_to_tensor, inference_variables)
189+
inference_conditions = self._combine_conditions(inference_conditions, summary_outputs, stage=stage)
181190

182-
if self.inference_variables_norm is not None:
183-
inference_variables = self.inference_variables_norm(inference_variables, stage=stage)
191+
inference_variables = self._prepare_inference_variables(inference_variables, stage=stage)
184192

185193
inference_metrics = self.inference_network.compute_metrics(
186194
inference_variables, conditions=inference_conditions, sample_weight=sample_weight, stage=stage
@@ -195,6 +203,45 @@ def compute_metrics(
195203

196204
return metrics
197205

206+
def _compute_summary_metrics(self, summary_variables: Tensor | None, stage: str) -> tuple[dict, Tensor | None]:
207+
"""Helper function to compute summary metrics and outputs."""
208+
if self.summary_network is None:
209+
if summary_variables is not None:
210+
raise ValueError("Cannot compute summaries from summary_variables without a summary network.")
211+
return {}, None
212+
213+
if summary_variables is None:
214+
raise ValueError("Summary variables are required when a summary network is present.")
215+
216+
if self.summary_variables_norm is not None:
217+
summary_variables = self.summary_variables_norm(summary_variables, stage=stage)
218+
219+
summary_metrics = self.summary_network.compute_metrics(summary_variables, stage=stage)
220+
summary_outputs = summary_metrics.pop("outputs")
221+
return summary_metrics, summary_outputs
222+
223+
def _prepare_inference_variables(self, inference_variables, stage):
224+
"""Helper function to convert inference variables to tensors and optionally standardize them."""
225+
inference_variables = keras.tree.map_structure(keras.ops.convert_to_tensor, inference_variables)
226+
if self.inference_variables_norm is not None:
227+
inference_variables = self.inference_variables_norm(inference_variables, stage=stage)
228+
return inference_variables
229+
230+
def _combine_conditions(
231+
self, inference_conditions: Tensor | None, summary_outputs: Tensor | None, stage: str
232+
) -> Tensor:
233+
"""Helper function to combine direct (inference) conditions and outputs of the summary network."""
234+
if inference_conditions is None:
235+
return summary_outputs
236+
237+
if self.inference_conditions_norm:
238+
inference_conditions = self.inference_conditions_norm(inference_conditions, stage=stage)
239+
240+
if summary_outputs is None:
241+
return inference_conditions
242+
243+
return keras.ops.concatenate([inference_conditions, summary_outputs], axis=-1)
244+
198245
def fit(self, *args, **kwargs):
199246
"""
200247
Trains the approximator on the provided dataset or on-demand data generated from the given simulator.

bayesflow/approximators/model_comparison_approximator.py

Lines changed: 108 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from bayesflow.utils.serialization import serialize, deserialize, serializable
1313

1414
from .approximator import Approximator
15+
from ..networks.standardization import Standardization
1516

1617

1718
@serializable("bayesflow.approximators")
@@ -44,17 +45,37 @@ def __init__(
4445
classifier_network: keras.Layer,
4546
adapter: Adapter,
4647
summary_network: SummaryNetwork = None,
48+
standardize: str | Sequence[str] | None = "all",
4749
**kwargs,
4850
):
4951
super().__init__(**kwargs)
5052
self.classifier_network = classifier_network
5153
self.adapter = adapter
5254
self.summary_network = summary_network
5355
self.num_models = num_models
56+
self.standardize = standardize
5457
self.logits_projector = keras.layers.Dense(num_models)
5558

59+
self.summary_variables_norm = None
60+
self.classifier_conditions_norm = None
61+
5662
def build(self, data_shapes: Mapping[str, Shape]):
5763
data = {key: keras.ops.zeros(value) for key, value in data_shapes.items()}
64+
65+
if self.standardize == "all":
66+
keys = ModelComparisonApproximator.SAMPLE_KEYS
67+
elif isinstance(self.standardize, str):
68+
keys = [self.standardize]
69+
elif isinstance(self.standardize, Sequence):
70+
keys = self.standardize
71+
else:
72+
keys = []
73+
74+
if "summary_variables" in data_shapes and "summary_variables" in data and self.summary_network:
75+
self.summary_variables_norm = Standardization()
76+
if "classifier_conditions" in data_shapes and "classifier_conditions" in keys:
77+
self.classifier_conditions_norm = Standardization()
78+
5879
self.compute_metrics(**data, stage="training")
5980

6081
@classmethod
@@ -134,46 +155,97 @@ def compute_metrics(
134155
summary_variables: Tensor = None,
135156
stage: str = "training",
136157
) -> dict[str, Tensor]:
137-
if self.summary_network is None:
138-
summary_metrics = {}
139-
else:
140-
summary_metrics = self.summary_network.compute_metrics(summary_variables, stage=stage)
141-
summary_outputs = summary_metrics.pop("outputs")
158+
"""
159+
Computes loss and tracks metrics for the classifier and summary networks.
142160
143-
if classifier_conditions is None:
144-
classifier_conditions = summary_outputs
145-
else:
146-
classifier_conditions = keras.ops.concatenate([classifier_conditions, summary_outputs], axis=-1)
161+
This method coordinates summary metric computation (if present), combines summary outputs with
162+
classifier conditions, computes classifier logits and cross-entropy loss, and aggregates all
163+
tracked metrics into a single dictionary. Keys are prefixed with "classifier_" or "summary_"
164+
to indicate their origin.
147165
148-
# we could move this into its own class
149-
logits = self.classifier_network(classifier_conditions)
150-
logits = self.logits_projector(logits)
166+
Parameters
167+
----------
168+
classifier_conditions : Tensor, optional
169+
Conditioning variables for the classifier network (default is None). May be
170+
combined with summary network outputs if present.
171+
model_indices : Tensor
172+
Ground-truth indices or one-hot encoded labels for classification.
173+
summary_variables : Tensor, optional
174+
Input tensor(s) for the summary network (default is None). Required if a summary
175+
network is present.
176+
stage : str, optional
177+
Current training stage (e.g., "training", "validation", "inference"). Controls
178+
certain metric computations (default is "training").
179+
180+
Returns
181+
-------
182+
metrics : dict[str, Tensor]
183+
Dictionary containing the total loss under the key "loss", as well as all tracked
184+
metrics for the classifier and summary networks. Each metric key is prefixed to
185+
indicate its source.
186+
"""
187+
188+
summary_metrics, summary_outputs = self._compute_summary_metrics(summary_variables, stage=stage)
189+
190+
classifier_conditions = self._combine_conditions(classifier_conditions, summary_outputs, stage=stage)
151191

152-
cross_entropy = keras.losses.categorical_crossentropy(model_indices, logits, from_logits=True)
153-
cross_entropy = keras.ops.mean(cross_entropy)
192+
logits = self._compute_logits(classifier_conditions)
193+
cross_entropy = keras.ops.mean(keras.losses.categorical_crossentropy(model_indices, logits, from_logits=True))
154194

155195
classifier_metrics = {"loss": cross_entropy}
156196

157197
if stage != "training" and any(self.classifier_network.metrics):
158-
# compute sample-based metrics
159198
predictions = keras.ops.argmax(logits, axis=-1)
160199
classifier_metrics |= {
161200
metric.name: metric(model_indices, predictions) for metric in self.classifier_network.metrics
162201
}
163202

164-
loss = classifier_metrics.get("loss", keras.ops.zeros(())) + summary_metrics.get("loss", keras.ops.zeros(()))
203+
loss = classifier_metrics.get("loss") + summary_metrics.get("loss", keras.ops.zeros(()))
165204

166205
classifier_metrics = {f"{key}/classifier_{key}": value for key, value in classifier_metrics.items()}
167206
summary_metrics = {f"{key}/summary_{key}": value for key, value in summary_metrics.items()}
168207

169208
metrics = {"loss": loss} | classifier_metrics | summary_metrics
170-
171209
return metrics
172210

211+
def _compute_summary_metrics(self, summary_variables: Tensor, stage: str) -> tuple[dict, Tensor | None]:
212+
"""Helper function to compute summary metrics and outputs."""
213+
if self.summary_network is None:
214+
return {}, None
215+
if summary_variables is None:
216+
raise ValueError("Summary variables are required when a summary network is present.")
217+
218+
if self.summary_variables_norm is not None:
219+
summary_variables = self.summary_variables_norm(summary_variables, stage=stage)
220+
221+
summary_metrics = self.summary_network.compute_metrics(summary_variables, stage=stage)
222+
summary_outputs = summary_metrics.pop("outputs")
223+
return summary_metrics, summary_outputs
224+
225+
def _combine_conditions(
226+
self, classifier_conditions: Tensor | None, summary_outputs: Tensor | None, stage
227+
) -> Tensor:
228+
"""Helper to combine classifier conditions and summary outputs, if present."""
229+
if classifier_conditions is None:
230+
return summary_outputs
231+
232+
if self.classifier_conditions_norm:
233+
classifier_conditions = self.classifier_conditions_norm(classifier_conditions, stage=stage)
234+
235+
if summary_outputs is None:
236+
return classifier_conditions
237+
return keras.ops.concatenate([classifier_conditions, summary_outputs], axis=-1)
238+
239+
def _compute_logits(self, classifier_conditions: Tensor) -> Tensor:
240+
"""Helper to compute projected logits from the classifier network."""
241+
logits = self.classifier_network(classifier_conditions)
242+
logits = self.logits_projector(logits)
243+
return logits
244+
173245
def fit(
174246
self,
175247
*,
176-
adapter: Adapter = "auto",
248+
adapter: Adapter | str = "auto",
177249
dataset: keras.utils.PyDataset = None,
178250
simulator: ModelComparisonSimulator = None,
179251
simulators: Sequence[Simulator] = None,
@@ -182,11 +254,13 @@ def fit(
182254
"""
183255
Trains the approximator on the provided dataset or on-demand generated from the given (multi-model) simulator.
184256
If `dataset` is not provided, a dataset is built from the `simulator`.
185-
If `simulator` is not provided, it will be build from a list of `simulators`.
257+
If `simulator` is not provided, it will be built from a list of `simulators`.
186258
If the model has not been built, it will be built using a batch from the dataset.
187259
188260
Parameters
189261
----------
262+
adapter : Adapter or str, optional
263+
The data adapter that will make the simulated / real outputs neural-network friendly.
190264
dataset : keras.utils.PyDataset, optional
191265
A dataset containing simulations for training. If provided, `simulator` must be None.
192266
simulator : ModelComparisonSimulator, optional
@@ -315,6 +389,13 @@ def predict(
315389

316390
conditions = keras.tree.map_structure(keras.ops.convert_to_tensor, conditions)
317391

392+
# Optionally standardize conditions
393+
if "summary_variables" in conditions and self.summary_variables_norm:
394+
conditions["summary_variables"] = self.summary_variables_norm(conditions["summary_variables"])
395+
396+
if "classifier_conditions" in conditions and self.classifier_conditions_norm:
397+
conditions["classifier_conditions"] = self.classifier_conditions_norm(conditions["classifier_conditions"])
398+
318399
output = self._predict(**conditions, **kwargs)
319400

320401
if not logits:
@@ -346,35 +427,33 @@ def _predict(self, classifier_conditions: Tensor = None, summary_variables: Tens
346427

347428
return output
348429

349-
def summaries(self, data: Mapping[str, np.ndarray], **kwargs):
430+
def summaries(self, data: Mapping[str, np.ndarray], **kwargs) -> np.ndarray:
350431
"""
351-
Computes the summaries of given data.
432+
Computes the learned summary statistics of given inputs.
352433
353434
The `data` dictionary is preprocessed using the `adapter` and passed through the summary network.
354435
355436
Parameters
356437
----------
357438
data : Mapping[str, np.ndarray]
358-
Dictionary of data as NumPy arrays.
439+
Dictionary of simulated or real quantities as NumPy arrays.
359440
**kwargs : dict
360441
Additional keyword arguments for the adapter and the summary network.
361442
362443
Returns
363444
-------
364445
summaries : np.ndarray
365-
Log-probabilities of the distribution `p(inference_variables | inference_conditions, h(summary_conditions))`
366-
367-
Raises
368-
------
369-
ValueError
370-
If the approximator does not have a summary network, or the adapter does not produce the output required
371-
by the summary network.
446+
The learned summary statistics.
372447
"""
373448
if self.summary_network is None:
374449
raise ValueError("A summary network is required to compute summaries.")
450+
375451
data_adapted = self.adapter(data, strict=False, stage="inference", **kwargs)
376452
if "summary_variables" not in data_adapted or data_adapted["summary_variables"] is None:
377453
raise ValueError("Summary variables are required to compute summaries.")
454+
378455
summary_variables = keras.ops.convert_to_tensor(data_adapted["summary_variables"])
379456
summaries = self.summary_network(summary_variables, **filter_kwargs(kwargs, self.summary_network.call))
457+
summaries = keras.ops.convert_to_numpy(summaries)
458+
380459
return summaries

0 commit comments

Comments
 (0)