Skip to content

Commit e4ab602

Browse files
committed
Update plot styling, add residuals
1 parent 6ca13ce commit e4ab602

File tree

3 files changed

+113
-17
lines changed

3 files changed

+113
-17
lines changed

src/deepdiagnostics/plots/coverage_fraction.py

Lines changed: 110 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Union
1+
from typing import Optional, Union
22
import numpy as np
33
import matplotlib.pyplot as plt
44
from matplotlib.axes import Axes as ax
@@ -49,25 +49,82 @@ def __init__(
4949
def plot_name(self):
5050
return "coverage_fraction.png"
5151

52-
def _data_setup(self):
53-
_, coverage = coverage_fraction_metric(
54-
self.model, self.data, self.run_id, out_dir=None
52+
def _data_setup(self, percentile_step_size:float=1) -> DataDisplay:
53+
_, (coverage_mean, coverage_std) = coverage_fraction_metric(
54+
self.model, self.data, self.run_id, out_dir=None, percentiles=np.arange(0, 100, percentile_step_size), use_progress_bar=self.use_progress_bar
5555
).calculate()
5656
return DataDisplay(
57-
coverage_fractions=coverage
57+
coverage_fractions=coverage_mean,
58+
coverage_percentiles=np.arange(0, 100, percentile_step_size),
59+
coverage_std=coverage_std
5860
)
5961

62+
def _plot_residual(self, data_display, ax, figure_alpha, line_width, reference_line_style, include_coverage_residual_std, include_ideal_range):
63+
color_cycler = iter(plt.cycler("color", self.parameter_colors))
64+
line_style_cycler = iter(plt.cycler("line_style", self.line_cycle))
65+
percentile_array = data_display.coverage_percentiles / 100.0
66+
67+
ax.plot([0,1], [0, 0], reference_line_style, lw=line_width, zorder=1000)
68+
69+
for i in range(self.n_parameters):
70+
color = next(color_cycler)["color"]
71+
line_style = next(line_style_cycler)["line_style"]
72+
73+
residual = data_display.coverage_fractions[:, i] - np.linspace(0, 1, len(data_display.coverage_fractions[:,i]))
74+
75+
ax.plot(
76+
percentile_array,
77+
residual,
78+
alpha=figure_alpha,
79+
lw=line_width*.8,
80+
linestyle=line_style,
81+
color=color,
82+
label=self.parameter_names[i],
83+
)
84+
if include_coverage_residual_std:
85+
86+
ax.fill_between(
87+
percentile_array,
88+
residual - data_display.coverage_std[:, i],
89+
residual + data_display.coverage_std[:, i],
90+
color=color,
91+
alpha=0.2,
92+
)
93+
94+
if include_ideal_range:
95+
96+
ax.fill_between(
97+
[0, 1],
98+
[-0.2]*2,
99+
[0.2]*2,
100+
color="gray",
101+
alpha=0.1,
102+
)
103+
ax.fill_between(
104+
[0, 1],
105+
[-0.1]*2,
106+
[0.1]*2,
107+
color="gray",
108+
alpha=0.2,
109+
)
110+
60111
def plot(
61112
self,
62113
data_display: Union[DataDisplay, str],
63114
figure_alpha=1.0,
64115
line_width=3,
65-
legend_loc="lower right",
116+
legend_loc:Optional[str]=None,
117+
include_coverage_std:bool = False,
118+
include_coverage_residual:bool = False,
119+
include_coverage_residual_std:bool = False,
120+
include_ideal_range: bool=True,
66121
reference_line_label="Reference Line",
67122
reference_line_style="k--",
68123
x_label="Confidence Interval of the Posterior Volume",
69124
y_label="Fraction of Lenses within Posterior Volume",
70-
title="NPE") -> tuple["fig", "ax"]:
125+
residual_y_label="Coverage Fraction Residual",
126+
title="NPE"
127+
) -> tuple["fig", "ax"]:
71128
"""
72129
Args:
73130
figure_alpha (float, optional): Opacity of parameter lines. Defaults to 1.0.
@@ -83,19 +140,28 @@ def plot(
83140
if not isinstance(data_display, DataDisplay):
84141
data_display = DataDisplay().from_h5(data_display, self.plot_name)
85142

86-
n_steps = data_display.coverage_fractions.shape[0]
87-
percentile_array = np.linspace(0, 1, n_steps)
143+
144+
percentile_array = data_display.coverage_percentiles / 100.0
88145
color_cycler = iter(plt.cycler("color", self.parameter_colors))
89146
line_style_cycler = iter(plt.cycler("line_style", self.line_cycle))
90147

91148
# Plotting
92-
fig, ax = plt.subplots(1, 1, figsize=self.figure_size)
149+
if include_coverage_residual:
150+
fig, subplots = plt.subplots(2, 1, figsize=(self.figure_size[0], self.figure_size[1]*1.2), height_ratios=[3, 1], sharex=True)
151+
ax = subplots[0]
152+
153+
self._plot_residual(
154+
data_display, subplots[1], figure_alpha, line_width, reference_line_style, include_coverage_residual_std, include_ideal_range
155+
)
156+
subplots[1].set_ylabel(residual_y_label)
157+
158+
else:
159+
fig, ax = plt.subplots(1, 1, figsize=self.figure_size)
93160

94161
# Iterate over the number of parameters in the model
95162
for i in range(self.n_parameters):
96163
color = next(color_cycler)["color"]
97164
line_style = next(line_style_cycler)["line_style"]
98-
99165
ax.plot(
100166
percentile_array,
101167
data_display.coverage_fractions[:, i],
@@ -105,6 +171,14 @@ def plot(
105171
color=color,
106172
label=self.parameter_names[i],
107173
)
174+
if include_coverage_std:
175+
ax.fill_between(
176+
percentile_array,
177+
data_display.coverage_fractions[:, i] - data_display.coverage_std[:, i],
178+
data_display.coverage_fractions[:, i] + data_display.coverage_std[:, i],
179+
color=color,
180+
alpha=0.2,
181+
)
108182

109183
ax.plot(
110184
[0, 0.5, 1],
@@ -115,13 +189,35 @@ def plot(
115189
label=reference_line_label,
116190
)
117191

192+
if include_ideal_range:
193+
def add_clearance(ax, clearance=0.1, clearance_alpha=0.2):
194+
x_values = np.linspace(0, 1, 100) # More points for smoother curves
195+
y_lower = np.maximum(0, x_values - clearance) # Lower bound with clearance
196+
y_upper = np.minimum(1, x_values + clearance) # Upper bound with clearance
197+
198+
# Fill the area between the bounds
199+
ax.fill_between(
200+
x_values,
201+
y_lower,
202+
y_upper,
203+
color="gray",
204+
alpha=clearance_alpha,
205+
)
206+
207+
add_clearance(ax, clearance=0.2, clearance_alpha=0.2)
208+
add_clearance(ax, clearance=0.1, clearance_alpha=0.1)
209+
210+
118211
ax.set_xlim([-0.05, 1.05])
119212
ax.set_ylim([-0.05, 1.05])
120213

121-
ax.text(0.03, 0.93, "Under-confident", horizontalalignment="left")
122-
ax.text(0.3, 0.05, "Overconfident", horizontalalignment="left")
214+
# ax.text(-0.03, 0.93, "Under-confident", horizontalalignment="left")
215+
# ax.text(0.3, 0.05, "Overconfident", horizontalalignment="left")
123216

124-
ax.legend(loc=legend_loc)
217+
if legend_loc is not None:
218+
ax.legend(loc=legend_loc)
219+
else:
220+
ax.legend()
125221

126222
ax.set_xlabel(x_label)
127223
ax.set_ylabel(y_label)

tests/conftest.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,8 +82,8 @@ def setUp(result_output):
8282
sim_paths = f"{simulator_config_path.strip('/')}/simulators.json"
8383
os.remove(sim_paths)
8484

85-
# out_dir = get_item("common", "out_dir", raise_exception=True)
86-
# shutil.rmtree(out_dir)
85+
out_dir = get_item("common", "out_dir", raise_exception=True)
86+
shutil.rmtree(out_dir)
8787

8888
@pytest.fixture
8989
def model_path():

tests/test_plots.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ def plot_config(config_factory):
2020
metrics_settings = {
2121
"use_progress_bar": False,
2222
"samples_per_inference": 10,
23-
"percentiles": [95],
23+
"percentiles": [95, 75, 50],
2424
}
2525
config = config_factory(metrics_settings=metrics_settings)
2626
return config

0 commit comments

Comments
 (0)