Skip to content

Commit 13961fa

Browse files
started writing fit_INLA routine
1 parent 83bef75 commit 13961fa

File tree

5 files changed

+280
-195
lines changed

5 files changed

+280
-195
lines changed

pymc_extras/inference/fit.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,17 @@ def fit(method: str, **kwargs) -> az.InferenceData:
3636

3737
return fit_pathfinder(**kwargs)
3838

39-
if method == "laplace":
39+
elif method == "laplace":
4040
from pymc_extras.inference.laplace import fit_laplace
4141

4242
return fit_laplace(**kwargs)
43+
44+
elif method == "INLA":
45+
from pymc_extras.inference.laplace import fit_INLA
46+
47+
return fit_INLA(**kwargs)
48+
49+
else:
50+
raise ValueError(
51+
f"method '{method}' not supported. Use one of 'pathfinder', 'laplace' or 'INLA'."
52+
)

pymc_extras/inference/inla.py

Lines changed: 164 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,164 @@
1+
import arviz as az
2+
import numpy as np
3+
import pymc as pm
4+
import pytensor
5+
import pytensor.tensor as pt
6+
7+
from better_optimize.constants import minimize_method
8+
from numpy.typing import ArrayLike
9+
from pytensor.tensor import TensorVariable
10+
from pytensor.tensor.optimize import minimize
11+
12+
13+
def get_conditional_gaussian_approximation(
14+
x: TensorVariable,
15+
Q: TensorVariable | ArrayLike,
16+
mu: TensorVariable | ArrayLike,
17+
model: pm.Model | None = None,
18+
method: minimize_method = "BFGS",
19+
use_jac: bool = True,
20+
use_hess: bool = False,
21+
optimizer_kwargs: dict | None = None,
22+
) -> list[TensorVariable]:
23+
"""
24+
Returns an estimate the a posteriori probability of a latent Gaussian field x and its mode x0 using the Laplace approximation.
25+
26+
That is:
27+
y | x, sigma ~ N(Ax, sigma^2 W)
28+
x | params ~ N(mu, Q(params)^-1)
29+
30+
We seek to estimate p(x | y, params) with a Gaussian:
31+
32+
log(p(x | y, params)) = log(p(y | x, params)) + log(p(x | params)) + const
33+
34+
Let f(x) = log(p(y | x, params)). From the definition of our model above, we have log(p(x | params)) = -0.5*(x - mu).T Q (x - mu) + 0.5*logdet(Q).
35+
36+
This gives log(p(x | y, params)) = f(x) - 0.5*(x - mu).T Q (x - mu) + 0.5*logdet(Q). We will estimate this using the Laplace approximation by Taylor expanding f(x) about the mode.
37+
38+
Thus:
39+
40+
1. Maximize log(p(x | y, params)) = f(x) - 0.5*(x - mu).T Q (x - mu) wrt x (note that logdet(Q) does not depend on x) to find the mode x0.
41+
42+
2. Use the Laplace approximation expanded about the mode: p(x | y, params) ~= N(mu=x0, tau=Q - f''(x0)).
43+
44+
Parameters
45+
----------
46+
x: TensorVariable
47+
The parameter with which to maximize wrt (that is, find the mode in x). In INLA, this is the latent Gaussian field x~N(mu,Q^-1).
48+
Q: TensorVariable | ArrayLike
49+
The precision matrix of the latent field x.
50+
mu: TensorVariable | ArrayLike
51+
The mean of the latent field x.
52+
model: Model
53+
PyMC model to use.
54+
method: minimize_method
55+
Which minimization algorithm to use.
56+
use_jac: bool
57+
If true, the minimizer will compute the gradient of log(p(x | y, params)).
58+
use_hess: bool
59+
If true, the minimizer will compute the Hessian log(p(x | y, params)).
60+
optimizer_kwargs: dict
61+
Kwargs to pass to scipy.optimize.minimize.
62+
63+
Returns
64+
-------
65+
x0, p(x | y, params): list[TensorVariable]
66+
Mode and Laplace approximation for posterior.
67+
"""
68+
model = pm.modelcontext(model)
69+
70+
# f = log(p(y | x, params))
71+
f_x = model.logp()
72+
73+
# log(p(x | y, params)) only including terms that depend on x for the minimization step (logdet(Q) ignored as it is a constant wrt x)
74+
log_x_posterior = f_x - 0.5 * (x - mu).T @ Q @ (x - mu)
75+
76+
# Maximize log(p(x | y, params)) wrt x to find mode x0
77+
x0, _ = minimize(
78+
objective=-log_x_posterior,
79+
x=x,
80+
method=method,
81+
jac=use_jac,
82+
hess=use_hess,
83+
optimizer_kwargs=optimizer_kwargs,
84+
)
85+
86+
# require f''(x0) for Laplace approx
87+
hess = pytensor.gradient.hessian(f_x, x)
88+
hess = pytensor.graph.replace.graph_replace(hess, {x: x0})
89+
90+
# Could be made more efficient with adding diagonals only
91+
tau = Q - hess
92+
93+
# Currently x is passed both as the query point for f(x, args) = logp(x | y, params) AND as an initial guess for x0. This may cause issues if the query point is
94+
# far from the mode x0 or in a neighbourhood which results in poor convergence.
95+
return x0, pm.MvNormal(f"{x.name}_laplace_approx", mu=x0, tau=tau)
96+
97+
98+
def get_log_marginal_likelihood(
99+
x: TensorVariable,
100+
Q: TensorVariable | ArrayLike,
101+
mu: TensorVariable | ArrayLike,
102+
model: pm.Model | None = None,
103+
method: minimize_method = "BFGS",
104+
use_jac: bool = True,
105+
use_hess: bool = False,
106+
optimizer_kwargs: dict | None = None,
107+
) -> TensorVariable:
108+
model = pm.modelcontext(model)
109+
110+
x0, laplace_approx = get_conditional_gaussian_approximation(
111+
x, Q, mu, model, method, use_jac, use_hess, optimizer_kwargs
112+
)
113+
log_laplace_approx = pm.logp(laplace_approx, model.rvs_to_values[x])
114+
115+
_, logdetQ = pt.nlinalg.slogdet(Q)
116+
log_x_likelihood = (
117+
-0.5 * (x - mu).T @ Q @ (x - mu) + 0.5 * logdetQ - 0.5 * x.shape[0] * np.log(2 * np.pi)
118+
)
119+
120+
log_likelihood = ( # logp(y | params) =
121+
model.logp() # logp(y | x, params)
122+
+ log_x_likelihood # * logp(x | params)
123+
- log_laplace_approx # / logp(x | y, params)
124+
)
125+
126+
return log_likelihood
127+
128+
129+
def fit_INLA(
130+
x: TensorVariable,
131+
Q: TensorVariable | ArrayLike,
132+
mu: TensorVariable | ArrayLike,
133+
model: pm.Model | None = None,
134+
method: minimize_method = "BFGS",
135+
use_jac: bool = True,
136+
use_hess: bool = False,
137+
optimizer_kwargs: dict | None = None,
138+
) -> az.InferenceData:
139+
model = pm.modelcontext(model)
140+
141+
# logp(y | params)
142+
log_likelihood = get_log_marginal_likelihood(
143+
x, Q, mu, model, method, use_jac, use_hess, optimizer_kwargs
144+
)
145+
146+
# TODO How to obtain prior? It can parametrise Q, mu, y, etc. Not sure if we could extract from model.logp somehow. Otherwise simply specify as a user input
147+
prior = None
148+
params = None
149+
log_prior = pm.logp(prior, model.rvs_to_values[params])
150+
151+
# logp(params | y) = logp(y | params) + logp(params) + const
152+
log_posterior = log_likelihood + log_prior
153+
154+
# TODO log_marginal_x_likelihood is almost the same as log_likelihood, but need to do some sampling?
155+
log_marginal_x_likelihood = None
156+
log_marginal_x_posterior = log_marginal_x_likelihood + log_prior
157+
158+
# TODO can we sample over log likelihoods?
159+
# Marginalize params
160+
idata_params = log_posterior.sample() # TODO something like NUTS, QMC, etc.?
161+
idata_x = log_marginal_x_posterior.sample()
162+
163+
# Bundle up idatas somehow
164+
return idata_params, idata_x

