Skip to content

Commit 711cd7b

Browse files
committed
Add log_prob, estimate and predict wrapper methods; flexible building
1 parent 70bb230 commit 711cd7b

File tree

1 file changed

+28
-4
lines changed

1 file changed

+28
-4
lines changed

bayesflow/approximators/approximator_ensemble.py

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,15 @@ def __init__(self, approximators: dict[str, Approximator], **kwargs):
1818

1919
self.num_approximators = len(self.approximators)
2020

21+
def build_from_data(self, adapted_data: dict[str, any]):
22+
data_shapes = keras.tree.map_structure(keras.ops.shape, adapted_data)
23+
if len(data_shapes["inference_variables"]) > 2:
24+
# Remove the ensemble dimension from data_shapes. This expects data_shapes are the shapes of a
25+
# batch of training data, where the second axis corresponds to different approximators.
26+
data_shapes = {k: v[:1] + v[2:] for k, v in data_shapes.items()}
27+
self.build(data_shapes)
28+
2129
def build(self, data_shapes: dict[str, tuple[int] | dict[str, dict]]) -> None:
22-
# Remove the ensemble dimension from data_shapes. This expects data_shapes are the shapes of a
23-
# batch of training data, where the second axis corresponds to different approximators.
24-
data_shapes = {k: v[:1] + v[2:] for k, v in data_shapes.items()}
2530
for approximator in self.approximators.values():
2631
approximator.build(data_shapes)
2732

@@ -82,7 +87,7 @@ def sample(
8287
conditions: Mapping[str, np.ndarray],
8388
split: bool = False,
8489
**kwargs,
85-
) -> dict[str, np.ndarray]:
90+
) -> dict[str, dict[str, np.ndarray]]:
8691
samples = {}
8792
for approx_name, approximator in self.approximators.items():
8893
if self._has_obj_method(approximator, "sample"):
@@ -91,6 +96,25 @@ def sample(
9196
)
9297
return samples
9398

99+
def log_prob(self, data: Mapping[str, np.ndarray], **kwargs) -> dict[str, np.ndarray]:
100+
log_prob = {}
101+
for approx_name, approximator in self.approximators.items():
102+
if self._has_obj_method(approximator, "log_prob"):
103+
log_prob[approx_name] = approximator.log_prob(data=data, **kwargs)
104+
return log_prob
105+
106+
def estimate(
107+
self,
108+
conditions: Mapping[str, np.ndarray],
109+
split: bool = False,
110+
**kwargs,
111+
) -> dict[str, dict[str, dict[str, np.ndarray]]]:
112+
estimates = {}
113+
for approx_name, approximator in self.approximators.items():
114+
if self._has_obj_method(approximator, "estimate"):
115+
estimates[approx_name] = approximator.estimate(conditions=conditions, split=split, **kwargs)
116+
return estimates
117+
94118
def _has_obj_method(self, obj, name):
95119
method = getattr(obj, name, None)
96120
return callable(method)

0 commit comments

Comments
 (0)