diff --git a/pymc_extras/inference/laplace.py b/pymc_extras/inference/laplace.py index d86f1fee5..d64d2adab 100644 --- a/pymc_extras/inference/laplace.py +++ b/pymc_extras/inference/laplace.py @@ -15,6 +15,7 @@ import logging +from collections.abc import Callable from functools import reduce from importlib.util import find_spec from itertools import product @@ -29,6 +30,7 @@ from arviz import dict_to_dataset from better_optimize.constants import minimize_method +from numpy.typing import ArrayLike from pymc import DictToArrayBijection from pymc.backends.arviz import ( coords_and_dims_for_inferencedata, @@ -39,6 +41,8 @@ from pymc.model.transform.conditioning import remove_value_transforms from pymc.model.transform.optimization import freeze_dims_and_data from pymc.util import get_default_varnames +from pytensor.tensor import TensorVariable +from pytensor.tensor.optimize import minimize from scipy import stats from pymc_extras.inference.find_map import ( @@ -52,6 +56,102 @@ _log = logging.getLogger(__name__) +def get_conditional_gaussian_approximation( + x: TensorVariable, + Q: TensorVariable | ArrayLike, + mu: TensorVariable | ArrayLike, + args: list[TensorVariable] | None = None, + model: pm.Model | None = None, + method: minimize_method = "BFGS", + use_jac: bool = True, + use_hess: bool = False, + optimizer_kwargs: dict | None = None, +) -> Callable: + """ + Returns a function to estimate the a posteriori log probability of a latent Gaussian field x and its mode x0 using the Laplace approximation. + + That is: + y | x, sigma ~ N(Ax, sigma^2 W) + x | params ~ N(mu, Q(params)^-1) + + We seek to estimate log(p(x | y, params)): + + log(p(x | y, params)) = log(p(y | x, params)) + log(p(x | params)) + const + + Let f(x) = log(p(y | x, params)). From the definition of our model above, we have log(p(x | params)) = -0.5*(x - mu).T Q (x - mu) + 0.5*logdet(Q). + + This gives log(p(x | y, params)) = f(x) - 0.5*(x - mu).T Q (x - mu) + 0.5*logdet(Q). We will estimate this using the Laplace approximation by Taylor expanding f(x) about the mode. + + Thus: + + 1. Maximize log(p(x | y, params)) = f(x) - 0.5*(x - mu).T Q (x - mu) wrt x (note that logdet(Q) does not depend on x) to find the mode x0. + + 2. Substitute x0 into the Laplace approximation expanded about the mode: log(p(x | y, params)) ~= -0.5*x.T (-f''(x0) + Q) x + x.T (Q.mu + f'(x0) - f''(x0).x0) + 0.5*logdet(Q). + + Parameters + ---------- + x: TensorVariable + The parameter with which to maximize wrt (that is, find the mode in x). In INLA, this is the latent field x~N(mu,Q^-1). + Q: TensorVariable | ArrayLike + The precision matrix of the latent field x. + mu: TensorVariable | ArrayLike + The mean of the latent field x. + args: list[TensorVariable] + Args to supply to the compiled function. That is, (x0, logp) = f(x, *args). If set to None, assumes the model RVs are args. + model: Model + PyMC model to use. + method: minimize_method + Which minimization algorithm to use. + use_jac: bool + If true, the minimizer will compute the gradient of log(p(x | y, params)). + use_hess: bool + If true, the minimizer will compute the Hessian log(p(x | y, params)). + optimizer_kwargs: dict + Kwargs to pass to scipy.optimize.minimize. + + Returns + ------- + f: Callable + A function which accepts a value of x and args and returns [x0, log(p(x | y, params))], where x0 is the mode. x is currently both the point at which to evaluate logp and the initial guess for the minimizer. + """ + model = pm.modelcontext(model) + + if args is None: + args = model.continuous_value_vars + model.discrete_value_vars + + # f = log(p(y | x, params)) + f_x = model.logp() + jac = pytensor.gradient.grad(f_x, x) + hess = pytensor.gradient.jacobian(jac.flatten(), x) + + # log(p(x | y, params)) only including terms that depend on x for the minimization step (logdet(Q) ignored as it is a constant wrt x) + log_x_posterior = f_x - 0.5 * (x - mu).T @ Q @ (x - mu) + + # Maximize log(p(x | y, params)) wrt x to find mode x0 + x0, _ = minimize( + objective=-log_x_posterior, + x=x, + method=method, + jac=use_jac, + hess=use_hess, + optimizer_kwargs=optimizer_kwargs, + ) + + # require f'(x0) and f''(x0) for Laplace approx + jac = pytensor.graph.replace.graph_replace(jac, {x: x0}) + hess = pytensor.graph.replace.graph_replace(hess, {x: x0}) + + # Full log(p(x | y, params)) using the Laplace approximation (up to a constant) + _, logdetQ = pt.nlinalg.slogdet(Q) + conditional_gaussian_approx = ( + -0.5 * x.T @ (-hess + Q) @ x + x.T @ (Q @ mu + jac - hess @ x0) + 0.5 * logdetQ + ) + + # Currently x is passed both as the query point for f(x, args) = logp(x | y, params) AND as an initial guess for x0. This may cause issues if the query point is + # far from the mode x0 or in a neighbourhood which results in poor convergence. + return pytensor.function(args, [x0, conditional_gaussian_approx]) + + def laplace_draws_to_inferencedata( posterior_draws: list[np.ndarray[float | int]], model: pm.Model | None = None ) -> az.InferenceData: @@ -308,6 +408,8 @@ def fit_mvn_at_MAP( ) H = -f_hess(mu.data) + if H.ndim == 1: + H = np.expand_dims(H, axis=1) H_inv = np.linalg.pinv(np.where(np.abs(H) < zero_tol, 0, -H)) def stabilize(x, jitter): diff --git a/tests/test_laplace.py b/tests/test_laplace.py index 8f7a4c017..72ff3e937 100644 --- a/tests/test_laplace.py +++ b/tests/test_laplace.py @@ -23,6 +23,7 @@ from pymc_extras.inference.laplace import ( fit_laplace, fit_mvn_at_MAP, + get_conditional_gaussian_approximation, sample_laplace_posterior, ) @@ -279,3 +280,85 @@ def test_laplace_scalar(): assert idata_laplace.fit.covariance_matrix.shape == (1, 1) np.testing.assert_allclose(idata_laplace.fit.mean_vector.values.item(), data.mean(), atol=0.1) + + +def test_get_conditional_gaussian_approximation(): + """ + Consider the trivial case of: + + y | x ~ N(x, cov_param) + x | param ~ N(mu_param, Q^-1) + + cov_param ~ N(cov_mu, cov_cov) + mu_param ~ N(mu_mu, mu_cov) + Q ~ N(Q_mu, Q_cov) + + This has an analytic solution at the mode which we can compare against. + """ + rng = np.random.default_rng(12345) + n = 10000 + d = 10 + + # Initialise arrays + mu_true = rng.random(d) + cov_true = np.diag(rng.random(d)) + Q_val = np.diag(rng.random(d)) + cov_param_val = np.diag(rng.random(d)) + + x_val = rng.random(d) + mu_val = rng.random(d) + + mu_mu = rng.random(d) + mu_cov = np.diag(np.ones(d)) + cov_mu = rng.random(d**2) + cov_cov = np.diag(np.ones(d**2)) + Q_mu = rng.random(d**2) + Q_cov = np.diag(np.ones(d**2)) + + with pm.Model() as model: + y_obs = rng.multivariate_normal(mean=mu_true, cov=cov_true, size=n) + + mu_param = pm.MvNormal("mu_param", mu=mu_mu, cov=mu_cov) + cov_param = pm.MvNormal("cov_param", mu=cov_mu, cov=cov_cov) + Q = pm.MvNormal("Q", mu=Q_mu, cov=Q_cov) + + # Pytensor currently doesn't support autograd for pt inverses, so we use a numeric Q instead + x = pm.MvNormal("x", mu=mu_param, cov=np.linalg.inv(Q_val)) + + y = pm.MvNormal( + "y", + mu=x, + cov=cov_param.reshape((d, d)), + observed=y_obs, + ) + + # logp(x | y, params) + cga = get_conditional_gaussian_approximation( + x=model.rvs_to_values[x], + Q=Q.reshape((d, d)), + mu=mu_param, + optimizer_kwargs={"tol": 1e-25}, + ) + + x0, log_x_posterior = cga( + x=x_val, mu_param=mu_val, cov_param=cov_param_val.flatten(), Q=Q_val.flatten() + ) + + # Get analytic values of the mode and Laplace-approximated log posterior + cov_param_inv = np.linalg.inv(cov_param_val) + + x0_true = np.linalg.inv(n * cov_param_inv + 2 * Q_val) @ ( + cov_param_inv @ y_obs.sum(axis=0) + 2 * Q_val @ mu_val + ) + + jac_true = cov_param_inv @ (y_obs - x0_true).sum(axis=0) - Q_val @ (x0_true - mu_val) + hess_true = -n * cov_param_inv - Q_val + + log_x_posterior_laplace_true = ( + -0.5 * x_val.T @ (-hess_true + Q_val) @ x_val + + x_val.T @ (Q_val @ mu_val + jac_true - hess_true @ x0_true) + + 0.5 * np.log(np.linalg.det(Q_val)) + ) + + np.testing.assert_allclose(x0, x0_true, atol=0.1, rtol=0.1) + np.testing.assert_allclose(log_x_posterior, log_x_posterior_laplace_true, atol=0.1, rtol=0.1)