pymc_extras/inference/laplace.py

Lines changed: 0 additions & 111 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515

1616
import logging
1717

18-
from collections.abc import Callable
1918
from functools import reduce
2019
from importlib.util import find_spec
2120
from itertools import product
@@ -30,7 +29,6 @@
3029

3130
from arviz import dict_to_dataset
3231
from better_optimize.constants import minimize_method
33-
from numpy.typing import ArrayLike
3432
from pymc import DictToArrayBijection
3533
from pymc.backends.arviz import (
3634
coords_and_dims_for_inferencedata,
@@ -41,8 +39,6 @@
4139
from pymc.model.transform.conditioning import remove_value_transforms
4240
from pymc.model.transform.optimization import freeze_dims_and_data
4341
from pymc.util import get_default_varnames
44-
from pytensor.tensor import TensorVariable
45-
from pytensor.tensor.optimize import minimize
4642
from scipy import stats
4743

4844
from pymc_extras.inference.find_map import (
@@ -56,113 +52,6 @@
5652
_log = logging.getLogger(__name__)
5753

5854

59-
def get_conditional_gaussian_approximation(
60-
x: TensorVariable,
61-
Q: TensorVariable | ArrayLike,
62-
mu: TensorVariable | ArrayLike,
63-
args: list[TensorVariable] | None = None,
64-
model: pm.Model | None = None,
65-
method: minimize_method = "BFGS",
66-
use_jac: bool = True,
67-
use_hess: bool = False,
68-
optimizer_kwargs: dict | None = None,
69-
) -> Callable:
70-
"""
71-
Returns a function to estimate the a posteriori log probability of a latent Gaussian field x and its mode x0 using the Laplace approximation.
72-
73-
That is:
74-
y | x, sigma ~ N(Ax, sigma^2 W)
75-
x | params ~ N(mu, Q(params)^-1)
76-
77-
We seek to estimate log(p(x | y, params)):
78-
79-
log(p(x | y, params)) = log(p(y | x, params)) + log(p(x | params)) + const
80-
81-
Let f(x) = log(p(y | x, params)). From the definition of our model above, we have log(p(x | params)) = -0.5*(x - mu).T Q (x - mu) + 0.5*logdet(Q).
82-
83-
This gives log(p(x | y, params)) = f(x) - 0.5*(x - mu).T Q (x - mu) + 0.5*logdet(Q). We will estimate this using the Laplace approximation by Taylor expanding f(x) about the mode.
84-
85-
Thus:
86-
87-
1. Maximize log(p(x | y, params)) = f(x) - 0.5*(x - mu).T Q (x - mu) wrt x (note that logdet(Q) does not depend on x) to find the mode x0.
88-
89-
2. Substitute x0 into the Laplace approximation expanded about the mode: log(p(x | y, params)) ~= -0.5*x.T (-f''(x0) + Q) x + x.T (Q.mu + f'(x0) - f''(x0).x0) + 0.5*logdet(Q).
90-
91-
Parameters
92-
----------
93-
x: TensorVariable
94-
The parameter with which to maximize wrt (that is, find the mode in x). In INLA, this is the latent field x~N(mu,Q^-1).
95-
Q: TensorVariable | ArrayLike
96-
The precision matrix of the latent field x.
97-
mu: TensorVariable | ArrayLike
98-
The mean of the latent field x.
99-
args: list[TensorVariable]
100-
Args to supply to the compiled function. That is, (x0, logp) = f(x, *args). If set to None, assumes the model RVs are args.
101-
model: Model
102-
PyMC model to use.
103-
method: minimize_method
104-
Which minimization algorithm to use.
105-
use_jac: bool
106-
If true, the minimizer will compute the gradient of log(p(x | y, params)).
107-
use_hess: bool
108-
If true, the minimizer will compute the Hessian log(p(x | y, params)).
109-
optimizer_kwargs: dict
110-
Kwargs to pass to scipy.optimize.minimize.
111-
112-
Returns
113-
-------
114-
f: Callable
115-
A function which accepts a value of x and args and returns [x0, log(p(x | y, params))], where x0 is the mode. x is currently both the point at which to evaluate logp and the initial guess for the minimizer.
116-
"""
117-
model = pm.modelcontext(model)
118-
119-
if args is None:
120-
args = model.continuous_value_vars + model.discrete_value_vars
121-
122-
# f = log(p(y | x, params))
123-
f_x = model.logp()
124-
# jac = pytensor.gradient.grad(f_x, x)
125-
# hess = pytensor.gradient.jacobian(jac.flatten(), x)
126-
127-
# log(p(x | y, params)) only including terms that depend on x for the minimization step (logdet(Q) ignored as it is a constant wrt x)
128-
log_x_posterior = f_x - 0.5 * (x - mu).T @ Q @ (
129-
x - mu
130-
) # TODO could be f + x.logp - IS X.LOGP DUPLICATED IN F?
131-
132-
# Maximize log(p(x | y, params)) wrt x to find mode x0
133-
x0, _ = minimize(
134-
objective=-log_x_posterior,
135-
x=x,
136-
method=method,
137-
jac=use_jac,
138-
hess=use_hess,
139-
optimizer_kwargs=optimizer_kwargs,
140-
)
141-
142-
# require f'(x0) and f''(x0) for Laplace approx
143-
# jac = pytensor.graph.replace.graph_replace(jac, {x: x0})
144-
jac = pytensor.gradient.grad(f_x, x)
145-
hess = pytensor.gradient.jacobian(jac.flatten(), x)
146-
hess = pytensor.graph.replace.graph_replace(hess, {x: x0})
147-
148-
# Full log(p(x | y, params)) using the Laplace approximation (up to a constant)
149-
# _, logdetQ = pt.nlinalg.slogdet(Q)
150-
# conditional_gaussian_approx = (
151-
# -0.5 * x.T @ (-hess + Q) @ x + x.T @ (Q @ mu + jac - hess @ x0) + 0.5 * logdetQ
152-
# )
153-
154-
# In the future, this could be made more efficient with only adding the diagonal of -hess
155-
tau = Q - hess
156-
157-
# Currently x is passed both as the query point for f(x, args) = logp(x | y, params) AND as an initial guess for x0. This may cause issues if the query point is
158-
# far from the mode x0 or in a neighbourhood which results in poor convergence.
159-
return (
160-
x0,
161-
pm.MvNormal(f"{x.name}_laplace_approx", mu=x0, tau=tau),
162-
tau,
163-
) # pytensor.function(args, [x0, pm.MvNormal(mu=x0, tau=tau)])
164-
165-
16655
def laplace_draws_to_inferencedata(
16756
posterior_draws: list[np.ndarray[float | int]], model: pm.Model | None = None
16857
) -> az.InferenceData:

0 commit comments

Comments
 (0)