@@ -147,12 +147,13 @@ def plot_recovery(
147147 if fig_size is None :
148148 fig_size = (int (4 * n_col ), int (4 * n_row ))
149149 f , axarr = plt .subplots (n_row , n_col , figsize = fig_size )
150+
150151 # turn axarr into 1D list
152+ axarr = np .atleast_1d (axarr )
151153 if n_col > 1 or n_row > 1 :
152154 axarr_it = axarr .flat
153155 else :
154- # for 1x1, axarr is not a list -> turn it into one for use with enumerate
155- axarr_it = [axarr ]
156+ axarr_it = axarr
156157
157158 for i , ax in enumerate (axarr_it ):
158159 if i >= n_params :
@@ -337,12 +338,13 @@ def plot_z_score_contraction(
337338 if fig_size is None :
338339 fig_size = (int (4 * n_col ), int (4 * n_row ))
339340 f , axarr = plt .subplots (n_row , n_col , figsize = fig_size )
341+
340342 # turn axarr into 1D list
343+ axarr = np .atleast_1d (axarr )
341344 if n_col > 1 or n_row > 1 :
342345 axarr_it = axarr .flat
343346 else :
344- # for 1x1, axarr is not a list -> turn it into one for use with enumerate
345- axarr_it = [axarr ]
347+ axarr_it = axarr
346348
347349 # Loop and plot
348350 for i , ax in enumerate (axarr_it ):
@@ -480,6 +482,7 @@ def plot_sbc_ecdf(
480482
481483 # Initialize figure
482484 f , ax = plt .subplots (n_row , n_col , figsize = fig_size )
485+ ax = np .atleast_1d (ax )
483486
484487 # Plot individual ecdf of parameters
485488 for j in range (ranks .shape [- 1 ]):
@@ -657,7 +660,8 @@ def plot_sbc_histograms(
657660 if fig_size is None :
658661 fig_size = (int (5 * n_col ), int (5 * n_row ))
659662 f , axarr = plt .subplots (n_row , n_col , figsize = fig_size )
660-
663+ axarr = np .atleast_1d (axarr )
664+
661665 # Compute ranks (using broadcasting)
662666 ranks = np .sum (post_samples < prior_samples [:, np .newaxis , :], axis = 1 )
663667
0 commit comments