Skip to content

Commit 4d3130b

Browse files
committed
Add predict wrapper method
1 parent 71d8382 commit 4d3130b

File tree

1 file changed

+14
-0
lines changed

1 file changed

+14
-0
lines changed

bayesflow/approximators/approximator_ensemble.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99

1010
from .approximator import Approximator
11+
from .model_comparison_approximator import ModelComparisonApproximator
1112

1213

1314
class ApproximatorEnsemble(Approximator):
@@ -115,6 +116,19 @@ def estimate(
115116
estimates[approx_name] = approximator.estimate(conditions=conditions, split=split, **kwargs)
116117
return estimates
117118

119+
def predict(
120+
self,
121+
*,
122+
conditions: Mapping[str, np.ndarray],
123+
probs: bool = True,
124+
**kwargs,
125+
) -> dict[str, np.ndarray]:
126+
predictions = {}
127+
for approx_name, approximator in self.approximators.items():
128+
if isinstance(approximator, ModelComparisonApproximator):
129+
predictions[approx_name] = approximator.predict(conditions=conditions, probs=probs, **kwargs)
130+
return predictions
131+
118132
def _has_obj_method(self, obj, name):
119133
method = getattr(obj, name, None)
120134
return callable(method)

0 commit comments

Comments
 (0)