Skip to content

Commit cc53741

Browse files
committed
adapt docs, minor stylistic changes
1 parent 0b6caea commit cc53741

File tree

1 file changed

+12
-14
lines changed

1 file changed

+12
-14
lines changed

bayesflow/diagnostics/plots/recovery.py

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from collections.abc import Sequence, Mapping
1+
from collections.abc import Sequence, Mapping, Callable
22

33
import matplotlib.pyplot as plt
44
import numpy as np
@@ -12,10 +12,10 @@ def recovery(
1212
targets: Mapping[str, np.ndarray] | np.ndarray,
1313
variable_keys: Sequence[str] = None,
1414
variable_names: Sequence[str] = None,
15-
point_agg=np.median,
16-
uncertainty_agg=credible_interval,
17-
point_agg_kwargs=None,
18-
uncertainty_agg_kwargs=None,
15+
point_agg: Callable = np.median,
16+
uncertainty_agg: Callable = credible_interval,
17+
point_agg_kwargs: dict = None,
18+
uncertainty_agg_kwargs: dict = None,
1919
add_corr: bool = True,
2020
figsize: Sequence[int] = None,
2121
label_fontsize: int = 16,
@@ -58,13 +58,14 @@ def recovery(
5858
By default, select all keys.
5959
variable_names : list or None, optional, default: None
6060
The individual parameter names for nice plot titles. Inferred if None
61-
point_agg : function to compute point estimates. Default: median
62-
uncertainty_agg : function to compute uncertainty interval bounds.
63-
Default: credible_interval with coverage probability 95%.
61+
point_agg : callable, optional, default: median
62+
Function to compute point estimates.
63+
uncertainty_agg : callable, optional, default: credible_interval with coverage probability 95%
64+
Function to compute uncertainty interval bounds.
6465
point_agg_kwargs : Optional dictionary of further arguments passed to point_agg.
6566
uncertainty_agg_kwargs : Optional dictionary of further arguments passed to uncertainty_agg.
6667
For example, to change the coverage probability of credible_interval to 50%,
67-
use uncertainty_agg_kwargs = dict(prob = 0.5)
68+
use uncertainty_agg_kwargs = dict(prob=0.5)
6869
add_corr : boolean, default: True
6970
Should correlations between estimates and ground truth values be shown?
7071
figsize : tuple or None, optional, default : None
@@ -112,11 +113,8 @@ def recovery(
112113
estimates = plot_data.pop("estimates")
113114
targets = plot_data.pop("targets")
114115

115-
if point_agg_kwargs is None:
116-
point_agg_kwargs = {}
117-
118-
if uncertainty_agg_kwargs is None:
119-
uncertainty_agg_kwargs = {}
116+
point_agg_kwargs = point_agg_kwargs or {}
117+
uncertainty_agg_kwargs = uncertainty_agg_kwargs or {}
120118

121119
# Compute point estimates and uncertainties
122120
point_estimate = point_agg(estimates, axis=1, **point_agg_kwargs)

0 commit comments

Comments
 (0)