@@ -61,7 +61,10 @@ def recovery(
6161 point_agg : callable, optional, default: median
6262 Function to compute point estimates.
6363 uncertainty_agg : callable, optional, default: credible_interval with coverage probability 95%
64- Function to compute uncertainty interval bounds.
64+ Function to compute a measure of uncertainty. Can either be the lower and upper
65+ uncertainty bounds provided with the shape (2, num_datasets, num_params) or a
66+ scalar measure of uncertainty (e.g., the median absolute deviation) with shape
67+ (num_datasets, num_params).
6568 point_agg_kwargs : Optional dictionary of further arguments passed to point_agg.
6669 uncertainty_agg_kwargs : Optional dictionary of further arguments passed to uncertainty_agg.
6770 For example, to change the coverage probability of credible_interval to 50%,
@@ -121,9 +124,10 @@ def recovery(
121124
122125 if uncertainty_agg is not None :
123126 u = uncertainty_agg (estimates , axis = 1 , ** uncertainty_agg_kwargs )
124- # compute lower and upper error
125- u [0 , :, :] = point_estimate - u [0 , :, :]
126- u [1 , :, :] = u [1 , :, :] - point_estimate
127+ if u .ndim == 3 :
128+ # compute lower and upper error
129+ u [0 , :, :] = point_estimate - u [0 , :, :]
130+ u [1 , :, :] = u [1 , :, :] - point_estimate
127131
128132 for i , ax in enumerate (plot_data ["axes" ].flat ):
129133 if i >= plot_data ["num_variables" ]:
@@ -134,7 +138,7 @@ def recovery(
134138 _ = ax .errorbar (
135139 targets [:, i ],
136140 point_estimate [:, i ],
137- yerr = u [:, : , i ],
141+ yerr = u [... , i ],
138142 fmt = "o" ,
139143 alpha = 0.5 ,
140144 color = color ,
0 commit comments