@@ -98,6 +98,10 @@ def plot_recovery(
9898 A flag for adding R^2 between true and estimates to the plot
9999 color : str, optional, default: '#8f2727'
100100 The color for the true vs. estimated scatter points and error bars
101+ n_row : int, optional, default: None
102+ The number of rows for the subplots. Dynamically determined if None.
103+ n_col : int, optional, default: None
104+ The number of columns for the subplots. Dynamically determined if None.
101105 xlabel : str, optional, default: 'Ground truth'
102106 The label on the x-axis of the plot
103107 ylabel : str, optional, default: 'Estimated'
@@ -232,7 +236,7 @@ def plot_z_score_contraction(
232236 tick_fontsize = 12 ,
233237 color = "#8f2727" ,
234238 n_col = None ,
235- n_row = None ,
239+ n_row = None
236240):
237241 """Implements a graphical check for global model sensitivity by plotting the posterior
238242 z-score over the posterior contraction for each set of posterior samples in ``post_samples``
@@ -279,6 +283,10 @@ def plot_z_score_contraction(
279283 The font size of the axis ticklabels
280284 color : str, optional, default: '#8f2727'
281285 The color for the true vs. estimated scatter points and error bars
286+ n_row : int, optional, default: None
287+ The number of rows for the subplots. Dynamically determined if None.
288+ n_col : int, optional, default: None
289+ The number of columns for the subplots. Dynamically determined if None.
282290
283291 Returns
284292 -------
@@ -379,6 +387,8 @@ def plot_sbc_ecdf(
379387 tick_fontsize = 12 ,
380388 rank_ecdf_color = "#a34f4f" ,
381389 fill_color = "grey" ,
390+ n_row = None ,
391+ n_col = None ,
382392 ** kwargs ,
383393):
384394 """Creates the empirical CDFs for each marginal rank distribution and plots it against
@@ -419,6 +429,10 @@ def plot_sbc_ecdf(
419429 The color to use for the rank ECDFs
420430 fill_color : str, optional, default: 'grey'
421431 The color of the fill arguments.
432+ n_row : int, optional, default: None
433+ The number of rows for the subplots. Dynamically determined if None.
434+ n_col : int, optional, default: None
435+ The number of columns for the subplots. Dynamically determined if None.
422436 **kwargs : dict, optional, default: {}
423437 Keyword arguments can be passed to control the behavior of ECDF simultaneous band computation
424438 through the ``ecdf_bands_kwargs`` dictionary. See `simultaneous_ecdf_bands` for keyword arguments
@@ -447,9 +461,14 @@ def plot_sbc_ecdf(
447461 n_row , n_col = 1 , 1
448462 f , ax = plt .subplots (1 , 1 , figsize = fig_size )
449463 else :
450- # Determine n_subplots dynamically
451- n_row = int (np .ceil (n_params / 6 ))
452- n_col = int (np .ceil (n_params / n_row ))
464+ # Determine number of rows and columns for subplots based on inputs
465+ if n_row is None and n_col is None :
466+ n_row = int (np .ceil (n_params / 6 ))
467+ n_col = int (np .ceil (n_params / n_row ))
468+ elif n_row is None and n_col is not None :
469+ n_row = int (np .ceil (n_params / n_col ))
470+ elif n_row is not None and n_col is None :
471+ n_col = int (np .ceil (n_params / n_row ))
453472
454473 # Determine fig_size dynamically, if None
455474 if fig_size is None :
@@ -543,6 +562,8 @@ def plot_sbc_histograms(
543562 title_fontsize = 18 ,
544563 tick_fontsize = 12 ,
545564 hist_color = "#a34f4f" ,
565+ n_row = None ,
566+ n_col = None
546567):
547568 """Creates and plots publication-ready histograms of rank statistics for simulation-based calibration
548569 (SBC) checks according to [1].
@@ -576,6 +597,10 @@ def plot_sbc_histograms(
576597 The font size of the axis ticklabels
577598 hist_color : str, optional, default '#a34f4f'
578599 The color to use for the histogram body
600+ n_row : int, optional, default: None
601+ The number of rows for the subplots. Dynamically determined if None.
602+ n_col : int, optional, default: None
603+ The number of columns for the subplots. Dynamically determined if None.
579604
580605 Returns
581606 -------
@@ -615,9 +640,14 @@ def plot_sbc_histograms(
615640 if param_names is None :
616641 param_names = [f"$\\ theta_{{{ i } }}$" for i in range (1 , n_params + 1 )]
617642
618- # Determine n_subplots dynamically
619- n_row = int (np .ceil (n_params / 6 ))
620- n_col = int (np .ceil (n_params / n_row ))
643+ # Determine number of rows and columns for subplots based on inputs
644+ if n_row is None and n_col is None :
645+ n_row = int (np .ceil (n_params / 6 ))
646+ n_col = int (np .ceil (n_params / n_row ))
647+ elif n_row is None and n_col is not None :
648+ n_row = int (np .ceil (n_params / n_col ))
649+ elif n_row is not None and n_col is None :
650+ n_col = int (np .ceil (n_params / n_row ))
621651
622652 # Initialize figure
623653 if fig_size is None :
@@ -1026,6 +1056,8 @@ def plot_calibration_curves(
10261056 epsilon = 0.02 ,
10271057 fig_size = None ,
10281058 color = "#8f2727" ,
1059+ n_row = None ,
1060+ n_col = None
10291061):
10301062 """Plots the calibration curves, the ECEs and the marginal histograms of predicted posterior model probabilities
10311063 for a model comparison problem. The marginal histograms inform about the fraction of predictions in each bin.
@@ -1055,6 +1087,10 @@ def plot_calibration_curves(
10551087 The figure size passed to the ``matplotlib`` constructor. Inferred if ``None``
10561088 color : str, optional, default: '#8f2727'
10571089 The color of the calibration curves
1090+ n_row : int, optional, default: None
1091+ The number of rows for the subplots. Dynamically determined if None.
1092+ n_col : int, optional, default: None
1093+ The number of columns for the subplots. Dynamically determined if None.
10581094
10591095 Returns
10601096 -------
@@ -1065,9 +1101,15 @@ def plot_calibration_curves(
10651101 if model_names is None :
10661102 model_names = [rf"$M_{{{ m } }}$" for m in range (1 , num_models + 1 )]
10671103
1068- # Determine n_subplots dynamically
1069- n_row = int (np .ceil (num_models / 6 ))
1070- n_col = int (np .ceil (num_models / n_row ))
1104+ # Determine number of rows and columns for subplots based on inputs
1105+ if n_row is None and n_col is None :
1106+ n_row = int (np .ceil (num_models / 6 ))
1107+ n_col = int (np .ceil (num_models / n_row ))
1108+ elif n_row is None and n_col is not None :
1109+ n_row = int (np .ceil (num_models / n_col ))
1110+ elif n_row is not None and n_col is None :
1111+ n_col = int (np .ceil (num_models / n_row ))
1112+
10711113
10721114 # Compute calibration
10731115 cal_errs , probs_true , probs_pred = expected_calibration_error (true_models , pred_models , num_bins )
@@ -1233,6 +1275,7 @@ def plot_confusion_matrix(
12331275 ax .set_title ("Confusion Matrix" , fontsize = title_fontsize )
12341276 return fig
12351277
1278+
12361279def plot_mmd_hypothesis_test (
12371280 mmd_null ,
12381281 mmd_observed = None ,
0 commit comments