|
| 1 | +import arviz as az |
| 2 | +import numpy as np |
| 3 | +import pymc as pm |
| 4 | +import pytensor |
| 5 | +import pytensor.tensor as pt |
| 6 | + |
| 7 | +from better_optimize.constants import minimize_method |
| 8 | +from numpy.typing import ArrayLike |
| 9 | +from pytensor.tensor import TensorVariable |
| 10 | +from pytensor.tensor.optimize import minimize |
| 11 | + |
| 12 | + |
| 13 | +def get_conditional_gaussian_approximation( |
| 14 | + x: TensorVariable, |
| 15 | + Q: TensorVariable | ArrayLike, |
| 16 | + mu: TensorVariable | ArrayLike, |
| 17 | + model: pm.Model | None = None, |
| 18 | + method: minimize_method = "BFGS", |
| 19 | + use_jac: bool = True, |
| 20 | + use_hess: bool = False, |
| 21 | + optimizer_kwargs: dict | None = None, |
| 22 | +) -> list[TensorVariable]: |
| 23 | + """ |
| 24 | + Returns an estimate the a posteriori probability of a latent Gaussian field x and its mode x0 using the Laplace approximation. |
| 25 | +
|
| 26 | + That is: |
| 27 | + y | x, sigma ~ N(Ax, sigma^2 W) |
| 28 | + x | params ~ N(mu, Q(params)^-1) |
| 29 | +
|
| 30 | + We seek to estimate p(x | y, params) with a Gaussian: |
| 31 | +
|
| 32 | + log(p(x | y, params)) = log(p(y | x, params)) + log(p(x | params)) + const |
| 33 | +
|
| 34 | + Let f(x) = log(p(y | x, params)). From the definition of our model above, we have log(p(x | params)) = -0.5*(x - mu).T Q (x - mu) + 0.5*logdet(Q). |
| 35 | +
|
| 36 | + This gives log(p(x | y, params)) = f(x) - 0.5*(x - mu).T Q (x - mu) + 0.5*logdet(Q). We will estimate this using the Laplace approximation by Taylor expanding f(x) about the mode. |
| 37 | +
|
| 38 | + Thus: |
| 39 | +
|
| 40 | + 1. Maximize log(p(x | y, params)) = f(x) - 0.5*(x - mu).T Q (x - mu) wrt x (note that logdet(Q) does not depend on x) to find the mode x0. |
| 41 | +
|
| 42 | + 2. Use the Laplace approximation expanded about the mode: p(x | y, params) ~= N(mu=x0, tau=Q - f''(x0)). |
| 43 | +
|
| 44 | + Parameters |
| 45 | + ---------- |
| 46 | + x: TensorVariable |
| 47 | + The parameter with which to maximize wrt (that is, find the mode in x). In INLA, this is the latent Gaussian field x~N(mu,Q^-1). |
| 48 | + Q: TensorVariable | ArrayLike |
| 49 | + The precision matrix of the latent field x. |
| 50 | + mu: TensorVariable | ArrayLike |
| 51 | + The mean of the latent field x. |
| 52 | + model: Model |
| 53 | + PyMC model to use. |
| 54 | + method: minimize_method |
| 55 | + Which minimization algorithm to use. |
| 56 | + use_jac: bool |
| 57 | + If true, the minimizer will compute the gradient of log(p(x | y, params)). |
| 58 | + use_hess: bool |
| 59 | + If true, the minimizer will compute the Hessian log(p(x | y, params)). |
| 60 | + optimizer_kwargs: dict |
| 61 | + Kwargs to pass to scipy.optimize.minimize. |
| 62 | +
|
| 63 | + Returns |
| 64 | + ------- |
| 65 | + x0, p(x | y, params): list[TensorVariable] |
| 66 | + Mode and Laplace approximation for posterior. |
| 67 | + """ |
| 68 | + model = pm.modelcontext(model) |
| 69 | + |
| 70 | + # f = log(p(y | x, params)) |
| 71 | + f_x = model.logp() |
| 72 | + |
| 73 | + # log(p(x | y, params)) only including terms that depend on x for the minimization step (logdet(Q) ignored as it is a constant wrt x) |
| 74 | + log_x_posterior = f_x - 0.5 * (x - mu).T @ Q @ (x - mu) |
| 75 | + |
| 76 | + # Maximize log(p(x | y, params)) wrt x to find mode x0 |
| 77 | + x0, _ = minimize( |
| 78 | + objective=-log_x_posterior, |
| 79 | + x=x, |
| 80 | + method=method, |
| 81 | + jac=use_jac, |
| 82 | + hess=use_hess, |
| 83 | + optimizer_kwargs=optimizer_kwargs, |
| 84 | + ) |
| 85 | + |
| 86 | + # require f''(x0) for Laplace approx |
| 87 | + hess = pytensor.gradient.hessian(f_x, x) |
| 88 | + hess = pytensor.graph.replace.graph_replace(hess, {x: x0}) |
| 89 | + |
| 90 | + # Could be made more efficient with adding diagonals only |
| 91 | + tau = Q - hess |
| 92 | + |
| 93 | + # Currently x is passed both as the query point for f(x, args) = logp(x | y, params) AND as an initial guess for x0. This may cause issues if the query point is |
| 94 | + # far from the mode x0 or in a neighbourhood which results in poor convergence. |
| 95 | + return x0, pm.MvNormal(f"{x.name}_laplace_approx", mu=x0, tau=tau) |
| 96 | + |
| 97 | + |
| 98 | +def get_log_marginal_likelihood( |
| 99 | + x: TensorVariable, |
| 100 | + Q: TensorVariable | ArrayLike, |
| 101 | + mu: TensorVariable | ArrayLike, |
| 102 | + model: pm.Model | None = None, |
| 103 | + method: minimize_method = "BFGS", |
| 104 | + use_jac: bool = True, |
| 105 | + use_hess: bool = False, |
| 106 | + optimizer_kwargs: dict | None = None, |
| 107 | +) -> TensorVariable: |
| 108 | + model = pm.modelcontext(model) |
| 109 | + |
| 110 | + x0, laplace_approx = get_conditional_gaussian_approximation( |
| 111 | + x, Q, mu, model, method, use_jac, use_hess, optimizer_kwargs |
| 112 | + ) |
| 113 | + log_laplace_approx = pm.logp(laplace_approx, model.rvs_to_values[x]) |
| 114 | + |
| 115 | + _, logdetQ = pt.nlinalg.slogdet(Q) |
| 116 | + log_x_likelihood = ( |
| 117 | + -0.5 * (x - mu).T @ Q @ (x - mu) + 0.5 * logdetQ - 0.5 * x.shape[0] * np.log(2 * np.pi) |
| 118 | + ) |
| 119 | + |
| 120 | + log_likelihood = ( # logp(y | params) = |
| 121 | + model.logp() # logp(y | x, params) |
| 122 | + + log_x_likelihood # * logp(x | params) |
| 123 | + - log_laplace_approx # / logp(x | y, params) |
| 124 | + ) |
| 125 | + |
| 126 | + return log_likelihood |
| 127 | + |
| 128 | + |
| 129 | +def fit_INLA( |
| 130 | + x: TensorVariable, |
| 131 | + Q: TensorVariable | ArrayLike, |
| 132 | + mu: TensorVariable | ArrayLike, |
| 133 | + model: pm.Model | None = None, |
| 134 | + method: minimize_method = "BFGS", |
| 135 | + use_jac: bool = True, |
| 136 | + use_hess: bool = False, |
| 137 | + optimizer_kwargs: dict | None = None, |
| 138 | +) -> az.InferenceData: |
| 139 | + model = pm.modelcontext(model) |
| 140 | + |
| 141 | + # logp(y | params) |
| 142 | + log_likelihood = get_log_marginal_likelihood( |
| 143 | + x, Q, mu, model, method, use_jac, use_hess, optimizer_kwargs |
| 144 | + ) |
| 145 | + |
| 146 | + # TODO How to obtain prior? It can parametrise Q, mu, y, etc. Not sure if we could extract from model.logp somehow. Otherwise simply specify as a user input |
| 147 | + prior = None |
| 148 | + params = None |
| 149 | + log_prior = pm.logp(prior, model.rvs_to_values[params]) |
| 150 | + |
| 151 | + # logp(params | y) = logp(y | params) + logp(params) + const |
| 152 | + log_posterior = log_likelihood + log_prior |
| 153 | + |
| 154 | + # TODO log_marginal_x_likelihood is almost the same as log_likelihood, but need to do some sampling? |
| 155 | + log_marginal_x_likelihood = None |
| 156 | + log_marginal_x_posterior = log_marginal_x_likelihood + log_prior |
| 157 | + |
| 158 | + # TODO can we sample over log likelihoods? |
| 159 | + # Marginalize params |
| 160 | + idata_params = log_posterior.sample() # TODO something like NUTS, QMC, etc.? |
| 161 | + idata_x = log_marginal_x_posterior.sample() |
| 162 | + |
| 163 | + # Bundle up idatas somehow |
| 164 | + return idata_params, idata_x |
0 commit comments