1- from collections .abc import Sequence , Mapping
1+ from collections .abc import Sequence , Mapping , Callable
22
33import matplotlib .pyplot as plt
44import numpy as np
55
6- from scipy .stats import median_abs_deviation
7-
86from bayesflow .utils import prepare_plot_data , prettify_subplots , make_quadratic , add_titles_and_labels , add_metric
7+ from bayesflow .utils .numpy_utils import credible_interval
98
109
1110def recovery (
1211 estimates : Mapping [str , np .ndarray ] | np .ndarray ,
1312 targets : Mapping [str , np .ndarray ] | np .ndarray ,
1413 variable_keys : Sequence [str ] = None ,
1514 variable_names : Sequence [str ] = None ,
16- point_agg = np .median ,
17- uncertainty_agg = median_abs_deviation ,
15+ point_agg : Callable = np .median ,
16+ uncertainty_agg : Callable = credible_interval ,
17+ point_agg_kwargs : dict = None ,
18+ uncertainty_agg_kwargs : dict = None ,
1819 add_corr : bool = True ,
1920 figsize : Sequence [int ] = None ,
2021 label_fontsize : int = 16 ,
@@ -57,8 +58,17 @@ def recovery(
5758 By default, select all keys.
5859 variable_names : list or None, optional, default: None
5960 The individual parameter names for nice plot titles. Inferred if None
60- point_agg : function to compute point estimates. Default: median
61- uncertainty_agg : function to compute uncertainty estimates. Default: MAD
61+ point_agg : callable, optional, default: median
62+ Function to compute point estimates.
63+ uncertainty_agg : callable, optional, default: credible_interval with coverage probability 95%
64+ Function to compute a measure of uncertainty. Can either be the lower and upper
65+ uncertainty bounds provided with the shape (2, num_datasets, num_params) or a
66+ scalar measure of uncertainty (e.g., the median absolute deviation) with shape
67+ (num_datasets, num_params).
68+ point_agg_kwargs : Optional dictionary of further arguments passed to point_agg.
69+ uncertainty_agg_kwargs : Optional dictionary of further arguments passed to uncertainty_agg.
70+ For example, to change the coverage probability of credible_interval to 50%,
71+ use uncertainty_agg_kwargs = dict(prob=0.5)
6272 add_corr : boolean, default: True
6373 Should correlations between estimates and ground truth values be shown?
6474 figsize : tuple or None, optional, default : None
@@ -106,11 +116,18 @@ def recovery(
106116 estimates = plot_data .pop ("estimates" )
107117 targets = plot_data .pop ("targets" )
108118
119+ point_agg_kwargs = point_agg_kwargs or {}
120+ uncertainty_agg_kwargs = uncertainty_agg_kwargs or {}
121+
109122 # Compute point estimates and uncertainties
110- point_estimate = point_agg (estimates , axis = 1 )
123+ point_estimate = point_agg (estimates , axis = 1 , ** point_agg_kwargs )
111124
112125 if uncertainty_agg is not None :
113- u = uncertainty_agg (estimates , axis = 1 )
126+ u = uncertainty_agg (estimates , axis = 1 , ** uncertainty_agg_kwargs )
127+ if u .ndim == 3 :
128+ # compute lower and upper error
129+ u [0 , :, :] = point_estimate - u [0 , :, :]
130+ u [1 , :, :] = u [1 , :, :] - point_estimate
114131
115132 for i , ax in enumerate (plot_data ["axes" ].flat ):
116133 if i >= plot_data ["num_variables" ]:
@@ -121,7 +138,7 @@ def recovery(
121138 _ = ax .errorbar (
122139 targets [:, i ],
123140 point_estimate [:, i ],
124- yerr = u [: , i ],
141+ yerr = u [... , i ],
125142 fmt = "o" ,
126143 alpha = 0.5 ,
127144 color = color ,
0 commit comments