1- from typing import Union
1+ from typing import Optional , Union
22import numpy as np
33import matplotlib .pyplot as plt
44from matplotlib .axes import Axes as ax
@@ -49,25 +49,82 @@ def __init__(
4949 def plot_name (self ):
5050 return "coverage_fraction.png"
5151
52- def _data_setup (self ) :
53- _ , coverage = coverage_fraction_metric (
54- self .model , self .data , self .run_id , out_dir = None
52+ def _data_setup (self , percentile_step_size : float = 1 ) -> DataDisplay :
53+ _ , ( coverage_mean , coverage_std ) = coverage_fraction_metric (
54+ self .model , self .data , self .run_id , out_dir = None , percentiles = np . arange ( 0 , 100 , percentile_step_size ), use_progress_bar = self . use_progress_bar
5555 ).calculate ()
5656 return DataDisplay (
57- coverage_fractions = coverage
57+ coverage_fractions = coverage_mean ,
58+ coverage_percentiles = np .arange (0 , 100 , percentile_step_size ),
59+ coverage_std = coverage_std
5860 )
5961
62+ def _plot_residual (self , data_display , ax , figure_alpha , line_width , reference_line_style , include_coverage_residual_std , include_ideal_range ):
63+ color_cycler = iter (plt .cycler ("color" , self .parameter_colors ))
64+ line_style_cycler = iter (plt .cycler ("line_style" , self .line_cycle ))
65+ percentile_array = data_display .coverage_percentiles / 100.0
66+
67+ ax .plot ([0 ,1 ], [0 , 0 ], reference_line_style , lw = line_width , zorder = 1000 )
68+
69+ for i in range (self .n_parameters ):
70+ color = next (color_cycler )["color" ]
71+ line_style = next (line_style_cycler )["line_style" ]
72+
73+ residual = data_display .coverage_fractions [:, i ] - np .linspace (0 , 1 , len (data_display .coverage_fractions [:,i ]))
74+
75+ ax .plot (
76+ percentile_array ,
77+ residual ,
78+ alpha = figure_alpha ,
79+ lw = line_width * .8 ,
80+ linestyle = line_style ,
81+ color = color ,
82+ label = self .parameter_names [i ],
83+ )
84+ if include_coverage_residual_std :
85+
86+ ax .fill_between (
87+ percentile_array ,
88+ residual - data_display .coverage_std [:, i ],
89+ residual + data_display .coverage_std [:, i ],
90+ color = color ,
91+ alpha = 0.2 ,
92+ )
93+
94+ if include_ideal_range :
95+
96+ ax .fill_between (
97+ [0 , 1 ],
98+ [- 0.2 ]* 2 ,
99+ [0.2 ]* 2 ,
100+ color = "gray" ,
101+ alpha = 0.1 ,
102+ )
103+ ax .fill_between (
104+ [0 , 1 ],
105+ [- 0.1 ]* 2 ,
106+ [0.1 ]* 2 ,
107+ color = "gray" ,
108+ alpha = 0.2 ,
109+ )
110+
60111 def plot (
61112 self ,
62113 data_display : Union [DataDisplay , str ],
63114 figure_alpha = 1.0 ,
64115 line_width = 3 ,
65- legend_loc = "lower right" ,
116+ legend_loc :Optional [str ]= None ,
117+ include_coverage_std :bool = False ,
118+ include_coverage_residual :bool = False ,
119+ include_coverage_residual_std :bool = False ,
120+ include_ideal_range : bool = True ,
66121 reference_line_label = "Reference Line" ,
67122 reference_line_style = "k--" ,
68123 x_label = "Confidence Interval of the Posterior Volume" ,
69124 y_label = "Fraction of Lenses within Posterior Volume" ,
70- title = "NPE" ) -> tuple ["fig" , "ax" ]:
125+ residual_y_label = "Coverage Fraction Residual" ,
126+ title = "NPE"
127+ ) -> tuple ["fig" , "ax" ]:
71128 """
72129 Args:
73130 figure_alpha (float, optional): Opacity of parameter lines. Defaults to 1.0.
@@ -83,19 +140,28 @@ def plot(
83140 if not isinstance (data_display , DataDisplay ):
84141 data_display = DataDisplay ().from_h5 (data_display , self .plot_name )
85142
86- n_steps = data_display . coverage_fractions . shape [ 0 ]
87- percentile_array = np . linspace ( 0 , 1 , n_steps )
143+
144+ percentile_array = data_display . coverage_percentiles / 100.0
88145 color_cycler = iter (plt .cycler ("color" , self .parameter_colors ))
89146 line_style_cycler = iter (plt .cycler ("line_style" , self .line_cycle ))
90147
91148 # Plotting
92- fig , ax = plt .subplots (1 , 1 , figsize = self .figure_size )
149+ if include_coverage_residual :
150+ fig , subplots = plt .subplots (2 , 1 , figsize = (self .figure_size [0 ], self .figure_size [1 ]* 1.2 ), height_ratios = [3 , 1 ], sharex = True )
151+ ax = subplots [0 ]
152+
153+ self ._plot_residual (
154+ data_display , subplots [1 ], figure_alpha , line_width , reference_line_style , include_coverage_residual_std , include_ideal_range
155+ )
156+ subplots [1 ].set_ylabel (residual_y_label )
157+
158+ else :
159+ fig , ax = plt .subplots (1 , 1 , figsize = self .figure_size )
93160
94161 # Iterate over the number of parameters in the model
95162 for i in range (self .n_parameters ):
96163 color = next (color_cycler )["color" ]
97164 line_style = next (line_style_cycler )["line_style" ]
98-
99165 ax .plot (
100166 percentile_array ,
101167 data_display .coverage_fractions [:, i ],
@@ -105,6 +171,14 @@ def plot(
105171 color = color ,
106172 label = self .parameter_names [i ],
107173 )
174+ if include_coverage_std :
175+ ax .fill_between (
176+ percentile_array ,
177+ data_display .coverage_fractions [:, i ] - data_display .coverage_std [:, i ],
178+ data_display .coverage_fractions [:, i ] + data_display .coverage_std [:, i ],
179+ color = color ,
180+ alpha = 0.2 ,
181+ )
108182
109183 ax .plot (
110184 [0 , 0.5 , 1 ],
@@ -115,13 +189,35 @@ def plot(
115189 label = reference_line_label ,
116190 )
117191
192+ if include_ideal_range :
193+ def add_clearance (ax , clearance = 0.1 , clearance_alpha = 0.2 ):
194+ x_values = np .linspace (0 , 1 , 100 ) # More points for smoother curves
195+ y_lower = np .maximum (0 , x_values - clearance ) # Lower bound with clearance
196+ y_upper = np .minimum (1 , x_values + clearance ) # Upper bound with clearance
197+
198+ # Fill the area between the bounds
199+ ax .fill_between (
200+ x_values ,
201+ y_lower ,
202+ y_upper ,
203+ color = "gray" ,
204+ alpha = clearance_alpha ,
205+ )
206+
207+ add_clearance (ax , clearance = 0.2 , clearance_alpha = 0.2 )
208+ add_clearance (ax , clearance = 0.1 , clearance_alpha = 0.1 )
209+
210+
118211 ax .set_xlim ([- 0.05 , 1.05 ])
119212 ax .set_ylim ([- 0.05 , 1.05 ])
120213
121- ax .text (0.03 , 0.93 , "Under-confident" , horizontalalignment = "left" )
122- ax .text (0.3 , 0.05 , "Overconfident" , horizontalalignment = "left" )
214+ # ax.text(- 0.03, 0.93, "Under-confident", horizontalalignment="left")
215+ # ax.text(0.3, 0.05, "Overconfident", horizontalalignment="left")
123216
124- ax .legend (loc = legend_loc )
217+ if legend_loc is not None :
218+ ax .legend (loc = legend_loc )
219+ else :
220+ ax .legend ()
125221
126222 ax .set_xlabel (x_label )
127223 ax .set_ylabel (y_label )
0 commit comments