@@ -14,8 +14,10 @@ def coverage(
1414 variable_names : Sequence [str ] = None ,
1515 figsize : Sequence [int ] = None ,
1616 label_fontsize : int = 16 ,
17+ legend_fontsize : int = 14 ,
1718 title_fontsize : int = 18 ,
1819 tick_fontsize : int = 12 ,
20+ legend_location : str = "lower right" ,
1921 color : str = "#132a70" ,
2022 num_col : int = None ,
2123 num_row : int = None ,
@@ -39,7 +41,7 @@ def coverage(
3941 The posterior draws obtained from num_datasets
4042 targets : np.ndarray of shape (num_datasets, num_params)
4143 The true parameter values used for generating num_datasets
42- difference : bool, optional, default: False
44+ difference : bool, optional, default: True
4345 If True, plots the difference between empirical coverage and ideal coverage
4446 (coverage - width), making deviations from ideal calibration more visible.
4547 If False, plots the standard coverage plot.
@@ -52,10 +54,14 @@ def coverage(
5254 The figure size passed to the matplotlib constructor. Inferred if None.
5355 label_fontsize : int, optional, default: 16
5456 The font size of the y-label and x-label text
57+ legend_fontsize : int, optional, default: 14
58+ The font size of the legend text
5559 title_fontsize : int, optional, default: 18
5660 The font size of the title text
5761 tick_fontsize : int, optional, default: 12
5862 The font size of the axis ticklabels
63+ legend_location : str, optional, default: 'upper right
64+ The location of the legend.
5965 color : str, optional, default: '#132a70'
6066 The color for the coverage line
6167 num_row : int, optional, default: None
@@ -128,17 +134,11 @@ def coverage(
128134 )
129135
130136 # Plot ideal coverage difference line (y = 0)
131- ax .axhline (y = 0 , color = "skyblue " , linewidth = 2.0 , label = "Ideal Coverage" )
137+ ax .axhline (y = 0 , color = "black " , linestyle = "dashed" , label = "Ideal Coverage" )
132138
133139 # Plot empirical coverage difference
134140 ax .plot (width_rep , diff_est , color = color , alpha = 1.0 , label = "Coverage Difference" )
135141
136- # Set axis limits
137- ax .set_xlim (0 , 1 )
138-
139- # Add legend to first subplot
140- if i == 0 :
141- ax .legend (fontsize = tick_fontsize , loc = "upper right" )
142142 else :
143143 # Plot confidence ribbon
144144 ax .fill_between (
@@ -151,23 +151,19 @@ def coverage(
151151 )
152152
153153 # Plot ideal coverage line (y = x)
154- ax .plot ([0 , 1 ], [0 , 1 ], color = "skyblue " , linewidth = 2.0 , label = "Ideal Coverage" )
154+ ax .plot ([0 , 1 ], [0 , 1 ], color = "black " , linestyle = "dashed" , label = "Ideal Coverage" )
155155
156156 # Plot empirical coverage
157157 ax .plot (width_rep , coverage_est , color = color , alpha = 1.0 , label = "Empirical Coverage" )
158158
159- # Set axis limits
160- ax .set_xlim (0 , 1 )
161- ax .set_ylim (0 , 1 )
162-
163- # Add legend to first subplot
164- if i == 0 :
165- ax .legend (fontsize = tick_fontsize , loc = "upper left" )
159+ # Add legend to first subplot
160+ if i == 0 :
161+ ax .legend (fontsize = legend_fontsize , loc = legend_location )
166162
167163 prettify_subplots (plot_data ["axes" ], num_subplots = plot_data ["num_variables" ], tick_fontsize = tick_fontsize )
168164
169165 # Add labels, titles, and set font sizes
170- ylabel = "Observed coverage difference" if difference else "Observed coverage"
166+ ylabel = "Empirical coverage difference" if difference else "Empirical coverage"
171167 add_titles_and_labels (
172168 axes = plot_data ["axes" ],
173169 num_row = plot_data ["num_row" ],
0 commit comments