Skip to content

Commit 15d00c1

Browse files
committed
Improve diagnostics
1 parent 3c29d9e commit 15d00c1

File tree

7 files changed

+371
-137
lines changed

7 files changed

+371
-137
lines changed

bayesflow/diagnostics/plots/calibration_ecdf.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,13 @@ def calibration_ecdf(
1515
variable_keys: Sequence[str] = None,
1616
variable_names: Sequence[str] = None,
1717
test_quantities: dict[str, Callable] = None,
18-
difference: bool = False,
18+
difference: bool = True,
1919
stacked: bool = False,
2020
rank_type: str | np.ndarray = "fractional",
2121
figsize: Sequence[float] = None,
2222
label_fontsize: int = 16,
2323
legend_fontsize: int = 14,
24-
legend_location: str = "upper right",
24+
legend_location: str = "lower right",
2525
title_fontsize: int = 18,
2626
tick_fontsize: int = 12,
2727
rank_ecdf_color: str = "#132a70",
@@ -59,7 +59,7 @@ def calibration_ecdf(
5959
The posterior draws obtained from n_data_sets
6060
targets : np.ndarray of shape (n_data_sets, n_params)
6161
The prior draws obtained for generating n_data_sets
62-
difference : bool, optional, default: False
62+
difference : bool, optional, default: True
6363
If `True`, plots the ECDF difference.
6464
Enables a more dynamic visualization range.
6565
stacked : bool, optional, default: False
@@ -98,7 +98,9 @@ def calibration_ecdf(
9898
label_fontsize : int, optional, default: 16
9999
The font size of the y-label and y-label texts
100100
legend_fontsize : int, optional, default: 14
101-
The font size of the legend text
101+
The font size of the legend text.
102+
legend_location : str, optional, default: 'lower right
103+
The location of the legend.
102104
title_fontsize : int, optional, default: 18
103105
The font size of the title text.
104106
Only relevant if `stacked=False`
@@ -211,11 +213,13 @@ def calibration_ecdf(
211213
else:
212214
titles = ["Stacked ECDFs"]
213215

214-
for ax, title in zip(plot_data["axes"].flat, titles):
216+
for i, (ax, title) in enumerate(zip(plot_data["axes"].flat, titles)):
215217
ax.fill_between(z, L, U, color=fill_color, alpha=0.2, label=rf"{int((1 - alpha) * 100)}$\%$ Confidence Bands")
216-
ax.legend(fontsize=legend_fontsize, loc=legend_location)
217218
ax.set_title(title, fontsize=title_fontsize)
218219

220+
if i == 0:
221+
ax.legend(fontsize=legend_fontsize, loc=legend_location)
222+
219223
prettify_subplots(plot_data["axes"], num_subplots=plot_data["num_variables"], tick_fontsize=tick_fontsize)
220224

221225
add_titles_and_labels(

bayesflow/diagnostics/plots/calibration_ecdf_from_quantiles.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,12 @@ def calibration_ecdf_from_quantiles(
1414
quantiles_key: str = "quantiles",
1515
variable_keys: Sequence[str] = None,
1616
variable_names: Sequence[str] = None,
17-
difference: bool = False,
17+
difference: bool = True,
1818
stacked: bool = False,
1919
figsize: Sequence[float] = None,
2020
label_fontsize: int = 16,
2121
legend_fontsize: int = 14,
22-
legend_location: str = "upper right",
22+
legend_location: str = "lower right",
2323
title_fontsize: int = 18,
2424
tick_fontsize: int = 12,
2525
rank_ecdf_color: str = "#132a70",
@@ -69,7 +69,7 @@ def calibration_ecdf_from_quantiles(
6969
variable_names : list or None, optional, default: None
7070
The parameter names for nice plot titles.
7171
Inferred if None. Only relevant if `stacked=False`.
72-
difference : bool, optional, default: False
72+
difference : bool, optional, default: True
7373
If `True`, plots the ECDF difference.
7474
Enables a more dynamic visualization range.
7575
stacked : bool, optional, default: False
@@ -82,7 +82,9 @@ def calibration_ecdf_from_quantiles(
8282
label_fontsize : int, optional, default: 16
8383
The font size of the y-label and y-label texts
8484
legend_fontsize : int, optional, default: 14
85-
The font size of the legend text
85+
The font size of the legend text.
86+
legend_location : str, optional, default: 'lower right
87+
The location of the legend.
8688
title_fontsize : int, optional, default: 18
8789
The font size of the title text.
8890
Only relevant if `stacked=False`

bayesflow/diagnostics/plots/coverage.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ def coverage(
1717
legend_fontsize: int = 14,
1818
title_fontsize: int = 18,
1919
tick_fontsize: int = 12,
20-
legend_location: str = "upper right",
20+
legend_location: str = "lower right",
2121
color: str = "#132a70",
2222
num_col: int = None,
2323
num_row: int = None,
@@ -41,7 +41,7 @@ def coverage(
4141
The posterior draws obtained from num_datasets
4242
targets : np.ndarray of shape (num_datasets, num_params)
4343
The true parameter values used for generating num_datasets
44-
difference : bool, optional, default: False
44+
difference : bool, optional, default: True
4545
If True, plots the difference between empirical coverage and ideal coverage
4646
(coverage - width), making deviations from ideal calibration more visible.
4747
If False, plots the standard coverage plot.
@@ -60,6 +60,8 @@ def coverage(
6060
The font size of the title text
6161
tick_fontsize : int, optional, default: 12
6262
The font size of the axis ticklabels
63+
legend_location : str, optional, default: 'upper right
64+
The location of the legend.
6365
color : str, optional, default: '#132a70'
6466
The color for the coverage line
6567
num_row : int, optional, default: None
@@ -132,7 +134,7 @@ def coverage(
132134
)
133135

134136
# Plot ideal coverage difference line (y = 0)
135-
ax.axhline(y=0, color="skyblue", linewidth=2.0, label="Ideal Coverage")
137+
ax.axhline(y=0, color="black", linestyle="dashed", label="Ideal Coverage")
136138

137139
# Plot empirical coverage difference
138140
ax.plot(width_rep, diff_est, color=color, alpha=1.0, label="Coverage Difference")
@@ -149,23 +151,19 @@ def coverage(
149151
)
150152

151153
# Plot ideal coverage line (y = x)
152-
ax.plot([0, 1], [0, 1], color="skyblue", linewidth=2.0, label="Ideal Coverage")
154+
ax.plot([0, 1], [0, 1], color="black", linestyle="dashed", label="Ideal Coverage")
153155

154156
# Plot empirical coverage
155157
ax.plot(width_rep, coverage_est, color=color, alpha=1.0, label="Empirical Coverage")
156158

157-
# Set axis limits
158-
ax.set_xlim(0, 1)
159-
ax.set_ylim(0, 1)
160-
161159
# Add legend to first subplot
162160
if i == 0:
163161
ax.legend(fontsize=legend_fontsize, loc=legend_location)
164162

165163
prettify_subplots(plot_data["axes"], num_subplots=plot_data["num_variables"], tick_fontsize=tick_fontsize)
166164

167165
# Add labels, titles, and set font sizes
168-
ylabel = "Observed coverage difference" if difference else "Observed coverage"
166+
ylabel = "Empirical coverage difference" if difference else "Empirical coverage"
169167
add_titles_and_labels(
170168
axes=plot_data["axes"],
171169
num_row=plot_data["num_row"],

bayesflow/diagnostics/plots/loss.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,7 @@ def loss(
168168
num_col=1,
169169
title=["Loss Trajectory"],
170170
xlabel="Training epoch #",
171-
ylabel="Value",
171+
ylabel="Loss",
172172
title_fontsize=title_fontsize,
173173
label_fontsize=label_fontsize,
174174
)

bayesflow/diagnostics/plots/recovery.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,8 +87,10 @@ def recovery(
8787
The number of rows for the subplots. Dynamically determined if None.
8888
num_col : int, optional, default: None
8989
The number of columns for the subplots. Dynamically determined if None.
90-
xlabel :
91-
ylabel :
90+
xlabel : str, optional, default: "Ground truth"
91+
The label shown on the x-axis.
92+
ylabel : str, optional, default: "Estimate"
93+
The label shown on the y-axis.
9294
markersize : float, optional, default: None
9395
The marker size in points.
9496

examples/Multimodal_Data.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -524,7 +524,7 @@
524524
},
525525
{
526526
"cell_type": "code",
527-
"execution_count": 16,
527+
"execution_count": null,
528528
"id": "2415fd0b-f5d6-4fc9-83d7-8952e6270186",
529529
"metadata": {},
530530
"outputs": [

examples/SIR_Posterior_Estimation.ipynb

Lines changed: 342 additions & 114 deletions
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)