|
| 1 | +# Copyright Contributors to the Pyro project. |
| 2 | +# SPDX-License-Identifier: Apache-2.0 |
| 3 | + |
| 4 | +""" |
| 5 | +Pareto Smoothed Importance Sampling (PSIS) diagnostics for variational inference. |
| 6 | +
|
| 7 | +Implements the k-hat diagnostic from: |
| 8 | + Yao, Y., Vehtari, A., Simpson, D., and Gelman, A. (2018). |
| 9 | + Yes, but Did It Work?: Evaluating Variational Inference. |
| 10 | + International Conference on Machine Learning. |
| 11 | +
|
| 12 | + Vehtari, A., Simpson, D., Gelman, A., Yao, Y., and Gabry, J. (2024). |
| 13 | + Pareto smoothed importance sampling. |
| 14 | + Journal of Machine Learning Research, 25(72):1-58. |
| 15 | +""" |
| 16 | + |
| 17 | +from __future__ import annotations |
| 18 | + |
| 19 | +from collections.abc import Callable |
| 20 | +import math |
| 21 | +import warnings |
| 22 | + |
| 23 | +import numpy as np |
| 24 | + |
| 25 | +import jax |
| 26 | +from jax import device_get, random |
| 27 | + |
| 28 | +from numpyro.handlers import seed |
| 29 | +from numpyro.infer.elbo import get_importance_log_probs |
| 30 | + |
| 31 | +__all__ = ["psis_diagnostic"] |
| 32 | + |
| 33 | + |
| 34 | +def _fit_generalized_pareto(x: np.ndarray) -> tuple[float, float]: |
| 35 | + """Estimate parameters of the Generalized Pareto Distribution (GPD). |
| 36 | +
|
| 37 | + Returns empirical Bayes estimates for the shape (k) and scale (sigma) |
| 38 | + parameters of the two-parameter GPD, using the method of Zhang and |
| 39 | + Stephens (2009) with the prior regularization from Vehtari et al. (2024). |
| 40 | +
|
| 41 | + References: |
| 42 | + Zhang, J. and Stephens, M.A. (2009). A new and efficient estimation |
| 43 | + method for the generalized Pareto distribution. Technometrics, |
| 44 | + 51(3):316-325. |
| 45 | +
|
| 46 | + Vehtari, A., Simpson, D., Gelman, A., Yao, Y., and Gabry, J. (2024). |
| 47 | + Pareto smoothed importance sampling. Journal of Machine Learning |
| 48 | + Research, 25(72):1-58. |
| 49 | +
|
| 50 | + :param numpy.ndarray x: one-dimensional array of positive exceedances (tail samples). |
| 51 | + :return: tuple of (k, sigma) where k is the shape parameter and sigma |
| 52 | + is the scale parameter. |
| 53 | + """ |
| 54 | + if x.ndim != 1 or len(x) <= 1: |
| 55 | + raise ValueError( |
| 56 | + f"Expected 1-D array with at least 2 elements, got shape {x.shape}." |
| 57 | + ) |
| 58 | + |
| 59 | + # Broad errstate is needed because degenerate inputs (e.g. zeros or |
| 60 | + # identical values) cause cascading numerical issues at multiple points: |
| 61 | + # divide: 1/x[quartile], 1/x[-1], -k/b when tail values are zero |
| 62 | + # over: exp(L - L') when profile log-likelihood differences are large |
| 63 | + # invalid: downstream ops on nan/inf from earlier divide-by-zero |
| 64 | + # The resulting nan/inf propagate correctly through the algorithm, |
| 65 | + # matching the reference implementation behavior. |
| 66 | + with np.errstate(divide="ignore", over="ignore", invalid="ignore"): |
| 67 | + return _fit_generalized_pareto_impl(x) |
| 68 | + |
| 69 | + |
| 70 | +def _fit_generalized_pareto_impl(x: np.ndarray) -> tuple[float, float]: |
| 71 | + x = np.sort(x) |
| 72 | + n = len(x) |
| 73 | + PRIOR = 3 |
| 74 | + m = 30 + int(np.sqrt(n)) |
| 75 | + |
| 76 | + # Candidate shape parameters (Zhang & Stephens grid) |
| 77 | + bs = np.arange(1, m + 1, dtype=float) |
| 78 | + bs -= 0.5 |
| 79 | + np.divide(m, bs, out=bs) |
| 80 | + np.sqrt(bs, out=bs) |
| 81 | + np.subtract(1, bs, out=bs) |
| 82 | + bs /= PRIOR * x[int(n / 4 + 0.5) - 1] |
| 83 | + bs += 1 / x[-1] |
| 84 | + |
| 85 | + # Profile log-likelihood for each candidate |
| 86 | + ks = np.negative(bs) |
| 87 | + temp = ks[:, None] * x |
| 88 | + np.log1p(temp, out=temp) |
| 89 | + np.mean(temp, axis=1, out=ks) |
| 90 | + |
| 91 | + L = bs / ks |
| 92 | + np.negative(L, out=L) |
| 93 | + np.log(L, out=L) |
| 94 | + L -= ks |
| 95 | + L -= 1 |
| 96 | + L *= n |
| 97 | + |
| 98 | + # Posterior weights (overflow in exp is expected and harmless; |
| 99 | + # overflowed values get negligible weight after normalization) |
| 100 | + temp = L - L[:, None] |
| 101 | + np.exp(temp, out=temp) |
| 102 | + w = np.sum(temp, axis=1) |
| 103 | + np.divide(1, w, out=w) |
| 104 | + |
| 105 | + # Remove negligible weights |
| 106 | + dii = w >= 10 * np.finfo(float).eps |
| 107 | + if not np.all(dii): |
| 108 | + w = w[dii] |
| 109 | + bs = bs[dii] |
| 110 | + w /= w.sum() |
| 111 | + |
| 112 | + # Posterior mean for b |
| 113 | + b = np.sum(bs * w) |
| 114 | + |
| 115 | + # Estimate for k (note: negated relative to Zhang & Stephens) |
| 116 | + temp = (-b) * x |
| 117 | + np.log1p(temp, out=temp) |
| 118 | + k = np.mean(temp) |
| 119 | + |
| 120 | + # Estimate for sigma |
| 121 | + sigma = -k / b |
| 122 | + |
| 123 | + # Weakly informative prior for k (Vehtari et al. 2024, Appendix G) |
| 124 | + # Prior: mean=0.5, effective sample size a=10 |
| 125 | + a = 10 |
| 126 | + k = k * n / (n + a) + a * 0.5 / (n + a) |
| 127 | + |
| 128 | + return float(k), float(sigma) |
| 129 | + |
| 130 | + |
| 131 | +def _compute_log_weights( |
| 132 | + rng_key: jax.Array, |
| 133 | + param_map: dict[str, jax.Array], |
| 134 | + model: Callable, |
| 135 | + guide: Callable, |
| 136 | + args: tuple, |
| 137 | + kwargs: dict, |
| 138 | +) -> jax.Array: |
| 139 | + """Compute log importance weight log p(x,z) - log q(z) for a single particle.""" |
| 140 | + # Separate seeds: guide needs its own randomness for sampling latent sites; |
| 141 | + # model gets an independent seed in case it has stochastic structure beyond |
| 142 | + # the latent sites replayed from the guide (e.g. stochastic control flow). |
| 143 | + model_seed, guide_seed = random.split(rng_key) |
| 144 | + seeded_model = seed(model, model_seed) |
| 145 | + seeded_guide = seed(guide, guide_seed) |
| 146 | + model_log_probs, guide_log_probs = get_importance_log_probs( |
| 147 | + seeded_model, seeded_guide, args, kwargs, param_map |
| 148 | + ) |
| 149 | + log_model = sum(v.sum() for v in model_log_probs.values()) |
| 150 | + log_guide = sum(v.sum() for v in guide_log_probs.values()) |
| 151 | + return log_model - log_guide |
| 152 | + |
| 153 | + |
| 154 | +def _psis_khat(log_weights: np.ndarray) -> float: |
| 155 | + """Compute PSIS k-hat from an array of raw log importance weights.""" |
| 156 | + log_weights = log_weights.copy() |
| 157 | + log_weights -= log_weights.max() |
| 158 | + log_weights = np.sort(log_weights) |
| 159 | + |
| 160 | + # S matches notation in Vehtari et al. (2024), Algorithm 1 |
| 161 | + S = len(log_weights) |
| 162 | + |
| 163 | + # Tail extraction (Vehtari et al. 2024, Algorithm 1) |
| 164 | + M = math.ceil(min(0.2 * S, 3 * math.sqrt(S))) |
| 165 | + cutoff_ind = -(M + 1) |
| 166 | + lw_cutoff = max(np.log(np.finfo(float).tiny), log_weights[cutoff_ind]) |
| 167 | + |
| 168 | + lw_tail = log_weights[log_weights > lw_cutoff] |
| 169 | + |
| 170 | + if len(lw_tail) < 5: |
| 171 | + warnings.warn( |
| 172 | + "Not enough tail samples for reliable PSIS diagnostic.", |
| 173 | + stacklevel=3, |
| 174 | + ) |
| 175 | + return float("inf") |
| 176 | + |
| 177 | + # Shift to exceedances |
| 178 | + tail = np.exp(lw_tail) - np.exp(lw_cutoff) |
| 179 | + |
| 180 | + # Fit GPD to the tail |
| 181 | + k, sigma = _fit_generalized_pareto(tail) |
| 182 | + |
| 183 | + return float(k) |
| 184 | + |
| 185 | + |
| 186 | +def psis_diagnostic( |
| 187 | + rng_key: jax.Array, |
| 188 | + param_map: dict[str, jax.Array], |
| 189 | + model: Callable, |
| 190 | + guide: Callable, |
| 191 | + *args, |
| 192 | + num_particles: int = 1000, |
| 193 | + chunk_size: int | None = None, |
| 194 | + **kwargs, |
| 195 | +) -> float: |
| 196 | + r"""Compute the PSIS k-hat diagnostic for a model/guide pair. |
| 197 | +
|
| 198 | + The k-hat statistic measures the reliability of importance sampling |
| 199 | + estimates. It is the shape parameter of a Generalized Pareto Distribution |
| 200 | + (GPD) fitted to the upper tail of the importance weights. |
| 201 | +
|
| 202 | + Interpretation (Vehtari et al. 2024): |
| 203 | +
|
| 204 | + - k < 0.5: finite variance, classical CLT applies |
| 205 | + - 0.5 <= k < 0.7: finite mean, generalized CLT may apply |
| 206 | + - k >= 0.7: unreliable importance sampling estimates |
| 207 | +
|
| 208 | + **Example usage**:: |
| 209 | +
|
| 210 | + >>> from jax import random |
| 211 | + >>> from numpyro.infer import SVI, Trace_ELBO, psis_diagnostic |
| 212 | + >>> svi = SVI(model, guide, optimizer, Trace_ELBO()) |
| 213 | + >>> svi_result = svi.run(random.PRNGKey(0), num_steps, *args) |
| 214 | + >>> khat = psis_diagnostic( |
| 215 | + ... random.PRNGKey(1), svi_result.params, model, guide, *args |
| 216 | + ... ) |
| 217 | +
|
| 218 | + .. note:: |
| 219 | +
|
| 220 | + For reliable results, use at least several hundred particles |
| 221 | + (the default of 1000 is usually sufficient). Very few particles |
| 222 | + may not provide enough tail samples for GPD fitting. |
| 223 | +
|
| 224 | + :param jax.random.PRNGKey rng_key: random number generator key. |
| 225 | + :param dict param_map: dictionary of current parameter values |
| 226 | + (e.g. ``svi_result.params``). |
| 227 | + :param Callable model: NumPyro model. |
| 228 | + :param Callable guide: NumPyro guide. |
| 229 | + :param args: positional arguments to model and guide. |
| 230 | + :param int num_particles: number of importance weight samples to draw. |
| 231 | + :param int chunk_size: maximum particles to evaluate at once (for memory |
| 232 | + control). If None, all particles are evaluated together. |
| 233 | + :param kwargs: keyword arguments to model and guide. |
| 234 | + :return: the estimated k-hat statistic. |
| 235 | + :rtype: float |
| 236 | + """ |
| 237 | + if num_particles < 2: |
| 238 | + raise ValueError("num_particles must be at least 2.") |
| 239 | + |
| 240 | + if chunk_size is None: |
| 241 | + chunk_size = num_particles |
| 242 | + |
| 243 | + rng_keys = random.split(rng_key, num_particles) |
| 244 | + |
| 245 | + # Compute log weights in batches |
| 246 | + def compute_fn(key): |
| 247 | + return _compute_log_weights(key, param_map, model, guide, args, kwargs) |
| 248 | + |
| 249 | + log_weights_list = [] |
| 250 | + for batch_start in range(0, num_particles, chunk_size): |
| 251 | + batch_keys = rng_keys[batch_start : batch_start + chunk_size] |
| 252 | + batch_lw = jax.vmap(compute_fn)(batch_keys) |
| 253 | + log_weights_list.append(batch_lw) |
| 254 | + |
| 255 | + log_weights = np.concatenate( |
| 256 | + [np.asarray(device_get(lw)) for lw in log_weights_list] |
| 257 | + ) |
| 258 | + |
| 259 | + return _psis_khat(log_weights) |
0 commit comments