1- from collections .abc import Sequence , Mapping
1+ from collections .abc import Sequence , Mapping , Callable
22
33import matplotlib .pyplot as plt
44import numpy as np
@@ -12,10 +12,10 @@ def recovery(
1212 targets : Mapping [str , np .ndarray ] | np .ndarray ,
1313 variable_keys : Sequence [str ] = None ,
1414 variable_names : Sequence [str ] = None ,
15- point_agg = np .median ,
16- uncertainty_agg = credible_interval ,
17- point_agg_kwargs = None ,
18- uncertainty_agg_kwargs = None ,
15+ point_agg : Callable = np .median ,
16+ uncertainty_agg : Callable = credible_interval ,
17+ point_agg_kwargs : dict = None ,
18+ uncertainty_agg_kwargs : dict = None ,
1919 add_corr : bool = True ,
2020 figsize : Sequence [int ] = None ,
2121 label_fontsize : int = 16 ,
@@ -58,13 +58,14 @@ def recovery(
5858 By default, select all keys.
5959 variable_names : list or None, optional, default: None
6060 The individual parameter names for nice plot titles. Inferred if None
61- point_agg : function to compute point estimates. Default: median
62- uncertainty_agg : function to compute uncertainty interval bounds.
63- Default: credible_interval with coverage probability 95%.
61+ point_agg : callable, optional, default: median
62+ Function to compute point estimates.
63+ uncertainty_agg : callable, optional, default: credible_interval with coverage probability 95%
64+ Function to compute uncertainty interval bounds.
6465 point_agg_kwargs : Optional dictionary of further arguments passed to point_agg.
6566 uncertainty_agg_kwargs : Optional dictionary of further arguments passed to uncertainty_agg.
6667 For example, to change the coverage probability of credible_interval to 50%,
67- use uncertainty_agg_kwargs = dict(prob = 0.5)
68+ use uncertainty_agg_kwargs = dict(prob= 0.5)
6869 add_corr : boolean, default: True
6970 Should correlations between estimates and ground truth values be shown?
7071 figsize : tuple or None, optional, default : None
@@ -112,11 +113,8 @@ def recovery(
112113 estimates = plot_data .pop ("estimates" )
113114 targets = plot_data .pop ("targets" )
114115
115- if point_agg_kwargs is None :
116- point_agg_kwargs = {}
117-
118- if uncertainty_agg_kwargs is None :
119- uncertainty_agg_kwargs = {}
116+ point_agg_kwargs = point_agg_kwargs or {}
117+ uncertainty_agg_kwargs = uncertainty_agg_kwargs or {}
120118
121119 # Compute point estimates and uncertainties
122120 point_estimate = point_agg (estimates , axis = 1 , ** point_agg_kwargs )
0 commit comments