Skip to content

Commit dcc1dfd

Browse files
Merge pull request #109 from stefanradev93/lukas
add n_row and n_col argument where applicable
2 parents 3ade306 + d22a996 commit dcc1dfd

File tree

1 file changed

+53
-10
lines changed

1 file changed

+53
-10
lines changed

bayesflow/diagnostics.py

Lines changed: 53 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,10 @@ def plot_recovery(
9898
A flag for adding R^2 between true and estimates to the plot
9999
color : str, optional, default: '#8f2727'
100100
The color for the true vs. estimated scatter points and error bars
101+
n_row : int, optional, default: None
102+
The number of rows for the subplots. Dynamically determined if None.
103+
n_col : int, optional, default: None
104+
The number of columns for the subplots. Dynamically determined if None.
101105
xlabel : str, optional, default: 'Ground truth'
102106
The label on the x-axis of the plot
103107
ylabel : str, optional, default: 'Estimated'
@@ -232,7 +236,7 @@ def plot_z_score_contraction(
232236
tick_fontsize=12,
233237
color="#8f2727",
234238
n_col=None,
235-
n_row=None,
239+
n_row=None
236240
):
237241
"""Implements a graphical check for global model sensitivity by plotting the posterior
238242
z-score over the posterior contraction for each set of posterior samples in ``post_samples``
@@ -279,6 +283,10 @@ def plot_z_score_contraction(
279283
The font size of the axis ticklabels
280284
color : str, optional, default: '#8f2727'
281285
The color for the true vs. estimated scatter points and error bars
286+
n_row : int, optional, default: None
287+
The number of rows for the subplots. Dynamically determined if None.
288+
n_col : int, optional, default: None
289+
The number of columns for the subplots. Dynamically determined if None.
282290
283291
Returns
284292
-------
@@ -379,6 +387,8 @@ def plot_sbc_ecdf(
379387
tick_fontsize=12,
380388
rank_ecdf_color="#a34f4f",
381389
fill_color="grey",
390+
n_row=None,
391+
n_col=None,
382392
**kwargs,
383393
):
384394
"""Creates the empirical CDFs for each marginal rank distribution and plots it against
@@ -419,6 +429,10 @@ def plot_sbc_ecdf(
419429
The color to use for the rank ECDFs
420430
fill_color : str, optional, default: 'grey'
421431
The color of the fill arguments.
432+
n_row : int, optional, default: None
433+
The number of rows for the subplots. Dynamically determined if None.
434+
n_col : int, optional, default: None
435+
The number of columns for the subplots. Dynamically determined if None.
422436
**kwargs : dict, optional, default: {}
423437
Keyword arguments can be passed to control the behavior of ECDF simultaneous band computation
424438
through the ``ecdf_bands_kwargs`` dictionary. See `simultaneous_ecdf_bands` for keyword arguments
@@ -447,9 +461,14 @@ def plot_sbc_ecdf(
447461
n_row, n_col = 1, 1
448462
f, ax = plt.subplots(1, 1, figsize=fig_size)
449463
else:
450-
# Determine n_subplots dynamically
451-
n_row = int(np.ceil(n_params / 6))
452-
n_col = int(np.ceil(n_params / n_row))
464+
# Determine number of rows and columns for subplots based on inputs
465+
if n_row is None and n_col is None:
466+
n_row = int(np.ceil(n_params / 6))
467+
n_col = int(np.ceil(n_params / n_row))
468+
elif n_row is None and n_col is not None:
469+
n_row = int(np.ceil(n_params / n_col))
470+
elif n_row is not None and n_col is None:
471+
n_col = int(np.ceil(n_params / n_row))
453472

454473
# Determine fig_size dynamically, if None
455474
if fig_size is None:
@@ -543,6 +562,8 @@ def plot_sbc_histograms(
543562
title_fontsize=18,
544563
tick_fontsize=12,
545564
hist_color="#a34f4f",
565+
n_row=None,
566+
n_col=None
546567
):
547568
"""Creates and plots publication-ready histograms of rank statistics for simulation-based calibration
548569
(SBC) checks according to [1].
@@ -576,6 +597,10 @@ def plot_sbc_histograms(
576597
The font size of the axis ticklabels
577598
hist_color : str, optional, default '#a34f4f'
578599
The color to use for the histogram body
600+
n_row : int, optional, default: None
601+
The number of rows for the subplots. Dynamically determined if None.
602+
n_col : int, optional, default: None
603+
The number of columns for the subplots. Dynamically determined if None.
579604
580605
Returns
581606
-------
@@ -615,9 +640,14 @@ def plot_sbc_histograms(
615640
if param_names is None:
616641
param_names = [f"$\\theta_{{{i}}}$" for i in range(1, n_params + 1)]
617642

618-
# Determine n_subplots dynamically
619-
n_row = int(np.ceil(n_params / 6))
620-
n_col = int(np.ceil(n_params / n_row))
643+
# Determine number of rows and columns for subplots based on inputs
644+
if n_row is None and n_col is None:
645+
n_row = int(np.ceil(n_params / 6))
646+
n_col = int(np.ceil(n_params / n_row))
647+
elif n_row is None and n_col is not None:
648+
n_row = int(np.ceil(n_params / n_col))
649+
elif n_row is not None and n_col is None:
650+
n_col = int(np.ceil(n_params / n_row))
621651

622652
# Initialize figure
623653
if fig_size is None:
@@ -1026,6 +1056,8 @@ def plot_calibration_curves(
10261056
epsilon=0.02,
10271057
fig_size=None,
10281058
color="#8f2727",
1059+
n_row=None,
1060+
n_col=None
10291061
):
10301062
"""Plots the calibration curves, the ECEs and the marginal histograms of predicted posterior model probabilities
10311063
for a model comparison problem. The marginal histograms inform about the fraction of predictions in each bin.
@@ -1055,6 +1087,10 @@ def plot_calibration_curves(
10551087
The figure size passed to the ``matplotlib`` constructor. Inferred if ``None``
10561088
color : str, optional, default: '#8f2727'
10571089
The color of the calibration curves
1090+
n_row : int, optional, default: None
1091+
The number of rows for the subplots. Dynamically determined if None.
1092+
n_col : int, optional, default: None
1093+
The number of columns for the subplots. Dynamically determined if None.
10581094
10591095
Returns
10601096
-------
@@ -1065,9 +1101,15 @@ def plot_calibration_curves(
10651101
if model_names is None:
10661102
model_names = [rf"$M_{{{m}}}$" for m in range(1, num_models + 1)]
10671103

1068-
# Determine n_subplots dynamically
1069-
n_row = int(np.ceil(num_models / 6))
1070-
n_col = int(np.ceil(num_models / n_row))
1104+
# Determine number of rows and columns for subplots based on inputs
1105+
if n_row is None and n_col is None:
1106+
n_row = int(np.ceil(num_models / 6))
1107+
n_col = int(np.ceil(num_models / n_row))
1108+
elif n_row is None and n_col is not None:
1109+
n_row = int(np.ceil(num_models / n_col))
1110+
elif n_row is not None and n_col is None:
1111+
n_col = int(np.ceil(num_models / n_row))
1112+
10711113

10721114
# Compute calibration
10731115
cal_errs, probs_true, probs_pred = expected_calibration_error(true_models, pred_models, num_bins)
@@ -1233,6 +1275,7 @@ def plot_confusion_matrix(
12331275
ax.set_title("Confusion Matrix", fontsize=title_fontsize)
12341276
return fig
12351277

1278+
12361279
def plot_mmd_hypothesis_test(
12371280
mmd_null,
12381281
mmd_observed=None,

0 commit comments

Comments
 (0)