Skip to content

Commit 211b20f

Browse files
committed
Draft of ApproximatorEnsemble
1 parent 2a19d32 commit 211b20f

File tree

2 files changed

+90
-0
lines changed

2 files changed

+90
-0
lines changed

bayesflow/approximators/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
from .point_approximator import PointApproximator
99
from .model_comparison_approximator import ModelComparisonApproximator
1010

11+
from .approximator_ensemble import ApproximatorEnsemble
12+
1113
from ..utils._docs import _add_imports_to_all
1214

1315
_add_imports_to_all(include_modules=[])
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
from collections.abc import Mapping
2+
3+
import numpy as np
4+
5+
import keras
6+
7+
from bayesflow.types import Tensor
8+
9+
10+
from .approximator import Approximator
11+
12+
13+
class ApproximatorEnsemble(Approximator):
14+
def __init__(self, approximators: dict[str, Approximator], **kwargs):
15+
super().__init__(**kwargs)
16+
17+
self.approximators = approximators
18+
19+
self.num_approximators = len(self.approximators)
20+
21+
def build(self, data_shapes: dict[str, tuple[int] | dict[str, dict]]) -> None:
22+
for approximator in self.approximators.values():
23+
approximator.build(data_shapes)
24+
25+
def compute_metrics(
26+
self,
27+
inference_variables: Tensor,
28+
inference_conditions: Tensor = None,
29+
summary_variables: Tensor = None,
30+
sample_weight: Tensor = None,
31+
stage: str = "training",
32+
) -> dict[str, dict[str, Tensor]]:
33+
metrics = {}
34+
for approx_name, approximator in self.approximators.items():
35+
# TODO: actually do the slicing
36+
inference_variables_slice = inference_variables
37+
inference_conditions_slice = inference_conditions
38+
summary_variables_slice = summary_variables
39+
sample_weight_slice = sample_weight
40+
41+
metrics[approx_name] = approximator.compute_metrics(
42+
inference_variables=inference_variables_slice,
43+
inference_conditions=inference_conditions_slice,
44+
summary_variables=summary_variables_slice,
45+
sample_weight=sample_weight_slice,
46+
stage=stage,
47+
)
48+
49+
# Flatten metrics dict
50+
joint_metrics = {}
51+
for approx_name in metrics.keys():
52+
for metric_key, value in metrics[approx_name].items():
53+
joint_metrics[f"{approx_name}/{metric_key}"] = value
54+
55+
metrics = joint_metrics
56+
57+
# Sum over losses
58+
losses = [v for k, v in metrics.items() if "loss" in k]
59+
metrics["loss"] = keras.ops.sum(losses)
60+
61+
return metrics
62+
63+
def sample(
64+
self,
65+
*,
66+
num_samples: int,
67+
conditions: Mapping[str, np.ndarray],
68+
split: bool = False,
69+
**kwargs,
70+
) -> dict[str, np.ndarray]:
71+
samples = {}
72+
for approx_name, approximator in self.approximators.items():
73+
if self._has_obj_method(approximator, "sample"):
74+
samples[approx_name] = approximator.sample(
75+
num_samples=num_samples, conditions=conditions, split=split, **kwargs
76+
)
77+
return samples
78+
79+
def _has_obj_method(self, obj, name):
80+
method = getattr(obj, name, None)
81+
return callable(method)
82+
83+
def _batch_size_from_data(self, data: Mapping[str, any]) -> int:
84+
"""
85+
Fetches the current batch size from an input dictionary. Can only be used during training when
86+
inference variables as present.
87+
"""
88+
return keras.ops.shape(data["inference_variables"])[0]

0 commit comments

Comments
 (0)