Skip to content

Commit 3bb93e6

Browse files
authored
Merge pull request #113 from voetberg/36-coverage
36 coverage
2 parents 72cd8f9 + 6dcab8c commit 3bb93e6

File tree

4 files changed

+146
-36
lines changed

4 files changed

+146
-36
lines changed

src/deepdiagnostics/metrics/coverage_fraction.py

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -51,34 +51,35 @@ def calculate(self) -> tuple[Sequence, Sequence]:
5151
Returns:
5252
tuple[Sequence, Sequence]: A tuple of the samples tested (M samples, Samples per inference, N parameters) and the coverage over those samples.
5353
"""
54+
5455
all_samples = np.empty(
5556
(self.number_simulations, self.samples_per_inference, np.shape(self.thetas)[1])
5657
)
57-
count_array = []
5858
iterator = range(self.number_simulations)
5959
if self.use_progress_bar:
6060
iterator = tqdm(
6161
iterator,
6262
desc="Sampling from the posterior for each observation",
6363
unit=" observation",
6464
)
65+
n_theta_samples = self.thetas.shape[0]
66+
count_array = np.zeros((self.number_simulations, len(self.percentiles), self.thetas.shape[1]))
67+
6568
for sample_index in iterator:
6669
context_sample = self.context[self.data.rng.integers(0, len(self.context))]
6770
samples = self._run_model_inference(self.samples_per_inference, context_sample)
6871

6972
all_samples[sample_index] = samples
7073

71-
count_vector = []
7274
# step through the percentile list
73-
for cov in self.percentiles:
75+
for index, cov in enumerate(self.percentiles):
7476
percentile_lower = 50.0 - cov / 2
7577
percentile_upper = 50.0 + cov / 2
7678

7779
# find the percentile for the posterior for this observation
7880
# this is n_params dimensional
7981
# the units are in parameter space
8082
confidence_lower = np.percentile(samples, percentile_lower, axis=0)
81-
8283
confidence_upper = np.percentile(samples, percentile_upper, axis=0)
8384

8485

@@ -87,22 +88,25 @@ def calculate(self) -> tuple[Sequence, Sequence]:
8788
# upper and lower confidence intervals
8889
# checks separately for each side of the 50th percentile
8990

90-
count = np.logical_and(
91-
confidence_upper - self.thetas[sample_index, :].numpy() > 0,
92-
self.thetas[sample_index, :].numpy() - confidence_lower > 0,
91+
c = np.logical_and(
92+
confidence_upper - self.thetas.numpy() > 0,
93+
self.thetas.numpy() - confidence_lower > 0,
9394
)
94-
count_vector.append(count)
95+
count_array[sample_index, index] = np.sum(c.astype(int), axis=0)/n_theta_samples
9596

9697
# each time the above is > 0, adds a count
97-
count_array.append(count_vector)
98+
#count_array[sample_index] = count_vector
99+
100+
coverage_mean = np.mean(count_array, axis=0)
101+
coverage_std = np.std(count_array, axis=0)
98102

99-
count_sum_array = np.sum(count_array, axis=0)
100-
frac_lens_within_vol = np.array(count_sum_array)
101-
coverage = frac_lens_within_vol / len(self.context)
103+
self.output = {
104+
"coverage": coverage_mean,
105+
"coverage_std": coverage_std,
102106

103-
self.output = coverage
107+
}
104108

105-
return all_samples, coverage
109+
return all_samples, (coverage_mean, coverage_std)
106110

107111
def __call__(self, **kwds: Any) -> Any:
108112
self.calculate()

src/deepdiagnostics/plots/coverage_fraction.py

Lines changed: 125 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Union
1+
from typing import Optional, Union
22
import numpy as np
33
import matplotlib.pyplot as plt
44
from matplotlib.axes import Axes as ax
@@ -49,53 +49,130 @@ def __init__(
4949
def plot_name(self):
5050
return "coverage_fraction.png"
5151

52-
def _data_setup(self):
53-
_, coverage = coverage_fraction_metric(
54-
self.model, self.data, self.run_id, out_dir=None
52+
def _data_setup(self, percentile_step_size:float=1) -> DataDisplay:
53+
_, (coverage_mean, coverage_std) = coverage_fraction_metric(
54+
self.model, self.data, self.run_id, out_dir=None, percentiles=np.arange(0, 100, percentile_step_size), use_progress_bar=self.use_progress_bar
5555
).calculate()
5656
return DataDisplay(
57-
coverage_fractions=coverage
57+
coverage_fractions=coverage_mean,
58+
coverage_percentiles=np.arange(0, 100, percentile_step_size),
59+
coverage_std=coverage_std
5860
)
5961

62+
def _plot_residual(self, data_display, ax, figure_alpha, line_width, reference_line_style, include_coverage_residual_std, include_ideal_range):
63+
color_cycler = iter(plt.cycler("color", self.parameter_colors))
64+
line_style_cycler = iter(plt.cycler("line_style", self.line_cycle))
65+
percentile_array = data_display.coverage_percentiles / 100.0
66+
67+
ax.plot([0,1], [0, 0], reference_line_style, lw=line_width, zorder=1000)
68+
69+
for i in range(self.n_parameters):
70+
color = next(color_cycler)["color"]
71+
line_style = next(line_style_cycler)["line_style"]
72+
73+
residual = data_display.coverage_fractions[:, i] - np.linspace(0, 1, len(data_display.coverage_fractions[:,i]))
74+
75+
ax.plot(
76+
percentile_array,
77+
residual,
78+
alpha=figure_alpha,
79+
lw=line_width*.8,
80+
linestyle=line_style,
81+
color=color,
82+
label=self.parameter_names[i],
83+
)
84+
if include_coverage_residual_std:
85+
86+
ax.fill_between(
87+
percentile_array,
88+
residual - data_display.coverage_std[:, i],
89+
residual + data_display.coverage_std[:, i],
90+
color=color,
91+
alpha=0.2,
92+
)
93+
94+
if include_ideal_range:
95+
96+
ax.fill_between(
97+
[0, 1],
98+
[-0.2]*2,
99+
[0.2]*2,
100+
color="gray",
101+
alpha=0.1,
102+
)
103+
ax.fill_between(
104+
[0, 1],
105+
[-0.1]*2,
106+
[0.1]*2,
107+
color="gray",
108+
alpha=0.2,
109+
)
110+
60111
def plot(
61112
self,
62113
data_display: Union[DataDisplay, str],
63114
figure_alpha=1.0,
64115
line_width=3,
65-
legend_loc="lower right",
116+
legend_loc:Optional[str]=None,
117+
include_coverage_std:bool = False,
118+
include_coverage_residual:bool = False,
119+
include_coverage_residual_std:bool = False,
120+
include_ideal_range: bool=True,
66121
reference_line_label="Reference Line",
67122
reference_line_style="k--",
68123
x_label="Confidence Interval of the Posterior Volume",
69-
y_label="Fraction of Lenses within Posterior Volume",
70-
title="NPE") -> tuple["fig", "ax"]:
124+
y_label="Coverage fraction within posterior volume",
125+
residual_y_label="Coverage Fraction Residual",
126+
title=""
127+
) -> tuple["fig", "ax"]:
71128
"""
72-
Args:
129+
Plot the coverage fraction and residuals if specified.
130+
131+
Args:
132+
data_display (Union[DataDisplay, str]): DataDisplay object or path to h5 file containing the data. If str, it will be loaded and requires the fields "coverage_fractions", "coverage_percentiles", and optionally "coverage_std".
73133
figure_alpha (float, optional): Opacity of parameter lines. Defaults to 1.0.
74134
line_width (int, optional): Width of parameter lines. Defaults to 3.
75-
legend_loc (str, optional): Location of the legend, str based on `matplotlib <https://matplotlib.org/stable/api/_as_gen/matplotlib.pyplot.legend.html>`_. Defaults to "lower right".
135+
legend_loc (str, optional): Location of the legend. Defaults to matplotlib specified.
136+
include_coverage_std (bool, optional): Whether to include the standard deviation shading for coverage fractions . Defaults to False.
137+
include_coverage_residual (bool, optional): Whether to include the residual plot (coverage fraction - diagonal). Creates an additional subplot under the original plot. Defaults to False.
138+
include_coverage_residual_std (bool, optional): Whether to include the standard deviation shading for residuals. Defaults to False.
139+
include_ideal_range (bool, optional): Whether to include the ideal range shading (0.1/0.2 around the diagonal). Defaults to True.
76140
reference_line_label (str, optional): Label name for the diagonal ideal line. Defaults to "Reference Line".
77141
reference_line_style (str, optional): Line style for the reference line. Defaults to "k--".
78142
x_label (str, optional): y label. Defaults to "Confidence Interval of the Posterior Volume".
79143
y_label (str, optional): y label. Defaults to "Fraction of Lenses within Posterior Volume".
144+
residual_y_label (str, optional): y label for the residual plot. Defaults to "Coverage Fraction Residual".
80145
title (str, optional): plot title. Defaults to "NPE".
146+
81147
"""
82-
148+
83149
if not isinstance(data_display, DataDisplay):
84150
data_display = DataDisplay().from_h5(data_display, self.plot_name)
85151

86-
n_steps = data_display.coverage_fractions.shape[0]
87-
percentile_array = np.linspace(0, 1, n_steps)
152+
153+
percentile_array = data_display.coverage_percentiles / 100.0
88154
color_cycler = iter(plt.cycler("color", self.parameter_colors))
89155
line_style_cycler = iter(plt.cycler("line_style", self.line_cycle))
90156

91157
# Plotting
92-
fig, ax = plt.subplots(1, 1, figsize=self.figure_size)
158+
if include_coverage_residual:
159+
fig, subplots = plt.subplots(2, 1, figsize=(self.figure_size[0], self.figure_size[1]*1.2), height_ratios=[3, 1], sharex=True)
160+
ax = subplots[0]
161+
162+
self._plot_residual(
163+
data_display, subplots[1], figure_alpha, line_width, reference_line_style, include_coverage_residual_std, include_ideal_range
164+
)
165+
subplots[1].set_ylabel(residual_y_label)
166+
subplots[1].set_xlabel(x_label)
167+
168+
else:
169+
fig, ax = plt.subplots(1, 1, figsize=self.figure_size)
170+
ax.set_xlabel(x_label)
93171

94172
# Iterate over the number of parameters in the model
95173
for i in range(self.n_parameters):
96174
color = next(color_cycler)["color"]
97175
line_style = next(line_style_cycler)["line_style"]
98-
99176
ax.plot(
100177
percentile_array,
101178
data_display.coverage_fractions[:, i],
@@ -105,6 +182,14 @@ def plot(
105182
color=color,
106183
label=self.parameter_names[i],
107184
)
185+
if include_coverage_std:
186+
ax.fill_between(
187+
percentile_array,
188+
data_display.coverage_fractions[:, i] - data_display.coverage_std[:, i],
189+
data_display.coverage_fractions[:, i] + data_display.coverage_std[:, i],
190+
color=color,
191+
alpha=0.2,
192+
)
108193

109194
ax.plot(
110195
[0, 0.5, 1],
@@ -115,15 +200,36 @@ def plot(
115200
label=reference_line_label,
116201
)
117202

203+
if include_ideal_range:
204+
def add_clearance(ax, clearance=0.1, clearance_alpha=0.2):
205+
x_values = np.linspace(0, 1, 100) # More points for smoother curves
206+
y_lower = np.maximum(0, x_values - clearance) # Lower bound with clearance
207+
y_upper = np.minimum(1, x_values + clearance) # Upper bound with clearance
208+
209+
# Fill the area between the bounds
210+
ax.fill_between(
211+
x_values,
212+
y_lower,
213+
y_upper,
214+
color="gray",
215+
alpha=clearance_alpha,
216+
)
217+
218+
add_clearance(ax, clearance=0.2, clearance_alpha=0.2)
219+
add_clearance(ax, clearance=0.1, clearance_alpha=0.1)
220+
221+
118222
ax.set_xlim([-0.05, 1.05])
119223
ax.set_ylim([-0.05, 1.05])
120224

121-
ax.text(0.03, 0.93, "Under-confident", horizontalalignment="left")
122-
ax.text(0.3, 0.05, "Overconfident", horizontalalignment="left")
225+
# ax.text(-0.03, 0.93, "Under-confident", horizontalalignment="left")
226+
# ax.text(0.3, 0.05, "Overconfident", horizontalalignment="left")
123227

124-
ax.legend(loc=legend_loc)
228+
if legend_loc is not None:
229+
ax.legend(loc=legend_loc)
230+
else:
231+
ax.legend()
125232

126-
ax.set_xlabel(x_label)
127233
ax.set_ylabel(y_label)
128234
ax.set_title(title)
129235

tests/conftest.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,8 +82,8 @@ def setUp(result_output):
8282
sim_paths = f"{simulator_config_path.strip('/')}/simulators.json"
8383
os.remove(sim_paths)
8484

85-
# out_dir = get_item("common", "out_dir", raise_exception=True)
86-
# shutil.rmtree(out_dir)
85+
out_dir = get_item("common", "out_dir", raise_exception=True)
86+
shutil.rmtree(out_dir)
8787

8888
@pytest.fixture
8989
def model_path():

tests/test_plots.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ def plot_config(config_factory):
2020
metrics_settings = {
2121
"use_progress_bar": False,
2222
"samples_per_inference": 10,
23-
"percentiles": [95],
23+
"percentiles": [95, 75, 50],
2424
}
2525
config = config_factory(metrics_settings=metrics_settings)
2626
return config

0 commit comments

Comments
 (0)