Skip to content

Commit c0c136f

Browse files
committed
Update docs
1 parent 800ed2c commit c0c136f

File tree

3 files changed

+63
-21
lines changed

3 files changed

+63
-21
lines changed

bayesflow/diagnostics/plots/pairs_posterior.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -39,25 +39,27 @@ def pairs_posterior(
3939
Optional true parameter values that have generated the observed dataset.
4040
priors : np.ndarray of shape (n_prior_draws, n_params) or None, optional (default: None)
4141
Optional prior samples obtained from the prior.
42-
dataset_id: Optional ID of the dataset for whose posterior the pairs plot shall be generated.
43-
Should only be specified if estimates contains posterior draws from multiple datasets.
42+
dataset_id: Optional ID of the dataset for whose posterior the pair plots shall be generated.
43+
Should only be specified if estimates contain posterior draws from multiple datasets.
4444
variable_keys : list or None, optional, default: None
4545
Select keys from the dictionary provided in samples.
4646
By default, select all keys.
4747
variable_names : list or None, optional, default: None
4848
The parameter names for nice plot titles. Inferred if None
4949
height : float, optional, default: 3
50-
The height of the pairplot
50+
The height of the pair plots
5151
label_fontsize : int, optional, default: 14
5252
The font size of the x and y-label texts (parameter names)
5353
tick_fontsize : int, optional, default: 12
54-
The font size of the axis ticklabels
54+
The font size of the axis tick labels
5555
legend_fontsize : int, optional, default: 16
5656
The font size of the legend text
5757
post_color : str, optional, default: '#132a70'
5858
The color for the posterior histograms and KDEs
5959
prior_color : str, optional, default: gray
6060
The color for the optional prior histograms and KDEs
61+
target_color : str, optional, default: red
62+
The color for the optional true parameter lines and points
6163
alpha : float in [0, 1], optional, default: 0.9
6264
The opacity of the posterior plots
6365
@@ -83,7 +85,7 @@ def pairs_posterior(
8385
variable_names=variable_names,
8486
)
8587

86-
# dicts_to_arrays will keep dataset axis even if it is of length 1
88+
# dicts_to_arrays will keep the dataset axis even if it is of length 1
8789
# however, pairs plotting requires the dataset axis to be removed
8890
estimates_shape = plot_data["estimates"].shape
8991
if len(estimates_shape) == 3 and estimates_shape[0] == 1:

bayesflow/diagnostics/plots/pairs_samples.py

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -40,13 +40,18 @@ def pairs_samples(
4040
height : float, optional, default: 2.5
4141
The height of the pair plot
4242
color : str, optional, default : '#8f2727'
43-
The color of the plot
43+
The primary color of the plot
4444
alpha : float in [0, 1], optional, default: 0.9
4545
The opacity of the plot
46+
label : str, optional, default: "Posterior"
47+
Label for the dataset to plot
4648
label_fontsize : int, optional, default: 14
4749
The font size of the x and y-label texts (parameter names)
4850
tick_fontsize : int, optional, default: 12
49-
The font size of the axis ticklabels
51+
The font size of the axis tick labels
52+
show_single_legend : bool, optional, default: False
53+
Optional toggle for the user to choose whether a single dataset
54+
should also display legend
5055
**kwargs : dict, optional
5156
Additional keyword arguments passed to the sns.PairGrid constructor
5257
"""
@@ -85,12 +90,20 @@ def _pairs_samples(
8590
show_single_legend: bool = False,
8691
**kwargs,
8792
) -> sns.PairGrid:
88-
# internal version of pairs_samples creating the seaborn plot
93+
"""
94+
Internal version of pairs_samples creating the seaborn PairPlot
95+
for both a single dataset and multiple datasets.
8996
90-
# Parameters
91-
# ----------
92-
# plot_data : output of bayesflow.utils.dict_utils.dicts_to_arrays
93-
# other arguments are documented in pairs_samples
97+
Parameters
98+
----------
99+
plot_data : output of bayesflow.utils.dict_utils.dicts_to_arrays
100+
Formatted data to plot from the sample dataset
101+
color2 : str, optional, default: 'gray'
102+
Secondary color for the pair plots.
103+
This is the color used for the prior draws.
104+
105+
Other arguments are documented in pairs_samples
106+
"""
94107

95108
estimates_shape = plot_data["estimates"].shape
96109
if len(estimates_shape) != 2:
@@ -144,7 +157,7 @@ def _pairs_samples(
144157
common_norm=False,
145158
)
146159

147-
# add scatterplots to the upper diagonal
160+
# add scatter plots to the upper diagonal
148161
g.map_upper(sns.scatterplot, alpha=0.6, s=40, edgecolor="k", color=color, lw=0)
149162

150163
# add KDEs to the lower diagonal
@@ -168,7 +181,7 @@ def _pairs_samples(
168181
g.axes[i, j].tick_params(axis="both", which="major", labelsize=tick_fontsize)
169182
g.axes[i, j].tick_params(axis="both", which="minor", labelsize=tick_fontsize)
170183

171-
# adjust font size of labels
184+
# adjust the font size of labels
172185
# the labels themselves remain the same as before, i.e., variable_names
173186
g.axes[i, 0].set_ylabel(variable_names[i], fontsize=label_fontsize)
174187
g.axes[dim - 1, i].set_xlabel(variable_names[i], fontsize=label_fontsize)
@@ -196,7 +209,7 @@ def _pairs_samples(
196209

197210
def histplot_twinx(x, **kwargs):
198211
"""
199-
# create a histogram plot on a twin y axis
212+
# create a histogram plot on a twin y-axis
200213
# this ensures that the y scaling of the diagonal plots
201214
# in independent of the y scaling of the off-diagonal plots
202215

bayesflow/utils/plot_utils.py

Lines changed: 33 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def prepare_plot_data(
6767
)
6868
check_estimates_prior_shapes(plot_data["estimates"], plot_data["targets"])
6969

70-
# store variable information at top level for easy access
70+
# store variable information at the top level for easy access
7171
variable_names = plot_data["estimates"].variable_names
7272
num_variables = len(variable_names)
7373
plot_data["variable_names"] = variable_names
@@ -249,7 +249,7 @@ def prettify_subplots(axes: np.ndarray, num_subplots: int, tick: bool = True, ti
249249

250250
def make_quadratic(ax: plt.Axes, x_data: np.ndarray, y_data: np.ndarray):
251251
"""
252-
Utility to make a subplots quadratic in order to avoid visual illusions
252+
Utility to make subplots quadratic to avoid visual illusions
253253
in, e.g., recovery plots.
254254
"""
255255

@@ -269,7 +269,7 @@ def make_quadratic(ax: plt.Axes, x_data: np.ndarray, y_data: np.ndarray):
269269

270270
def gradient_line(x, y, c=None, cmap: str = "viridis", lw: float = 2.0, alpha: float = 1, ax=None):
271271
"""
272-
Plot a 1D line with color gradient determined by `c` (same shape as x and y).
272+
Plot a 1D line with a color gradient determined by `c` (same shape as x and y).
273273
"""
274274
if ax is None:
275275
ax = plt.gca()
@@ -304,7 +304,7 @@ def gradient_legend(ax, label, cmap, norm, loc="upper right"):
304304
- loc: legend location (default 'upper right')
305305
"""
306306

307-
# Custom dummy handle to represent the gradient
307+
# Custom placeholder handle to represent the gradient
308308
class _GradientSwatch(Rectangle):
309309
pass
310310

@@ -361,8 +361,35 @@ def add_gradient_plot(
361361

362362

363363
def create_legends(
364-
g, plot_data, color, color2, label: str = "Posterior", show_single_legend: bool = False, fontsize: int = 14
364+
g,
365+
plot_data: dict,
366+
color: str | tuple = "#132a70",
367+
color2: str | tuple = "gray",
368+
label: str = "Posterior",
369+
show_single_legend: bool = False,
370+
legend_fontsize: int = 14,
365371
):
372+
"""
373+
Helper function to create legends for pairplots.
374+
375+
Parameters
376+
----------
377+
g : sns.PairGrid
378+
Seaborn object for the pair plots
379+
plot_data : output of bayesflow.utils.dict_utils.dicts_to_arrays
380+
Formatted data to plot from the sample dataset
381+
color : str, optional, default : '#8f2727'
382+
The primary color of the plot
383+
color2 : str, optional, default: 'gray'
384+
The secondary color for the plot
385+
label : str, optional, default: "Posterior"
386+
Label for the dataset to plot
387+
show_single_legend : bool, optional, default: False
388+
Optional toggle for the user to choose whether a single dataset
389+
should also display legend
390+
legend_fontsize : int, optional, default: 14
391+
fontsize for the legend
392+
"""
366393
handles = []
367394
labels = []
368395

@@ -391,5 +418,5 @@ def create_legends(
391418
loc="center left",
392419
bbox_to_anchor=(1, 0.5),
393420
frameon=False,
394-
fontsize=fontsize,
421+
fontsize=legend_fontsize,
395422
)

0 commit comments

Comments
 (0)