Skip to content

Commit 217ea69

Browse files
committed
add support for symmetric uncertainty measures
1 parent cc53741 commit 217ea69

File tree

1 file changed

+9
-5
lines changed

1 file changed

+9
-5
lines changed

bayesflow/diagnostics/plots/recovery.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,10 @@ def recovery(
6161
point_agg : callable, optional, default: median
6262
Function to compute point estimates.
6363
uncertainty_agg : callable, optional, default: credible_interval with coverage probability 95%
64-
Function to compute uncertainty interval bounds.
64+
Function to compute a measure of uncertainty. Can either be the lower and upper
65+
uncertainty bounds provided with the shape (2, num_datasets, num_params) or a
66+
scalar measure of uncertainty (e.g., the median absolute deviation) with shape
67+
(num_datasets, num_params).
6568
point_agg_kwargs : Optional dictionary of further arguments passed to point_agg.
6669
uncertainty_agg_kwargs : Optional dictionary of further arguments passed to uncertainty_agg.
6770
For example, to change the coverage probability of credible_interval to 50%,
@@ -121,9 +124,10 @@ def recovery(
121124

122125
if uncertainty_agg is not None:
123126
u = uncertainty_agg(estimates, axis=1, **uncertainty_agg_kwargs)
124-
# compute lower and upper error
125-
u[0, :, :] = point_estimate - u[0, :, :]
126-
u[1, :, :] = u[1, :, :] - point_estimate
127+
if u.ndim == 3:
128+
# compute lower and upper error
129+
u[0, :, :] = point_estimate - u[0, :, :]
130+
u[1, :, :] = u[1, :, :] - point_estimate
127131

128132
for i, ax in enumerate(plot_data["axes"].flat):
129133
if i >= plot_data["num_variables"]:
@@ -134,7 +138,7 @@ def recovery(
134138
_ = ax.errorbar(
135139
targets[:, i],
136140
point_estimate[:, i],
137-
yerr=u[:, :, i],
141+
yerr=u[..., i],
138142
fmt="o",
139143
alpha=0.5,
140144
color=color,

0 commit comments

Comments
 (0)