Skip to content

Commit 85d8149

Browse files
committed
take batch size into account when aggregating metrics
1 parent 777507e commit 85d8149

File tree

7 files changed

+47
-18
lines changed

7 files changed

+47
-18
lines changed

bayesflow/approximators/approximator.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,3 +137,26 @@ def fit(self, *, dataset: keras.utils.PyDataset = None, simulator: Simulator = N
137137
self.build_from_data(mock_data)
138138

139139
return super().fit(dataset=dataset, **kwargs)
140+
141+
def _batch_size_from_data(self, data: any):
142+
"""Obtain the batch size from a batch of data.
143+
144+
To properly weight the metrics for batches of different sizes, the batch size of a given batch of data is
145+
required. As the data structure differs between approximators, each approximator has to specify this method.
146+
147+
Parameters
148+
----------
149+
data :
150+
The data that are passed to `compute_metrics` as keyword arguments.
151+
152+
Returns
153+
-------
154+
batch_size : int
155+
The batch size of the given data.
156+
"""
157+
raise NotImplementedError(
158+
"Correct calculation of the metrics requires obtaining the batch size from the supplied data "
159+
"for proper weighting of metrics for batches with different sizes. Please implement the "
160+
"_batch_size_from_data method for your approximator. For a given batch of data, it should "
161+
"return the corresponding batch size."
162+
)

bayesflow/approximators/backend_approximators/jax_approximator.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def stateless_test_step(self, state: tuple, data: dict[str, any]) -> (dict[str,
5555
)
5656
metrics, non_trainable_variables, metrics_variables = aux
5757

58-
metrics_variables = self._update_metrics(loss, metrics_variables)
58+
metrics_variables = self._update_metrics(loss, metrics_variables, self._batch_size_from_data(data))
5959

6060
state = trainable_variables, non_trainable_variables, metrics_variables
6161
return metrics, state
@@ -74,7 +74,7 @@ def stateless_train_step(self, state: tuple, data: dict[str, any]) -> (dict[str,
7474
optimizer_variables, grads, trainable_variables
7575
)
7676

77-
metrics_variables = self._update_metrics(loss, metrics_variables)
77+
metrics_variables = self._update_metrics(loss, metrics_variables, self._batch_size_from_data(data))
7878

7979
state = trainable_variables, non_trainable_variables, optimizer_variables, metrics_variables
8080
return metrics, state
@@ -85,11 +85,11 @@ def test_step(self, *args, **kwargs):
8585
def train_step(self, *args, **kwargs):
8686
return self.stateless_train_step(*args, **kwargs)
8787

88-
def _update_metrics(self, loss: jax.Array, metrics_variables: any) -> any:
88+
def _update_metrics(self, loss: jax.Array, metrics_variables: any, sample_weight: any = None) -> any:
8989
# update the loss progress bar, and possibly metrics variables along with it
9090
state_mapping = list(zip(self.metrics_variables, metrics_variables))
9191
with keras.StatelessScope(state_mapping) as scope:
92-
self._loss_tracker.update_state(loss)
92+
self._loss_tracker.update_state(loss, sample_weight=sample_weight)
9393

9494
metrics_variables = [scope.get_current_value(v) for v in self.metrics_variables]
9595

bayesflow/approximators/backend_approximators/numpy_approximator.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,17 +13,17 @@ def compute_metrics(self, *args, **kwargs) -> dict[str, np.ndarray]:
1313
def test_step(self, data: dict[str, any]) -> dict[str, np.ndarray]:
1414
kwargs = filter_kwargs(data | {"stage": "validation"}, self.compute_metrics)
1515
metrics = self.compute_metrics(**kwargs)
16-
self._update_metrics(metrics)
16+
self._update_metrics(metrics, self._batch_size_from_data(data))
1717
return metrics
1818

1919
def train_step(self, data: dict[str, any]) -> dict[str, np.ndarray]:
2020
raise NotImplementedError("Numpy backend does not support training.")
2121

22-
def _update_metrics(self, metrics):
22+
def _update_metrics(self, metrics, sample_weight=None):
2323
for name, value in metrics.items():
2424
try:
2525
metric_index = self.metrics_names.index(name)
26-
self.metrics[metric_index].update_state(value)
26+
self.metrics[metric_index].update_state(value, sample_weight=sample_weight)
2727
except ValueError:
2828
self._metrics.append(keras.metrics.Mean(name=name))
29-
self._metrics[-1].update_state(value)
29+
self._metrics[-1].update_state(value, sample_weight=sample_weight)

bayesflow/approximators/backend_approximators/tensorflow_approximator.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ def compute_metrics(self, *args, **kwargs) -> dict[str, tf.Tensor]:
1313
def test_step(self, data: dict[str, any]) -> dict[str, tf.Tensor]:
1414
kwargs = filter_kwargs(data | {"stage": "validation"}, self.compute_metrics)
1515
metrics = self.compute_metrics(**kwargs)
16-
self._update_metrics(metrics)
16+
self._update_metrics(metrics, self._batch_size_from_data(data))
1717
return metrics
1818

1919
def train_step(self, data: dict[str, any]) -> dict[str, tf.Tensor]:
@@ -26,14 +26,14 @@ def train_step(self, data: dict[str, any]) -> dict[str, tf.Tensor]:
2626
grads = tape.gradient(loss, self.trainable_variables)
2727
self.optimizer.apply_gradients(zip(grads, self.trainable_variables))
2828

29-
self._update_metrics(metrics)
29+
self._update_metrics(metrics, self._batch_size_from_data(data))
3030
return metrics
3131

32-
def _update_metrics(self, metrics):
32+
def _update_metrics(self, metrics, sample_weight=None):
3333
for name, value in metrics.items():
3434
try:
3535
metric_index = self.metrics_names.index(name)
36-
self.metrics[metric_index].update_state(value)
36+
self.metrics[metric_index].update_state(value, sample_weight=sample_weight)
3737
except ValueError:
3838
self._metrics.append(keras.metrics.Mean(name=name))
39-
self._metrics[-1].update_state(value)
39+
self._metrics[-1].update_state(value, sample_weight=sample_weight)

bayesflow/approximators/backend_approximators/torch_approximator.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ def compute_metrics(self, *args, **kwargs) -> dict[str, torch.Tensor]:
1313
def test_step(self, data: dict[str, any]) -> dict[str, torch.Tensor]:
1414
kwargs = filter_kwargs(data | {"stage": "validation"}, self.compute_metrics)
1515
metrics = self.compute_metrics(**kwargs)
16-
self._update_metrics(metrics)
16+
self._update_metrics(metrics, self._batch_size_from_data(data))
1717
return metrics
1818

1919
def train_step(self, data: dict[str, any]) -> dict[str, torch.Tensor]:
@@ -34,14 +34,14 @@ def train_step(self, data: dict[str, any]) -> dict[str, torch.Tensor]:
3434
with torch.no_grad():
3535
self.optimizer.apply(gradients, trainable_weights)
3636

37-
self._update_metrics(metrics)
37+
self._update_metrics(metrics, self._batch_size_from_data(data))
3838
return metrics
3939

40-
def _update_metrics(self, metrics):
40+
def _update_metrics(self, metrics, sample_weight=None):
4141
for name, value in metrics.items():
4242
try:
4343
metric_index = self.metrics_names.index(name)
44-
self.metrics[metric_index].update_state(value)
44+
self.metrics[metric_index].update_state(value, sample_weight=sample_weight)
4545
except ValueError:
4646
self._metrics.append(keras.metrics.Mean(name=name))
47-
self._metrics[-1].update_state(value)
47+
self._metrics[-1].update_state(value, sample_weight=sample_weight)

bayesflow/approximators/continuous_approximator.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -491,3 +491,6 @@ def _log_prob(
491491
conditions=inference_conditions,
492492
**filter_kwargs(kwargs, self.inference_network.log_prob),
493493
)
494+
495+
def _batch_size_from_data(self, data: Mapping[str, any]):
496+
return keras.ops.shape(data["inference_variables"])[0]

bayesflow/approximators/model_comparison_approximator.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -378,3 +378,6 @@ def summaries(self, data: Mapping[str, np.ndarray], **kwargs):
378378
summary_variables = keras.ops.convert_to_tensor(data_adapted["summary_variables"])
379379
summaries = self.summary_network(summary_variables, **filter_kwargs(kwargs, self.summary_network.call))
380380
return summaries
381+
382+
def _batch_size_from_data(self, data: Mapping[str, any]):
383+
return keras.ops.shape(data["model_indices"])[0]

0 commit comments

Comments
 (0)