Skip to content

Commit 8464741

Browse files
Add PSIS k-hat diagnostic for variational inference (#2139)
Implement Pareto Smoothed Importance Sampling (PSIS) diagnostic to evaluate variational approximation quality, as requested in #1804. The k-hat statistic is the shape parameter of a Generalized Pareto Distribution fitted to the upper tail of importance weights. It indicates whether the guide is a reliable approximation: k < 0.5: good (finite variance) 0.5 <= k < 0.7: marginal (finite mean) k >= 0.7: unreliable GPD fitting uses Zhang & Stephens (2009) with prior regularization from Vehtari et al. (2024), matching Pyro's implementation and Vehtari's reference code to ~1e-15.
1 parent d5598e7 commit 8464741

File tree

4 files changed

+612
-0
lines changed

4 files changed

+612
-0
lines changed

docs/source/utilities.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,10 @@ log_likelihood
6868
--------------
6969
.. autofunction:: numpyro.infer.util.log_likelihood
7070

71+
psis_diagnostic
72+
---------------
73+
.. autofunction:: numpyro.infer.importance.psis_diagnostic
74+
7175
find_valid_initial_params
7276
-------------------------
7377
.. autofunction:: numpyro.infer.util.find_valid_initial_params

numpyro/infer/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from numpyro.infer.ensemble import AIES, ESS
1515
from numpyro.infer.hmc import HMC, NUTS
1616
from numpyro.infer.hmc_gibbs import HMCECS, DiscreteHMCGibbs, HMCGibbs
17+
from numpyro.infer.importance import psis_diagnostic
1718
from numpyro.infer.initialization import (
1819
init_to_feasible,
1920
init_to_mean,
@@ -40,6 +41,7 @@
4041
"init_to_uniform",
4142
"init_to_value",
4243
"log_likelihood",
44+
"psis_diagnostic",
4345
"reparam",
4446
"BarkerMH",
4547
"DiscreteHMCGibbs",

numpyro/infer/importance.py

Lines changed: 259 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,259 @@
1+
# Copyright Contributors to the Pyro project.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
"""
5+
Pareto Smoothed Importance Sampling (PSIS) diagnostics for variational inference.
6+
7+
Implements the k-hat diagnostic from:
8+
Yao, Y., Vehtari, A., Simpson, D., and Gelman, A. (2018).
9+
Yes, but Did It Work?: Evaluating Variational Inference.
10+
International Conference on Machine Learning.
11+
12+
Vehtari, A., Simpson, D., Gelman, A., Yao, Y., and Gabry, J. (2024).
13+
Pareto smoothed importance sampling.
14+
Journal of Machine Learning Research, 25(72):1-58.
15+
"""
16+
17+
from __future__ import annotations
18+
19+
from collections.abc import Callable
20+
import math
21+
import warnings
22+
23+
import numpy as np
24+
25+
import jax
26+
from jax import device_get, random
27+
28+
from numpyro.handlers import seed
29+
from numpyro.infer.elbo import get_importance_log_probs
30+
31+
__all__ = ["psis_diagnostic"]
32+
33+
34+
def _fit_generalized_pareto(x: np.ndarray) -> tuple[float, float]:
35+
"""Estimate parameters of the Generalized Pareto Distribution (GPD).
36+
37+
Returns empirical Bayes estimates for the shape (k) and scale (sigma)
38+
parameters of the two-parameter GPD, using the method of Zhang and
39+
Stephens (2009) with the prior regularization from Vehtari et al. (2024).
40+
41+
References:
42+
Zhang, J. and Stephens, M.A. (2009). A new and efficient estimation
43+
method for the generalized Pareto distribution. Technometrics,
44+
51(3):316-325.
45+
46+
Vehtari, A., Simpson, D., Gelman, A., Yao, Y., and Gabry, J. (2024).
47+
Pareto smoothed importance sampling. Journal of Machine Learning
48+
Research, 25(72):1-58.
49+
50+
:param numpy.ndarray x: one-dimensional array of positive exceedances (tail samples).
51+
:return: tuple of (k, sigma) where k is the shape parameter and sigma
52+
is the scale parameter.
53+
"""
54+
if x.ndim != 1 or len(x) <= 1:
55+
raise ValueError(
56+
f"Expected 1-D array with at least 2 elements, got shape {x.shape}."
57+
)
58+
59+
# Broad errstate is needed because degenerate inputs (e.g. zeros or
60+
# identical values) cause cascading numerical issues at multiple points:
61+
# divide: 1/x[quartile], 1/x[-1], -k/b when tail values are zero
62+
# over: exp(L - L') when profile log-likelihood differences are large
63+
# invalid: downstream ops on nan/inf from earlier divide-by-zero
64+
# The resulting nan/inf propagate correctly through the algorithm,
65+
# matching the reference implementation behavior.
66+
with np.errstate(divide="ignore", over="ignore", invalid="ignore"):
67+
return _fit_generalized_pareto_impl(x)
68+
69+
70+
def _fit_generalized_pareto_impl(x: np.ndarray) -> tuple[float, float]:
71+
x = np.sort(x)
72+
n = len(x)
73+
PRIOR = 3
74+
m = 30 + int(np.sqrt(n))
75+
76+
# Candidate shape parameters (Zhang & Stephens grid)
77+
bs = np.arange(1, m + 1, dtype=float)
78+
bs -= 0.5
79+
np.divide(m, bs, out=bs)
80+
np.sqrt(bs, out=bs)
81+
np.subtract(1, bs, out=bs)
82+
bs /= PRIOR * x[int(n / 4 + 0.5) - 1]
83+
bs += 1 / x[-1]
84+
85+
# Profile log-likelihood for each candidate
86+
ks = np.negative(bs)
87+
temp = ks[:, None] * x
88+
np.log1p(temp, out=temp)
89+
np.mean(temp, axis=1, out=ks)
90+
91+
L = bs / ks
92+
np.negative(L, out=L)
93+
np.log(L, out=L)
94+
L -= ks
95+
L -= 1
96+
L *= n
97+
98+
# Posterior weights (overflow in exp is expected and harmless;
99+
# overflowed values get negligible weight after normalization)
100+
temp = L - L[:, None]
101+
np.exp(temp, out=temp)
102+
w = np.sum(temp, axis=1)
103+
np.divide(1, w, out=w)
104+
105+
# Remove negligible weights
106+
dii = w >= 10 * np.finfo(float).eps
107+
if not np.all(dii):
108+
w = w[dii]
109+
bs = bs[dii]
110+
w /= w.sum()
111+
112+
# Posterior mean for b
113+
b = np.sum(bs * w)
114+
115+
# Estimate for k (note: negated relative to Zhang & Stephens)
116+
temp = (-b) * x
117+
np.log1p(temp, out=temp)
118+
k = np.mean(temp)
119+
120+
# Estimate for sigma
121+
sigma = -k / b
122+
123+
# Weakly informative prior for k (Vehtari et al. 2024, Appendix G)
124+
# Prior: mean=0.5, effective sample size a=10
125+
a = 10
126+
k = k * n / (n + a) + a * 0.5 / (n + a)
127+
128+
return float(k), float(sigma)
129+
130+
131+
def _compute_log_weights(
132+
rng_key: jax.Array,
133+
param_map: dict[str, jax.Array],
134+
model: Callable,
135+
guide: Callable,
136+
args: tuple,
137+
kwargs: dict,
138+
) -> jax.Array:
139+
"""Compute log importance weight log p(x,z) - log q(z) for a single particle."""
140+
# Separate seeds: guide needs its own randomness for sampling latent sites;
141+
# model gets an independent seed in case it has stochastic structure beyond
142+
# the latent sites replayed from the guide (e.g. stochastic control flow).
143+
model_seed, guide_seed = random.split(rng_key)
144+
seeded_model = seed(model, model_seed)
145+
seeded_guide = seed(guide, guide_seed)
146+
model_log_probs, guide_log_probs = get_importance_log_probs(
147+
seeded_model, seeded_guide, args, kwargs, param_map
148+
)
149+
log_model = sum(v.sum() for v in model_log_probs.values())
150+
log_guide = sum(v.sum() for v in guide_log_probs.values())
151+
return log_model - log_guide
152+
153+
154+
def _psis_khat(log_weights: np.ndarray) -> float:
155+
"""Compute PSIS k-hat from an array of raw log importance weights."""
156+
log_weights = log_weights.copy()
157+
log_weights -= log_weights.max()
158+
log_weights = np.sort(log_weights)
159+
160+
# S matches notation in Vehtari et al. (2024), Algorithm 1
161+
S = len(log_weights)
162+
163+
# Tail extraction (Vehtari et al. 2024, Algorithm 1)
164+
M = math.ceil(min(0.2 * S, 3 * math.sqrt(S)))
165+
cutoff_ind = -(M + 1)
166+
lw_cutoff = max(np.log(np.finfo(float).tiny), log_weights[cutoff_ind])
167+
168+
lw_tail = log_weights[log_weights > lw_cutoff]
169+
170+
if len(lw_tail) < 5:
171+
warnings.warn(
172+
"Not enough tail samples for reliable PSIS diagnostic.",
173+
stacklevel=3,
174+
)
175+
return float("inf")
176+
177+
# Shift to exceedances
178+
tail = np.exp(lw_tail) - np.exp(lw_cutoff)
179+
180+
# Fit GPD to the tail
181+
k, sigma = _fit_generalized_pareto(tail)
182+
183+
return float(k)
184+
185+
186+
def psis_diagnostic(
187+
rng_key: jax.Array,
188+
param_map: dict[str, jax.Array],
189+
model: Callable,
190+
guide: Callable,
191+
*args,
192+
num_particles: int = 1000,
193+
chunk_size: int | None = None,
194+
**kwargs,
195+
) -> float:
196+
r"""Compute the PSIS k-hat diagnostic for a model/guide pair.
197+
198+
The k-hat statistic measures the reliability of importance sampling
199+
estimates. It is the shape parameter of a Generalized Pareto Distribution
200+
(GPD) fitted to the upper tail of the importance weights.
201+
202+
Interpretation (Vehtari et al. 2024):
203+
204+
- k < 0.5: finite variance, classical CLT applies
205+
- 0.5 <= k < 0.7: finite mean, generalized CLT may apply
206+
- k >= 0.7: unreliable importance sampling estimates
207+
208+
**Example usage**::
209+
210+
>>> from jax import random
211+
>>> from numpyro.infer import SVI, Trace_ELBO, psis_diagnostic
212+
>>> svi = SVI(model, guide, optimizer, Trace_ELBO())
213+
>>> svi_result = svi.run(random.PRNGKey(0), num_steps, *args)
214+
>>> khat = psis_diagnostic(
215+
... random.PRNGKey(1), svi_result.params, model, guide, *args
216+
... )
217+
218+
.. note::
219+
220+
For reliable results, use at least several hundred particles
221+
(the default of 1000 is usually sufficient). Very few particles
222+
may not provide enough tail samples for GPD fitting.
223+
224+
:param jax.random.PRNGKey rng_key: random number generator key.
225+
:param dict param_map: dictionary of current parameter values
226+
(e.g. ``svi_result.params``).
227+
:param Callable model: NumPyro model.
228+
:param Callable guide: NumPyro guide.
229+
:param args: positional arguments to model and guide.
230+
:param int num_particles: number of importance weight samples to draw.
231+
:param int chunk_size: maximum particles to evaluate at once (for memory
232+
control). If None, all particles are evaluated together.
233+
:param kwargs: keyword arguments to model and guide.
234+
:return: the estimated k-hat statistic.
235+
:rtype: float
236+
"""
237+
if num_particles < 2:
238+
raise ValueError("num_particles must be at least 2.")
239+
240+
if chunk_size is None:
241+
chunk_size = num_particles
242+
243+
rng_keys = random.split(rng_key, num_particles)
244+
245+
# Compute log weights in batches
246+
def compute_fn(key):
247+
return _compute_log_weights(key, param_map, model, guide, args, kwargs)
248+
249+
log_weights_list = []
250+
for batch_start in range(0, num_particles, chunk_size):
251+
batch_keys = rng_keys[batch_start : batch_start + chunk_size]
252+
batch_lw = jax.vmap(compute_fn)(batch_keys)
253+
log_weights_list.append(batch_lw)
254+
255+
log_weights = np.concatenate(
256+
[np.asarray(device_get(lw)) for lw in log_weights_list]
257+
)
258+
259+
return _psis_khat(log_weights)

0 commit comments

Comments
 (0)