@@ -62,7 +62,7 @@ def plot_recovery(
6262 https://betanalpha.github.io/assets/case_studies/principled_bayesian_workflow.html
6363
6464 Important: Posterior aggregates play no special role in Bayesian inference and should only
65- be used heuristically. For instanec , in the case of multi-modal posteriors, common point
65+ be used heuristically. For instance , in the case of multi-modal posteriors, common point
6666 estimates, such as mean, (geometric) median, or maximum a posteriori (MAP) mean nothing.
6767
6868 Parameters
@@ -71,7 +71,7 @@ def plot_recovery(
7171 The posterior draws obtained from n_data_sets
7272 prior_samples : np.ndarray of shape (n_data_sets, n_params)
7373 The prior draws (true parameters) obtained for generating the n_data_sets
74- point_agg : callable, optional, default: np.median
74+ point_agg : callable, optional, default: `` np.median``
7575 The function to apply to the posterior draws to get a point estimate for each marginal.
7676 The default computes the marginal median for each marginal posterior as a robust
7777 point estimate.
@@ -89,13 +89,13 @@ def plot_recovery(
8989 metric_fontsize : int, optional, default: 16
9090 The font size of the goodness-of-fit metric (if provided)
9191 tick_fontsize : int, optional, default: 12
92- The font size of the axis ticklabels
92+ The font size of the axis tick labels
9393 add_corr : bool, optional, default: True
9494 A flag for adding correlation between true and estimates to the plot
9595 add_r2 : bool, optional, default: True
9696 A flag for adding R^2 between true and estimates to the plot
9797 color : str, optional, default: '#8f2727'
98- The color for the true vs. estimated scatter points and errobars
98+ The color for the true vs. estimated scatter points and error bars
9999
100100 Returns
101101 -------
@@ -144,7 +144,7 @@ def plot_recovery(
144144 if i >= n_params :
145145 break
146146
147- # Add scatter and errorbars
147+ # Add scatter and error bars
148148 if uncertainty_agg is not None :
149149 _ = ax .errorbar (prior_samples [:, i ], est [:, i ], yerr = u [:, i ], fmt = "o" , alpha = 0.5 , color = color )
150150 else :
@@ -242,7 +242,7 @@ def plot_z_score_contraction(
242242
243243 post_contraction = 1 - (posterior_variance / prior_variance)
244244
245- In other words, the posterior is a proxy for the reduction in ucnertainty gained by
245+ In other words, the posterior is a proxy for the reduction in uncertainty gained by
246246 replacing the prior with the posterior. The ideal posterior contraction tends to 1.
247247 Contraction near zero indicates that the posterior variance is almost identical to
248248 the prior variance for the particular marginal parameter distribution.
@@ -253,7 +253,7 @@ def plot_z_score_contraction(
253253 Toward a principled Bayesian workflow in cognitive science.
254254 Psychological methods, 26(1), 103.
255255
256- Also available at https://arxiv.org/abs/1904.12765
256+ Paper also available at https://arxiv.org/abs/1904.12765
257257
258258 Parameters
259259 ----------
@@ -272,7 +272,7 @@ def plot_z_score_contraction(
272272 tick_fontsize : int, optional, default: 12
273273 The font size of the axis ticklabels
274274 color : str, optional, default: '#8f2727'
275- The color for the true vs. estimated scatter points and errobars
275+ The color for the true vs. estimated scatter points and error bars
276276
277277 Returns
278278 -------
@@ -887,21 +887,21 @@ def plot_losses(
887887 lw = lw_val ,
888888 label = "Validation" ,
889889 )
890- # Schmuck
890+ # Schmuck
891891 ax .set_xlabel ("Training step #" , fontsize = label_fontsize )
892892 ax .set_ylabel ("Loss value" , fontsize = label_fontsize )
893893 sns .despine (ax = ax )
894894 ax .grid (alpha = grid_alpha )
895895 ax .set_title (train_losses .columns [i ], fontsize = title_fontsize )
896896 # Only add legend if there is a validation curve
897- if val_losses is not None :
897+ if val_losses is not None or moving_average :
898898 ax .legend (fontsize = legend_fontsize )
899899 f .tight_layout ()
900900 return f
901901
902902
903903def plot_prior2d (prior , param_names = None , n_samples = 2000 , height = 2.5 , color = "#8f2727" , ** kwargs ):
904- """Creates pairplots for a given joint prior.
904+ """Creates pair-plots for a given joint prior.
905905
906906 Parameters
907907 ----------
@@ -913,7 +913,7 @@ def plot_prior2d(prior, param_names=None, n_samples=2000, height=2.5, color="#8f
913913 The number of random draws from the joint prior
914914 height : float, optional, default: 2.5
915915 The height of the pair plot
916- color : str, optional, defailt : '#8f2727'
916+ color : str, optional, default : '#8f2727'
917917 The color of the plot
918918 **kwargs : dict, optional
919919 Additional keyword arguments passed to the sns.PairGrid constructor
@@ -943,14 +943,16 @@ def plot_prior2d(prior, param_names=None, n_samples=2000, height=2.5, color="#8f
943943 # Generate plots
944944 g = sns .PairGrid (data_to_plot , height = height , ** kwargs )
945945 g .map_diag (sns .histplot , fill = True , color = color , alpha = 0.9 , kde = True )
946- # Kernel density estimation (KDE) may not always be possible (e.g. with parameters whose correlation is close to 1 or -1).
946+
947+ # Kernel density estimation (KDE) may not always be possible
948+ # (e.g. with parameters whose correlation is close to 1 or -1).
947949 # In this scenario, a scatter-plot is generated instead.
948950 try :
949951 g .map_lower (sns .kdeplot , fill = True , color = color , alpha = 0.9 )
950952 except Exception as e :
951- logging .warn ("KDE failed due to the following exception:\n " + repr (e ) + "\n Substituting scatter plot." )
952- g .map_lower (plt . scatter , alpha = 0.6 , s = 40 , edgecolor = "k" , color = color )
953- g .map_upper (plt . scatter , alpha = 0.6 , s = 40 , edgecolor = "k" , color = color )
953+ logging .warning ("KDE failed due to the following exception:\n " + repr (e ) + "\n Substituting scatter plot." )
954+ g .map_lower (sns . scatterplot , alpha = 0.6 , s = 40 , edgecolor = "k" , color = color )
955+ g .map_upper (sns . scatterplot , alpha = 0.6 , s = 40 , edgecolor = "k" , color = color )
954956
955957 # Add grids
956958 for i in range (dim ):
@@ -961,8 +963,8 @@ def plot_prior2d(prior, param_names=None, n_samples=2000, height=2.5, color="#8f
961963
962964
963965def plot_latent_space_2d (z_samples , height = 2.5 , color = "#8f2727" , ** kwargs ):
964- """Creates pairplots for the latent space learned by the inference network. Enables
965- visual inspection of the the latent space and whether its structrue corresponds to the
966+ """Creates pair plots for the latent space learned by the inference network. Enables
967+ visual inspection of the latent space and whether its structure corresponds to the
966968 one enforced by the optimization criterion.
967969
968970 Parameters
@@ -971,7 +973,7 @@ def plot_latent_space_2d(z_samples, height=2.5, color="#8f2727", **kwargs):
971973 The latent samples computed through a forward pass of the inference network.
972974 height : float, optional, default: 2.5
973975 The height of the pair plot.
974- color : str, optional, defailt : '#8f2727'
976+ color : str, optional, default : '#8f2727'
975977 The color of the plot
976978 **kwargs : dict, optional
977979 Additional keyword arguments passed to the sns.PairGrid constructor
@@ -996,7 +998,7 @@ def plot_latent_space_2d(z_samples, height=2.5, color="#8f2727", **kwargs):
996998 g = sns .PairGrid (data_to_plot , height = height , ** kwargs )
997999 g .map_diag (sns .histplot , fill = True , color = color , alpha = 0.9 , kde = True )
9981000 g .map_lower (sns .kdeplot , fill = True , color = color , alpha = 0.9 )
999- g .map_upper (plt . scatter , alpha = 0.6 , s = 40 , edgecolor = "k" , color = color )
1001+ g .map_upper (sns . scatterplot , alpha = 0.6 , s = 40 , edgecolor = "k" , color = color )
10001002
10011003 # Add grids
10021004 for i in range (z_dim ):
@@ -1060,6 +1062,8 @@ def plot_calibration_curves(
10601062 # Determine n_subplots dynamically
10611063 n_row = int (np .ceil (num_models / 6 ))
10621064 n_col = int (np .ceil (num_models / n_row ))
1065+
1066+ # Compute calibration
10631067 cal_errs , probs_true , probs_pred = expected_calibration_error (true_models , pred_models , num_bins )
10641068
10651069 # Initialize figure
@@ -1094,8 +1098,6 @@ def plot_calibration_curves(
10941098 ax [j ].spines ["top" ].set_visible (False )
10951099 ax [j ].set_xlim ([0 - epsilon , 1 + epsilon ])
10961100 ax [j ].set_ylim ([0 - epsilon , 1 + epsilon ])
1097- ax [j ].set_xlabel ("Predicted probability" , fontsize = label_fontsize )
1098- ax [j ].set_ylabel ("True probability" , fontsize = label_fontsize )
10991101 ax [j ].set_xticks ([0.0 , 0.2 , 0.4 , 0.6 , 0.8 , 1.0 ])
11001102 ax [j ].set_yticks ([0.0 , 0.2 , 0.4 , 0.6 , 0.8 , 1.0 ])
11011103 ax [j ].grid (alpha = 0.5 )
@@ -1111,6 +1113,18 @@ def plot_calibration_curves(
11111113 size = legend_fontsize ,
11121114 )
11131115
1116+ # Only add x-labels to the bottom row
1117+ bottom_row = axarr if n_row == 1 else axarr [0 ] if n_col == 1 else axarr [n_row - 1 , :]
1118+ for _ax in bottom_row :
1119+ _ax .set_xlabel ("Predicted probability" , fontsize = label_fontsize )
1120+
1121+ # Only add y-labels to left-most row
1122+ if n_row == 1 : # if there is only one row, the ax array is 1D
1123+ ax [0 ].set_ylabel ("True probability" , fontsize = label_fontsize )
1124+ else : # if there is more than one row, the ax array is 2D
1125+ for _ax in axarr [:, 0 ]:
1126+ _ax .set_ylabel ("True probability" , fontsize = label_fontsize )
1127+
11141128 fig .tight_layout ()
11151129 return fig
11161130
@@ -1223,32 +1237,31 @@ def plot_mmd_hypothesis_test(
12231237
12241238 Parameters
12251239 ----------
1226- mmd_null: np.ndarray
1227- samples from the MMD sampling distribution under the null hypothesis "the model is well-specified"
1228- mmd_observed: float
1229- observed MMD value
1230- alpha_level: float
1231- rejection probability (type I error)
1232- null_color: color
1233- color for the H0 sampling distribution
1234- observed_color: color
1235- color for the observed MMD
1236- alpha_color: color
1237- color for the rejection area
1240+ mmd_null : np.ndarray
1241+ The samples from the MMD sampling distribution under the null hypothesis "the model is well-specified"
1242+ mmd_observed : float
1243+ The observed MMD value
1244+ alpha_level : float
1245+ The rejection probability (type I error)
1246+ null_color : str or tuple
1247+ The color of the H0 sampling distribution
1248+ observed_color : str or tuple
1249+ The color of the observed MMD
1250+ alpha_color : str or tuple
1251+ The color of the rejection area
12381252 truncate_vlines_at_kde: bool
12391253 true: cut off the vlines at the kde
12401254 false: continue kde lines across the plot
1241- xmin: float
1242- lower x axis limit
1243- xmax: float
1244- upper x axis limit
1245- bw_factor: float, default: 1.5
1255+ xmin : float
1256+ The lower x- axis limit
1257+ xmax : float
1258+ The upper x- axis limit
1259+ bw_factor : float, optional , default: 1.5
12461260 bandwidth (aka. smoothing parameter) of the kernel density estimate
12471261
12481262 Returns
12491263 -------
12501264 f : plt.Figure - the figure instance for optional saving
1251-
12521265 """
12531266
12541267 def draw_vline_to_kde (x , kde_object , color , label = None , ** kwargs ):
0 commit comments