Skip to content

Commit d8fa571

Browse files
committed
Slice training batch, giving every ensemble member independent training signal
1 parent 211b20f commit d8fa571

File tree

1 file changed

+26
-11
lines changed

1 file changed

+26
-11
lines changed

bayesflow/approximators/approximator_ensemble.py

Lines changed: 26 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@ def __init__(self, approximators: dict[str, Approximator], **kwargs):
1919
self.num_approximators = len(self.approximators)
2020

2121
def build(self, data_shapes: dict[str, tuple[int] | dict[str, dict]]) -> None:
22+
# Remove the ensemble dimension from data_shapes. This expects data_shapes are the shapes of a
23+
# batch of training data, where the second axis corresponds to different approximators.
24+
data_shapes = {k: v[:1] + v[2:] for k, v in data_shapes.items()}
2225
for approximator in self.approximators.values():
2326
approximator.build(data_shapes)
2427

@@ -30,19 +33,32 @@ def compute_metrics(
3033
sample_weight: Tensor = None,
3134
stage: str = "training",
3235
) -> dict[str, dict[str, Tensor]]:
36+
# Prepare empty dict for metrics
3337
metrics = {}
34-
for approx_name, approximator in self.approximators.items():
35-
# TODO: actually do the slicing
36-
inference_variables_slice = inference_variables
37-
inference_conditions_slice = inference_conditions
38-
summary_variables_slice = summary_variables
39-
sample_weight_slice = sample_weight
38+
39+
# Define the variable slices as None (default) or respective input
40+
_inference_variables = inference_variables
41+
_inference_conditions = inference_conditions
42+
_summary_variables = summary_variables
43+
_sample_weight = sample_weight
44+
45+
for i, (approx_name, approximator) in enumerate(self.approximators.items()):
46+
# During training each approximator receives its own separate slice
47+
if stage == "training":
48+
# Pick out the correct slice for each ensemble member
49+
_inference_variables = inference_variables[:, i]
50+
if inference_conditions is not None:
51+
_inference_conditions = inference_conditions[:, i]
52+
if summary_variables is not None:
53+
_summary_variables = summary_variables[:, i]
54+
if sample_weight is not None:
55+
_sample_weight = sample_weight[:, i]
4056

4157
metrics[approx_name] = approximator.compute_metrics(
42-
inference_variables=inference_variables_slice,
43-
inference_conditions=inference_conditions_slice,
44-
summary_variables=summary_variables_slice,
45-
sample_weight=sample_weight_slice,
58+
inference_variables=_inference_variables,
59+
inference_conditions=_inference_conditions,
60+
summary_variables=_summary_variables,
61+
sample_weight=_sample_weight,
4662
stage=stage,
4763
)
4864

@@ -51,7 +67,6 @@ def compute_metrics(
5167
for approx_name in metrics.keys():
5268
for metric_key, value in metrics[approx_name].items():
5369
joint_metrics[f"{approx_name}/{metric_key}"] = value
54-
5570
metrics = joint_metrics
5671

5772
# Sum over losses

0 commit comments

Comments
 (0)