99def coverage (
1010 estimates : Mapping [str , np .ndarray ] | np .ndarray ,
1111 targets : Mapping [str , np .ndarray ] | np .ndarray ,
12+ difference : bool = False ,
1213 variable_keys : Sequence [str ] = None ,
1314 variable_names : Sequence [str ] = None ,
1415 figsize : Sequence [int ] = None ,
@@ -29,13 +30,19 @@ def coverage(
2930
3031 The coverage is accompanied by credible intervals for the coverage (gray ribbon).
3132 These are computed via the (conjugate) Beta-Binomial model for binomial proportions with a uniform prior.
33+ For more details on the Beta-Binomial model, see Chapter 2 of Bayesian Data Analysis (2013, 3rd ed.) by
34+ Gelman A., Carlin J., Stern H., Dunson D., Vehtari A., & Rubin D.
3235
3336 Parameters
3437 ----------
3538 estimates : np.ndarray of shape (num_datasets, num_post_draws, num_params)
3639 The posterior draws obtained from num_datasets
3740 targets : np.ndarray of shape (num_datasets, num_params)
3841 The true parameter values used for generating num_datasets
42+ difference : bool, optional, default: False
43+ If True, plots the difference between empirical coverage and ideal coverage
44+ (coverage - width), making deviations from ideal calibration more visible.
45+ If False, plots the standard coverage plot.
3946 variable_keys : list or None, optional, default: None
4047 Select keys from the dictionaries provided in estimates and targets.
4148 By default, select all keys.
@@ -104,181 +111,70 @@ def coverage(
104111 coverage_low = coverage_data ["coverage_lower" ][:, i ]
105112 coverage_high = coverage_data ["coverage_upper" ][:, i ]
106113
107- # Plot confidence ribbon
108- ax .fill_between (
109- width_rep ,
110- coverage_low ,
111- coverage_high ,
112- color = "grey" ,
113- alpha = 0.33 ,
114- label = "95% Credible Interval" ,
115- )
116-
117- # Plot ideal coverage line (y = x)
118- ax .plot ([0 , 1 ], [0 , 1 ], color = "skyblue" , linewidth = 2.0 , label = "Ideal Coverage" )
119-
120- # Plot empirical coverage
121- ax .plot (width_rep , coverage_est , color = color , alpha = 1.0 , label = "Empirical Coverage" )
122-
123- # Set axis limits
124- ax .set_xlim (0 , 1 )
125- ax .set_ylim (0 , 1 )
126-
127- # Add legend to first subplot
128- if i == 0 :
129- ax .legend (fontsize = tick_fontsize , loc = "upper left" )
130-
131- prettify_subplots (plot_data ["axes" ], num_subplots = plot_data ["num_variables" ], tick_fontsize = tick_fontsize )
132-
133- # Add labels, titles, and set font sizes
134- add_titles_and_labels (
135- axes = plot_data ["axes" ],
136- num_row = plot_data ["num_row" ],
137- num_col = plot_data ["num_col" ],
138- title = plot_data ["variable_names" ],
139- xlabel = "Central interval width" ,
140- ylabel = "Observed coverage" ,
141- title_fontsize = title_fontsize ,
142- label_fontsize = label_fontsize ,
143- )
144-
145- plot_data ["fig" ].tight_layout ()
146- return plot_data ["fig" ]
147-
148-
149- def coverage_diff (
150- estimates : Mapping [str , np .ndarray ] | np .ndarray ,
151- targets : Mapping [str , np .ndarray ] | np .ndarray ,
152- variable_keys : Sequence [str ] = None ,
153- variable_names : Sequence [str ] = None ,
154- figsize : Sequence [int ] = None ,
155- label_fontsize : int = 16 ,
156- title_fontsize : int = 18 ,
157- tick_fontsize : int = 12 ,
158- color : str = "#132a70" ,
159- num_col : int = None ,
160- num_row : int = None ,
161- ) -> plt .Figure :
162- """
163- Creates coverage difference plots showing the difference between empirical coverage
164- and ideal coverage of posterior credible intervals.
165-
166- This plot shows coverage - width, making deviations from ideal calibration
167- more visible than the standard coverage plot.
168- For more details, see the documentation of the standard coverage plot.
169-
170- Parameters
171- ----------
172- estimates : np.ndarray of shape (num_datasets, num_post_draws, num_params)
173- The posterior draws obtained from num_datasets
174- targets : np.ndarray of shape (num_datasets, num_params)
175- The true parameter values used for generating num_datasets
176- variable_keys : list or None, optional, default: None
177- Select keys from the dictionaries provided in estimates and targets.
178- By default, select all keys.
179- variable_names : list or None, optional, default: None
180- The parameter names for nice plot titles. Inferred if None
181- figsize : tuple or None, optional, default: None
182- The figure size passed to the matplotlib constructor. Inferred if None.
183- label_fontsize : int, optional, default: 16
184- The font size of the y-label and x-label text
185- title_fontsize : int, optional, default: 18
186- The font size of the title text
187- tick_fontsize : int, optional, default: 12
188- The font size of the axis ticklabels
189- color : str, optional, default: '#132a70'
190- The color for the coverage difference line
191- num_row : int, optional, default: None
192- The number of rows for the subplots. Dynamically determined if None.
193- num_col : int, optional, default: None
194- The number of columns for the subplots. Dynamically determined if None.
195-
196- Returns
197- -------
198- f : plt.Figure - the figure instance for optional saving
199-
200- Raises
201- ------
202- ShapeError
203- If there is a deviation from the expected shapes of ``estimates`` and ``targets``.
204-
205- """
206-
207- # Gather plot data and metadata into a dictionary
208- plot_data = prepare_plot_data (
209- estimates = estimates ,
210- targets = targets ,
211- variable_keys = variable_keys ,
212- variable_names = variable_names ,
213- num_col = num_col ,
214- num_row = num_row ,
215- figsize = figsize ,
216- )
217-
218- estimates = plot_data .pop ("estimates" )
219- targets = plot_data .pop ("targets" )
220-
221- # Determine widths to compute coverage for
222- num_draws = estimates .shape [1 ]
223- widths = np .arange (0 , num_draws + 2 ) / (num_draws + 1 )
224-
225- # Compute empirical coverage with default parameters
226- coverage_data = compute_empirical_coverage (
227- estimates = estimates ,
228- targets = targets ,
229- widths = widths ,
230- prob = 0.95 ,
231- interval_type = "central" ,
232- )
233-
234- # Plot coverage difference for each parameter
235- for i , ax in enumerate (plot_data ["axes" ].flat ):
236- if i >= plot_data ["num_variables" ]:
237- break
238-
239- width_rep = coverage_data ["width_represented" ][:, i ]
240- coverage_est = coverage_data ["coverage_estimates" ][:, i ]
241- coverage_low = coverage_data ["coverage_lower" ][:, i ]
242- coverage_high = coverage_data ["coverage_upper" ][:, i ]
243-
244- # Compute differences
245- diff_est = coverage_est - width_rep
246- diff_low = coverage_low - width_rep
247- diff_high = coverage_high - width_rep
248-
249- # Plot confidence ribbon
250- ax .fill_between (
251- width_rep ,
252- diff_low ,
253- diff_high ,
254- color = "grey" ,
255- alpha = 0.33 ,
256- label = "95% Credible Interval" ,
257- )
258-
259- # Plot ideal coverage difference line (y = 0)
260- ax .axhline (y = 0 , color = "skyblue" , linewidth = 2.0 , label = "Ideal Coverage" )
261-
262- # Plot empirical coverage difference
263- ax .plot (width_rep , diff_est , color = color , alpha = 1.0 , label = "Coverage Difference" )
264-
265- # Set axis limits
266- ax .set_xlim (0 , 1 )
267-
268- # Add legend to first subplot
269- if i == 0 :
270- ax .legend (fontsize = tick_fontsize , loc = "upper right" )
114+ if difference :
115+ # Compute differences for coverage difference plot
116+ diff_est = coverage_est - width_rep
117+ diff_low = coverage_low - width_rep
118+ diff_high = coverage_high - width_rep
119+
120+ # Plot confidence ribbon
121+ ax .fill_between (
122+ width_rep ,
123+ diff_low ,
124+ diff_high ,
125+ color = "grey" ,
126+ alpha = 0.33 ,
127+ label = "95% Credible Interval" ,
128+ )
129+
130+ # Plot ideal coverage difference line (y = 0)
131+ ax .axhline (y = 0 , color = "skyblue" , linewidth = 2.0 , label = "Ideal Coverage" )
132+
133+ # Plot empirical coverage difference
134+ ax .plot (width_rep , diff_est , color = color , alpha = 1.0 , label = "Coverage Difference" )
135+
136+ # Set axis limits
137+ ax .set_xlim (0 , 1 )
138+
139+ # Add legend to first subplot
140+ if i == 0 :
141+ ax .legend (fontsize = tick_fontsize , loc = "upper right" )
142+ else :
143+ # Plot confidence ribbon
144+ ax .fill_between (
145+ width_rep ,
146+ coverage_low ,
147+ coverage_high ,
148+ color = "grey" ,
149+ alpha = 0.33 ,
150+ label = "95% Credible Interval" ,
151+ )
152+
153+ # Plot ideal coverage line (y = x)
154+ ax .plot ([0 , 1 ], [0 , 1 ], color = "skyblue" , linewidth = 2.0 , label = "Ideal Coverage" )
155+
156+ # Plot empirical coverage
157+ ax .plot (width_rep , coverage_est , color = color , alpha = 1.0 , label = "Empirical Coverage" )
158+
159+ # Set axis limits
160+ ax .set_xlim (0 , 1 )
161+ ax .set_ylim (0 , 1 )
162+
163+ # Add legend to first subplot
164+ if i == 0 :
165+ ax .legend (fontsize = tick_fontsize , loc = "upper left" )
271166
272167 prettify_subplots (plot_data ["axes" ], num_subplots = plot_data ["num_variables" ], tick_fontsize = tick_fontsize )
273168
274169 # Add labels, titles, and set font sizes
170+ ylabel = "Observed coverage difference" if difference else "Observed coverage"
275171 add_titles_and_labels (
276172 axes = plot_data ["axes" ],
277173 num_row = plot_data ["num_row" ],
278174 num_col = plot_data ["num_col" ],
279175 title = plot_data ["variable_names" ],
280176 xlabel = "Central interval width" ,
281- ylabel = "Coverage difference" ,
177+ ylabel = ylabel ,
282178 title_fontsize = title_fontsize ,
283179 label_fontsize = label_fontsize ,
284180 )
0 commit comments