Skip to content

Commit 0b6caea

Browse files
committed
make point_arg_kwargs and uncertainty_agg_kwargs explicit arguments
1 parent 89007ce commit 0b6caea

File tree

1 file changed

+15
-3
lines changed

1 file changed

+15
-3
lines changed

bayesflow/diagnostics/plots/recovery.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@ def recovery(
1414
variable_names: Sequence[str] = None,
1515
point_agg=np.median,
1616
uncertainty_agg=credible_interval,
17+
point_agg_kwargs=None,
18+
uncertainty_agg_kwargs=None,
1719
add_corr: bool = True,
1820
figsize: Sequence[int] = None,
1921
label_fontsize: int = 16,
@@ -58,7 +60,11 @@ def recovery(
5860
The individual parameter names for nice plot titles. Inferred if None
5961
point_agg : function to compute point estimates. Default: median
6062
uncertainty_agg : function to compute uncertainty interval bounds.
61-
Default: credible_interval
63+
Default: credible_interval with coverage probability 95%.
64+
point_agg_kwargs : Optional dictionary of further arguments passed to point_agg.
65+
uncertainty_agg_kwargs : Optional dictionary of further arguments passed to uncertainty_agg.
66+
For example, to change the coverage probability of credible_interval to 50%,
67+
use uncertainty_agg_kwargs = dict(prob = 0.5)
6268
add_corr : boolean, default: True
6369
Should correlations between estimates and ground truth values be shown?
6470
figsize : tuple or None, optional, default : None
@@ -106,11 +112,17 @@ def recovery(
106112
estimates = plot_data.pop("estimates")
107113
targets = plot_data.pop("targets")
108114

115+
if point_agg_kwargs is None:
116+
point_agg_kwargs = {}
117+
118+
if uncertainty_agg_kwargs is None:
119+
uncertainty_agg_kwargs = {}
120+
109121
# Compute point estimates and uncertainties
110-
point_estimate = point_agg(estimates, axis=1, **kwargs.get("point_agg_kwargs", {}))
122+
point_estimate = point_agg(estimates, axis=1, **point_agg_kwargs)
111123

112124
if uncertainty_agg is not None:
113-
u = uncertainty_agg(estimates, axis=1, **kwargs.get("uncertainty_agg_kwargs", {}))
125+
u = uncertainty_agg(estimates, axis=1, **uncertainty_agg_kwargs)
114126
# compute lower and upper error
115127
u[0, :, :] = point_estimate - u[0, :, :]
116128
u[1, :, :] = u[1, :, :] - point_estimate

0 commit comments

Comments
 (0)