Skip to content

Commit 6f1ec37

Browse files
Michal-Novomestskyandreacate
authored andcommitted
Implement a minimizer for INLA (pymc-devs#513)
1 parent c088004 commit 6f1ec37

File tree

2 files changed

+466
-0
lines changed

2 files changed

+466
-0
lines changed

pymc_extras/inference/laplace.py

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
import logging
1717

18+
from collections.abc import Callable
1819
from functools import reduce
1920
from importlib.util import find_spec
2021
from itertools import product
@@ -29,6 +30,7 @@
2930

3031
from arviz import dict_to_dataset
3132
from better_optimize.constants import minimize_method
33+
from numpy.typing import ArrayLike
3234
from pymc import DictToArrayBijection
3335
from pymc.backends.arviz import (
3436
coords_and_dims_for_inferencedata,
@@ -39,6 +41,8 @@
3941
from pymc.model.transform.conditioning import remove_value_transforms
4042
from pymc.model.transform.optimization import freeze_dims_and_data
4143
from pymc.util import get_default_varnames
44+
from pytensor.tensor import TensorVariable
45+
from pytensor.tensor.optimize import minimize
4246
from scipy import stats
4347

4448
from pymc_extras.inference.find_map import (
@@ -52,6 +56,102 @@
5256
_log = logging.getLogger(__name__)
5357

5458

59+
def get_conditional_gaussian_approximation(
60+
x: TensorVariable,
61+
Q: TensorVariable | ArrayLike,
62+
mu: TensorVariable | ArrayLike,
63+
args: list[TensorVariable] | None = None,
64+
model: pm.Model | None = None,
65+
method: minimize_method = "BFGS",
66+
use_jac: bool = True,
67+
use_hess: bool = False,
68+
optimizer_kwargs: dict | None = None,
69+
) -> Callable:
70+
"""
71+
Returns a function to estimate the a posteriori log probability of a latent Gaussian field x and its mode x0 using the Laplace approximation.
72+
73+
That is:
74+
y | x, sigma ~ N(Ax, sigma^2 W)
75+
x | params ~ N(mu, Q(params)^-1)
76+
77+
We seek to estimate log(p(x | y, params)):
78+
79+
log(p(x | y, params)) = log(p(y | x, params)) + log(p(x | params)) + const
80+
81+
Let f(x) = log(p(y | x, params)). From the definition of our model above, we have log(p(x | params)) = -0.5*(x - mu).T Q (x - mu) + 0.5*logdet(Q).
82+
83+
This gives log(p(x | y, params)) = f(x) - 0.5*(x - mu).T Q (x - mu) + 0.5*logdet(Q). We will estimate this using the Laplace approximation by Taylor expanding f(x) about the mode.
84+
85+
Thus:
86+
87+
1. Maximize log(p(x | y, params)) = f(x) - 0.5*(x - mu).T Q (x - mu) wrt x (note that logdet(Q) does not depend on x) to find the mode x0.
88+
89+
2. Substitute x0 into the Laplace approximation expanded about the mode: log(p(x | y, params)) ~= -0.5*x.T (-f''(x0) + Q) x + x.T (Q.mu + f'(x0) - f''(x0).x0) + 0.5*logdet(Q).
90+
91+
Parameters
92+
----------
93+
x: TensorVariable
94+
The parameter with which to maximize wrt (that is, find the mode in x). In INLA, this is the latent field x~N(mu,Q^-1).
95+
Q: TensorVariable | ArrayLike
96+
The precision matrix of the latent field x.
97+
mu: TensorVariable | ArrayLike
98+
The mean of the latent field x.
99+
args: list[TensorVariable]
100+
Args to supply to the compiled function. That is, (x0, logp) = f(x, *args). If set to None, assumes the model RVs are args.
101+
model: Model
102+
PyMC model to use.
103+
method: minimize_method
104+
Which minimization algorithm to use.
105+
use_jac: bool
106+
If true, the minimizer will compute the gradient of log(p(x | y, params)).
107+
use_hess: bool
108+
If true, the minimizer will compute the Hessian log(p(x | y, params)).
109+
optimizer_kwargs: dict
110+
Kwargs to pass to scipy.optimize.minimize.
111+
112+
Returns
113+
-------
114+
f: Callable
115+
A function which accepts a value of x and args and returns [x0, log(p(x | y, params))], where x0 is the mode. x is currently both the point at which to evaluate logp and the initial guess for the minimizer.
116+
"""
117+
model = pm.modelcontext(model)
118+
119+
if args is None:
120+
args = model.continuous_value_vars + model.discrete_value_vars
121+
122+
# f = log(p(y | x, params))
123+
f_x = model.logp()
124+
jac = pytensor.gradient.grad(f_x, x)
125+
hess = pytensor.gradient.jacobian(jac.flatten(), x)
126+
127+
# log(p(x | y, params)) only including terms that depend on x for the minimization step (logdet(Q) ignored as it is a constant wrt x)
128+
log_x_posterior = f_x - 0.5 * (x - mu).T @ Q @ (x - mu)
129+
130+
# Maximize log(p(x | y, params)) wrt x to find mode x0
131+
x0, _ = minimize(
132+
objective=-log_x_posterior,
133+
x=x,
134+
method=method,
135+
jac=use_jac,
136+
hess=use_hess,
137+
optimizer_kwargs=optimizer_kwargs,
138+
)
139+
140+
# require f'(x0) and f''(x0) for Laplace approx
141+
jac = pytensor.graph.replace.graph_replace(jac, {x: x0})
142+
hess = pytensor.graph.replace.graph_replace(hess, {x: x0})
143+
144+
# Full log(p(x | y, params)) using the Laplace approximation (up to a constant)
145+
_, logdetQ = pt.nlinalg.slogdet(Q)
146+
conditional_gaussian_approx = (
147+
-0.5 * x.T @ (-hess + Q) @ x + x.T @ (Q @ mu + jac - hess @ x0) + 0.5 * logdetQ
148+
)
149+
150+
# Currently x is passed both as the query point for f(x, args) = logp(x | y, params) AND as an initial guess for x0. This may cause issues if the query point is
151+
# far from the mode x0 or in a neighbourhood which results in poor convergence.
152+
return pytensor.function(args, [x0, conditional_gaussian_approx])
153+
154+
55155
def laplace_draws_to_inferencedata(
56156
posterior_draws: list[np.ndarray[float | int]], model: pm.Model | None = None
57157
) -> az.InferenceData:
@@ -308,6 +408,8 @@ def fit_mvn_at_MAP(
308408
)
309409

310410
H = -f_hess(mu.data)
411+
if H.ndim == 1:
412+
H = np.expand_dims(H, axis=1)
311413
H_inv = np.linalg.pinv(np.where(np.abs(H) < zero_tol, 0, -H))
312414

313415
def stabilize(x, jitter):

0 commit comments

Comments
 (0)