Skip to content

Commit 89007ce

Browse files
committed
Add kwargs for agg functions
1 parent a4432a2 commit 89007ce

File tree

2 files changed

+4
-5
lines changed

2 files changed

+4
-5
lines changed

bayesflow/diagnostics/plots/recovery.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@ def recovery(
1414
variable_names: Sequence[str] = None,
1515
point_agg=np.median,
1616
uncertainty_agg=credible_interval,
17-
prob=0.95,
1817
add_corr: bool = True,
1918
figsize: Sequence[int] = None,
2019
label_fontsize: int = 16,
@@ -108,10 +107,10 @@ def recovery(
108107
targets = plot_data.pop("targets")
109108

110109
# Compute point estimates and uncertainties
111-
point_estimate = point_agg(estimates, axis=1)
110+
point_estimate = point_agg(estimates, axis=1, **kwargs.get("point_agg_kwargs", {}))
112111

113112
if uncertainty_agg is not None:
114-
u = uncertainty_agg(estimates, prob=prob, axis=1)
113+
u = uncertainty_agg(estimates, axis=1, **kwargs.get("uncertainty_agg_kwargs", {}))
115114
# compute lower and upper error
116115
u[0, :, :] = point_estimate - u[0, :, :]
117116
u[1, :, :] = u[1, :, :] - point_estimate

bayesflow/utils/numpy_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def softplus(x: np.ndarray, beta: float = 1.0, threshold: float = 20.0) -> np.nd
4545
return np.where(beta * x > threshold, x, np.log1p(exp_beta_x) / beta)
4646

4747

48-
def credible_interval(x: np.ndarray, prob: float = 0.95, axis: Sequence[int] = None, **kwargs) -> np.ndarray:
48+
def credible_interval(x: np.ndarray, prob: float = 0.95, axis: Sequence[int] | int = None, **kwargs) -> np.ndarray:
4949
"""
5050
Compute credible interval from samples using quantiles.
5151
@@ -69,7 +69,7 @@ def credible_interval(x: np.ndarray, prob: float = 0.95, axis: Sequence[int] = N
6969
--------
7070
>>> import numpy as np
7171
>>> # Simulate posterior samples
72-
>>> samples = np.random.normal(10, 1000, 3)
72+
>>> samples = np.random.normal(size=(10, 1000, 3))
7373
7474
>>> # Different coverage probabilities
7575
>>> credible_interval(samples, prob=0.5, axis=1) # 50% CI

0 commit comments

Comments
 (0)