Skip to content

Commit 24e70aa

Browse files
arrjonjerrymhuang
andauthored
Adding Diganostics to Ecdf Plots in the spirit of TARP (#261)
* ecdf with random points * single axis * single axis * add comments * clean * clean * title fix * docstring * posterior 2d * posterior 2d fix * posterior 2d fix * posterior 2d fix * fix reference * clean up * add comment * add tests * pass kwargs * fix title * make more customizable * fix conflict --------- Co-authored-by: Jerry <[email protected]>
1 parent 0537f2a commit 24e70aa

File tree

8 files changed

+255
-39
lines changed

8 files changed

+255
-39
lines changed

bayesflow/diagnostics/plot_posterior_2d.py

Lines changed: 53 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import matplotlib.pyplot as plt
2+
13
import numpy as np
24
import pandas as pd
35
import seaborn as sns
@@ -8,10 +10,11 @@
810

911

1012
def plot_posterior_2d(
11-
post_samples: dict[str, np.ndarray] | np.ndarray,
12-
prior_samples: dict[str, np.ndarray] | np.ndarray,
13+
post_samples: np.ndarray,
14+
prior_samples: np.ndarray = None,
1315
prior=None,
14-
param_names: list = None,
16+
variable_names: list = None,
17+
true_params: np.ndarray = None,
1518
height: int = 3,
1619
label_fontsize: int = 14,
1720
legend_fontsize: int = 16,
@@ -24,15 +27,17 @@ def plot_posterior_2d(
2427
) -> sns.PairGrid:
2528
"""Generates a bivariate pairplot given posterior draws and optional prior or prior draws.
2629
27-
posterior_draws : np.ndarray of shape (n_post_draws, n_params)
30+
post_samples : np.ndarray of shape (n_post_draws, n_params)
2831
The posterior draws obtained for a SINGLE observed data set.
29-
prior : bayesflow.forward_inference.Prior instance or None, optional, default: None
30-
The optional prior object having an input-output signature as given by ayesflow.forward_inference.Prior
31-
prior_draws : np.ndarray of shape (n_prior_draws, n_params) or None, optonal (default: None)
32-
The optional prior draws obtained from the prior. If both prior and prior_draws are provided, prior_draws
32+
prior_samples : np.ndarray of shape (n_prior_draws, n_params) or None, optional (default: None)
33+
The optional prior samples obtained from the prior. If both prior and prior_samples are provided, prior_samples
3334
will be used.
34-
param_names : list or None, optional, default: None
35+
prior : bayesflow.forward_inference.Prior instance or None, optional, default: None
36+
The optional prior object having an input-output signature as given by bayesflow.forward_inference.Prior
37+
variable_names : list or None, optional, default: None
3538
The parameter names for nice plot titles. Inferred if None
39+
true_params : np.ndarray of shape (n_params,) or None, optional, default: None
40+
The true parameter values to be plotted on the diagonal.
3641
height : float, optional, default: 3
3742
The height of the pairplot
3843
label_fontsize : int, optional, default: 14
@@ -41,7 +46,7 @@ def plot_posterior_2d(
4146
The font size of the legend text
4247
tick_fontsize : int, optional, default: 12
4348
The font size of the axis ticklabels
44-
post_color : str, optional, default: '#8f2727'
49+
post_color : str, optional, default: '#132a70'
4550
The color for the posterior histograms and KDEs
4651
priors_color : str, optional, default: gray
4752
The color for the optional prior histograms and KDEs
@@ -64,7 +69,10 @@ def plot_posterior_2d(
6469
assert (len(post_samples.shape)) == 2, "Shape of `posterior_samples` for a single data set should be 2 dimensional!"
6570

6671
# Plot posterior first
67-
g = plot_samples_2d(post_samples, context="\\theta", param_names=param_names, render=False, height=height, **kwargs)
72+
context = ""
73+
g = plot_samples_2d(
74+
post_samples, context=context, variable_names=variable_names, render=False, height=height, **kwargs
75+
)
6876

6977
# Obtain n_draws and n_params
7078
n_draws, n_params = post_samples.shape
@@ -73,34 +81,54 @@ def plot_posterior_2d(
7381
if prior is not None and prior_samples is None:
7482
draws = prior(n_draws)
7583
if isinstance(draws, dict):
76-
prior_draws = draws["prior_draws"]
84+
prior_samples = draws["prior_draws"]
7785
else:
78-
prior_draws = draws
86+
prior_samples = draws
87+
elif prior_samples is not None:
88+
# trim to the same number of draws as posterior
89+
prior_samples = prior_samples[:n_draws]
7990

8091
# Attempt to determine parameter names
81-
if param_names is None:
92+
if variable_names is None:
8293
if hasattr(prior, "param_names"):
83-
if prior.param_names is not None:
84-
param_names = prior.param_names
94+
if prior.variable_names is not None:
95+
variable_names = prior.variable_names
8596
else:
86-
param_names = [f"$\\theta_{{{i}}}$" for i in range(1, n_params + 1)]
97+
variable_names = [f"{context} $\\theta_{{{i}}}$" for i in range(1, n_params + 1)]
8798
else:
88-
param_names = [f"$\\theta_{{{i}}}$" for i in range(1, n_params + 1)]
99+
variable_names = [f"{context} $\\theta_{{{i}}}$" for i in range(1, n_params + 1)]
100+
else:
101+
variable_names = [f"{context} {p}" for p in variable_names]
89102

90103
# Add prior, if given
91-
if prior_draws is not None:
92-
prior_draws_df = pd.DataFrame(prior_draws, columns=param_names)
93-
g.data = prior_draws_df
104+
if prior_samples is not None:
105+
prior_samples_df = pd.DataFrame(prior_samples, columns=variable_names)
106+
g.data = prior_samples_df
94107
g.map_diag(sns.histplot, fill=True, color=prior_color, alpha=prior_alpha, kde=True, zorder=-1)
95108
g.map_lower(sns.kdeplot, fill=True, color=prior_color, alpha=prior_alpha, zorder=-1)
96109

110+
# Add true parameters
111+
if true_params is not None:
112+
# Custom function to plot true_params on the diagonal
113+
def plot_true_params(x, **kwargs):
114+
param = x.iloc[0] # Get the single true value for the diagonal
115+
plt.axvline(param, color="black", linestyle="--") # Add vertical line
116+
117+
# only plot on the diagonal a vertical line for the true parameter
118+
g.data = pd.DataFrame(true_params[np.newaxis], columns=variable_names)
119+
g.map_diag(plot_true_params)
120+
97121
# Add legend, if prior also given
98-
if prior_draws is not None or prior is not None:
122+
if prior_samples is not None or prior is not None:
99123
handles = [
100124
Line2D(xdata=[], ydata=[], color=post_color, lw=3, alpha=post_alpha),
101125
Line2D(xdata=[], ydata=[], color=prior_color, lw=3, alpha=prior_alpha),
102126
]
103-
g.legend(handles, ["Posterior", "Prior"], fontsize=legend_fontsize, loc="center right")
127+
handles_names = ["Posterior", "Prior"]
128+
if true_params is not None:
129+
handles.append(Line2D(xdata=[], ydata=[], color="black", lw=3, linestyle="--"))
130+
handles_names.append("True Parameter")
131+
plt.legend(handles=handles, labels=handles_names, fontsize=legend_fontsize, loc="center right")
104132

105133
n_row, n_col = g.axes.shape
106134

@@ -115,9 +143,9 @@ def plot_posterior_2d(
115143
g.axes[i, j].tick_params(axis="both", which="minor", labelsize=tick_fontsize)
116144

117145
# Add nice labels
118-
for i, param_name in enumerate(param_names):
146+
for i, param_name in enumerate(variable_names):
119147
g.axes[i, 0].set_ylabel(param_name, fontsize=label_fontsize)
120-
g.axes[len(param_names) - 1, i].set_xlabel(param_name, fontsize=label_fontsize)
148+
g.axes[len(variable_names) - 1, i].set_xlabel(param_name, fontsize=label_fontsize)
121149

122150
# Add grids
123151
for i in range(n_params):

bayesflow/diagnostics/plot_sbc_ecdf.py

Lines changed: 41 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from typing import Sequence
55
from ..utils.plot_utils import preprocess, add_titles_and_labels, prettify_subplots
66
from ..utils.ecdf import simultaneous_ecdf_bands
7+
from ..utils.ecdf.ranks import fractional_ranks, distance_ranks
78

89

910
def plot_sbc_ecdf(
@@ -13,6 +14,7 @@ def plot_sbc_ecdf(
1314
variable_names: Sequence[str] = None,
1415
difference: bool = False,
1516
stacked: bool = False,
17+
rank_type: str | np.ndarray = "fractional",
1618
figsize: Sequence[float] = None,
1719
label_fontsize: int = 16,
1820
legend_fontsize: int = 14,
@@ -33,11 +35,20 @@ def plot_sbc_ecdf(
3335
For models with many parameters, use `stacked=True` to obtain an idea
3436
of the overall calibration of a posterior approximator.
3537
38+
To compute ranks based on the Euclidean distance to the origin or a reference, use `rank_type='distance'` (and
39+
pass a reference array, respectively). This can be used to check the joint calibration of the posterior approximator
40+
and might show potential biases in the posterior approximation which are not detected by the fractional ranks (e.g.,
41+
when the prior equals the posterior). This is motivated by [2].
42+
3643
[1] Säilynoja, T., Bürkner, P. C., & Vehtari, A. (2022). Graphical test
3744
for discrete uniformity and its applications in goodness-of-fit evaluation
3845
and multiple sample comparison. Statistics and Computing, 32(2), 1-21.
3946
https://arxiv.org/abs/2103.10522
4047
48+
[2] Lemos, Pablo, et al. "Sampling-based accuracy testing of posterior estimators
49+
for general inference." International Conference on Machine Learning. PMLR, 2023.
50+
https://proceedings.mlr.press/v202/lemos23a.html
51+
4152
Parameters
4253
----------
4354
post_samples : np.ndarray of shape (n_data_sets, n_post_draws, n_params)
@@ -51,6 +62,11 @@ def plot_sbc_ecdf(
5162
If `True`, all ECDFs will be plotted on the same plot.
5263
If `False`, each ECDF will have its own subplot,
5364
similar to the behavior of `plot_sbc_histograms`.
65+
rank_type : str, optional, default: 'fractional'
66+
If `fractional` (default), the ranks are computed as the fraction of posterior samples that are smaller than
67+
the prior. If `distance`, the ranks are computed as the fraction of posterior samples that are closer to
68+
a reference points (default here is the origin). You can pass a reference array in the same shape as the
69+
`prior_samples` array by setting `references` in the ``ranks_kwargs``. This is motivated by [2].
5470
variable_names : list or None, optional, default: None
5571
The parameter names for nice plot titles.
5672
Inferred if None. Only relevant if `stacked=False`.
@@ -79,7 +95,9 @@ def plot_sbc_ecdf(
7995
**kwargs : dict, optional, default: {}
8096
Keyword arguments can be passed to control the behavior of
8197
ECDF simultaneous band computation through the ``ecdf_bands_kwargs``
82-
dictionary. See `simultaneous_ecdf_bands` for keyword arguments
98+
dictionary. See `simultaneous_ecdf_bands` for keyword arguments.
99+
Moreover, additional keyword arguments can be passed to control the behavior of
100+
the rank computation through the ``ranks_kwargs`` dictionary.
83101
84102
Returns
85103
-------
@@ -90,6 +108,8 @@ def plot_sbc_ecdf(
90108
ShapeError
91109
If there is a deviation form the expected shapes of `post_samples`
92110
and `prior_samples`.
111+
ValueError
112+
If an unknown `rank_type` is passed.
93113
"""
94114

95115
# Preprocessing
@@ -99,8 +119,16 @@ def plot_sbc_ecdf(
99119
plot_data["post_samples"] = plot_data.pop("post_variables")
100120
plot_data["prior_samples"] = plot_data.pop("prior_variables")
101121

102-
# Compute fractional ranks (using broadcasting)
103-
ranks = np.mean(plot_data["post_samples"] < plot_data["prior_samples"][:, np.newaxis, :], axis=1)
122+
if rank_type == "fractional":
123+
# Compute fractional ranks
124+
ranks = fractional_ranks(plot_data["post_samples"], plot_data["prior_samples"])
125+
elif rank_type == "distance":
126+
# Compute ranks based on distance to the origin
127+
ranks = distance_ranks(
128+
plot_data["post_samples"], plot_data["prior_samples"], stacked=stacked, **kwargs.pop("ranks_kwargs", {})
129+
)
130+
else:
131+
raise ValueError(f"Unknown rank type: {rank_type}. Use 'fractional' or 'distance'.")
104132

105133
# Plot individual ecdf of parameters
106134
for j in range(ranks.shape[-1]):
@@ -114,6 +142,8 @@ def plot_sbc_ecdf(
114142

115143
if stacked:
116144
if j == 0:
145+
if not isinstance(plot_data["axes"], np.ndarray):
146+
plot_data["axes"] = np.array([plot_data["axes"]]) # in case of single axis
117147
plot_data["axes"][0].plot(xx, yy, color=rank_ecdf_color, alpha=0.95, label="Rank ECDFs")
118148
else:
119149
plot_data["axes"][0].plot(xx, yy, color=rank_ecdf_color, alpha=0.95)
@@ -132,7 +162,13 @@ def plot_sbc_ecdf(
132162
ylab = "ECDF"
133163

134164
# Add simultaneous bounds
135-
titles = plot_data["variable_names"] if not stacked else ["Stacked ECDFs"]
165+
if not stacked:
166+
titles = plot_data["variable_names"]
167+
elif rank_type in ["distance", "random"]:
168+
titles = ["Joint ECDFs"]
169+
else:
170+
titles = ["Stacked ECDFs"]
171+
136172
for ax, title in zip(plot_data["axes"].flat, titles):
137173
ax.fill_between(z, L, H, color=fill_color, alpha=0.2, label=rf"{int((1-alpha) * 100)}$\%$ Confidence Bands")
138174
ax.legend(fontsize=legend_fontsize)
@@ -145,7 +181,7 @@ def plot_sbc_ecdf(
145181
plot_data["axes"],
146182
plot_data["num_row"],
147183
plot_data["num_col"],
148-
xlabel="Fractional rank statistic",
184+
xlabel=f"{rank_type.capitalize()} rank statistic",
149185
ylabel=ylab,
150186
label_fontsize=label_fontsize,
151187
)

bayesflow/utils/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
split_tensors,
1212
)
1313
from .dispatch import find_distribution, find_network, find_permutation, find_pooling, find_recurrent_net
14-
from .ecdf import simultaneous_ecdf_bands
14+
from .ecdf import simultaneous_ecdf_bands, ranks
1515
from .functional import batched_call
1616
from .git import (
1717
issue_url,

bayesflow/utils/ecdf/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
11
from .simultaneous_ecdf_bands import simultaneous_ecdf_bands
2+
from .ranks import fractional_ranks, distance_ranks

bayesflow/utils/ecdf/ranks.py

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
import numpy as np
2+
3+
4+
def fractional_ranks(post_samples: np.ndarray, prior_samples: np.ndarray) -> np.ndarray:
5+
"""Compute fractional ranks (using broadcasting)"""
6+
return np.mean(post_samples < prior_samples[:, np.newaxis, :], axis=1)
7+
8+
9+
def _helper_distance_ranks(
10+
post_samples: np.ndarray,
11+
prior_samples: np.ndarray,
12+
stacked: bool,
13+
references: np.ndarray,
14+
distance: callable,
15+
p_norm: int,
16+
) -> np.ndarray:
17+
"""
18+
Helper function to compute ranks of true parameter wrt posterior samples
19+
based on distances (defined on the p_norm) between samples and a given references.
20+
"""
21+
if distance is None:
22+
# compute distances to references
23+
dist_post = np.abs((references[:, np.newaxis, :] - post_samples))
24+
dist_prior = np.abs(references - prior_samples)
25+
26+
if stacked:
27+
# compute ranks for all parameters jointly
28+
samples_distances = np.sum(dist_post**p_norm, axis=-1) ** (1 / p_norm)
29+
theta_distances = np.sum(dist_prior**p_norm, axis=-1) ** (1 / p_norm)
30+
31+
ranks = np.mean((samples_distances < theta_distances[:, np.newaxis]), axis=1)[:, np.newaxis]
32+
else:
33+
# compute marginal ranks for each parameter
34+
ranks = np.mean((dist_post < dist_prior[:, np.newaxis]), axis=1)
35+
36+
else:
37+
# compute distances using the given distance function
38+
if stacked:
39+
# compute distance over joint parameters
40+
dist_post = np.array([distance(post_samples[i], references[i]) for i in range(references.shape[0])])
41+
dist_prior = np.array([distance(prior_samples[i], references[i]) for i in range(references.shape[0])])
42+
ranks = np.mean((dist_post < dist_prior[:, np.newaxis]), axis=1)[:, np.newaxis]
43+
else:
44+
# compute distances per parameter
45+
dist_post = np.zeros_like(post_samples)
46+
dist_prior = np.zeros_like(prior_samples)
47+
for i in range(references.shape[0]): # Iterate over samples
48+
for j in range(references.shape[1]): # Iterate over parameters
49+
dist_post[i, :, j] = distance(post_samples[i, :, j], references[i, j])
50+
dist_prior[i, j] = distance(prior_samples[i, j], references[i, j])
51+
52+
ranks = np.mean((dist_post < dist_prior[:, np.newaxis]), axis=1)
53+
return ranks
54+
55+
56+
def distance_ranks(
57+
post_samples: np.ndarray,
58+
prior_samples: np.ndarray,
59+
stacked: bool,
60+
references: np.ndarray = None,
61+
distance: callable = None,
62+
p_norm: int = 2,
63+
) -> np.ndarray:
64+
"""
65+
Compute ranks of true parameter wrt posterior samples based on distances between samples and optional references.
66+
67+
Parameters
68+
----------
69+
post_samples : np.ndarray
70+
The posterior samples.
71+
prior_samples : np.ndarray
72+
The prior samples.
73+
references : np.ndarray, optional
74+
The references to compute the ranks.
75+
stacked : bool
76+
If True, compute ranks for all parameters jointly. Otherwise, compute marginal ranks.
77+
distance : callable, optional
78+
The distance function to compute the ranks. If None, the distance defined by the p_norm is used. Must be
79+
a function that takes two arrays (if stacked, it gets the full parameter vectors, if not only the single
80+
parameters) and returns an array with the distances. This could be based on the log-posterior, for example.
81+
p_norm : int, optional
82+
The norm to compute the distance if no distance is passed. Default is L2-norm.
83+
"""
84+
# Reference is the origin
85+
if references is None:
86+
references = np.zeros((prior_samples.shape[0], prior_samples.shape[1]))
87+
else:
88+
# Validate reference
89+
if references.shape[0] != prior_samples.shape[0]:
90+
raise ValueError("The number of references must match the number of prior samples.")
91+
if references.shape[1] != prior_samples.shape[1]:
92+
raise ValueError("The dimension of references must match the dimension of the parameters.")
93+
94+
ranks = _helper_distance_ranks(
95+
post_samples=post_samples,
96+
prior_samples=prior_samples,
97+
stacked=stacked,
98+
references=references,
99+
distance=distance,
100+
p_norm=p_norm,
101+
)
102+
return ranks

0 commit comments

Comments
 (0)