Skip to content

Commit 0ee1ec9

Browse files
bugfix: laplace import/file location
1 parent 8cb19d9 commit 0ee1ec9

File tree

2 files changed

+1
-102
lines changed

2 files changed

+1
-102
lines changed

pymc_extras/inference/fit.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def fit(method: str, **kwargs) -> az.InferenceData:
3737
return fit_pathfinder(**kwargs)
3838

3939
elif method == "laplace":
40-
from pymc_extras.inference.laplace import fit_laplace
40+
from pymc_extras.inference.laplace_approx.laplace import fit_laplace
4141

4242
return fit_laplace(**kwargs)
4343

pymc_extras/inference/laplace_approx/laplace.py

Lines changed: 0 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -15,27 +15,22 @@
1515

1616
import logging
1717

18-
from collections.abc import Callable
1918
from functools import partial
2019
from typing import Literal
2120
from typing import cast as type_cast
2221

2322
import arviz as az
2423
import numpy as np
2524
import pymc as pm
26-
import pytensor
2725
import pytensor.tensor as pt
2826
import xarray as xr
2927

3028
from better_optimize.constants import minimize_method
31-
from numpy.typing import ArrayLike
3229
from pymc.blocking import DictToArrayBijection
3330
from pymc.model.transform.optimization import freeze_dims_and_data
3431
from pymc.pytensorf import join_nonshared_inputs
3532
from pymc.util import get_default_varnames
3633
from pytensor.graph import vectorize_graph
37-
from pytensor.tensor import TensorVariable
38-
from pytensor.tensor.optimize import minimize
3934
from pytensor.tensor.type import Variable
4035

4136
from pymc_extras.inference.laplace_approx.find_map import (
@@ -51,102 +46,6 @@
5146
_log = logging.getLogger(__name__)
5247

5348

54-
def get_conditional_gaussian_approximation(
55-
x: TensorVariable,
56-
Q: TensorVariable | ArrayLike,
57-
mu: TensorVariable | ArrayLike,
58-
args: list[TensorVariable] | None = None,
59-
model: pm.Model | None = None,
60-
method: minimize_method = "BFGS",
61-
use_jac: bool = True,
62-
use_hess: bool = False,
63-
optimizer_kwargs: dict | None = None,
64-
) -> Callable:
65-
"""
66-
Returns a function to estimate the a posteriori log probability of a latent Gaussian field x and its mode x0 using the Laplace approximation.
67-
68-
That is:
69-
y | x, sigma ~ N(Ax, sigma^2 W)
70-
x | params ~ N(mu, Q(params)^-1)
71-
72-
We seek to estimate log(p(x | y, params)):
73-
74-
log(p(x | y, params)) = log(p(y | x, params)) + log(p(x | params)) + const
75-
76-
Let f(x) = log(p(y | x, params)). From the definition of our model above, we have log(p(x | params)) = -0.5*(x - mu).T Q (x - mu) + 0.5*logdet(Q).
77-
78-
This gives log(p(x | y, params)) = f(x) - 0.5*(x - mu).T Q (x - mu) + 0.5*logdet(Q). We will estimate this using the Laplace approximation by Taylor expanding f(x) about the mode.
79-
80-
Thus:
81-
82-
1. Maximize log(p(x | y, params)) = f(x) - 0.5*(x - mu).T Q (x - mu) wrt x (note that logdet(Q) does not depend on x) to find the mode x0.
83-
84-
2. Substitute x0 into the Laplace approximation expanded about the mode: log(p(x | y, params)) ~= -0.5*x.T (-f''(x0) + Q) x + x.T (Q.mu + f'(x0) - f''(x0).x0) + 0.5*logdet(Q).
85-
86-
Parameters
87-
----------
88-
x: TensorVariable
89-
The parameter with which to maximize wrt (that is, find the mode in x). In INLA, this is the latent field x~N(mu,Q^-1).
90-
Q: TensorVariable | ArrayLike
91-
The precision matrix of the latent field x.
92-
mu: TensorVariable | ArrayLike
93-
The mean of the latent field x.
94-
args: list[TensorVariable]
95-
Args to supply to the compiled function. That is, (x0, logp) = f(x, *args). If set to None, assumes the model RVs are args.
96-
model: Model
97-
PyMC model to use.
98-
method: minimize_method
99-
Which minimization algorithm to use.
100-
use_jac: bool
101-
If true, the minimizer will compute the gradient of log(p(x | y, params)).
102-
use_hess: bool
103-
If true, the minimizer will compute the Hessian log(p(x | y, params)).
104-
optimizer_kwargs: dict
105-
Kwargs to pass to scipy.optimize.minimize.
106-
107-
Returns
108-
-------
109-
f: Callable
110-
A function which accepts a value of x and args and returns [x0, log(p(x | y, params))], where x0 is the mode. x is currently both the point at which to evaluate logp and the initial guess for the minimizer.
111-
"""
112-
model = pm.modelcontext(model)
113-
114-
if args is None:
115-
args = model.continuous_value_vars + model.discrete_value_vars
116-
117-
# f = log(p(y | x, params))
118-
f_x = model.logp()
119-
jac = pytensor.gradient.grad(f_x, x)
120-
hess = pytensor.gradient.jacobian(jac.flatten(), x)
121-
122-
# log(p(x | y, params)) only including terms that depend on x for the minimization step (logdet(Q) ignored as it is a constant wrt x)
123-
log_x_posterior = f_x - 0.5 * (x - mu).T @ Q @ (x - mu)
124-
125-
# Maximize log(p(x | y, params)) wrt x to find mode x0
126-
x0, _ = minimize(
127-
objective=-log_x_posterior,
128-
x=x,
129-
method=method,
130-
jac=use_jac,
131-
hess=use_hess,
132-
optimizer_kwargs=optimizer_kwargs,
133-
)
134-
135-
# require f'(x0) and f''(x0) for Laplace approx
136-
jac = pytensor.graph.replace.graph_replace(jac, {x: x0})
137-
hess = pytensor.graph.replace.graph_replace(hess, {x: x0})
138-
139-
# Full log(p(x | y, params)) using the Laplace approximation (up to a constant)
140-
_, logdetQ = pt.nlinalg.slogdet(Q)
141-
conditional_gaussian_approx = (
142-
-0.5 * x.T @ (-hess + Q) @ x + x.T @ (Q @ mu + jac - hess @ x0) + 0.5 * logdetQ
143-
)
144-
145-
# Currently x is passed both as the query point for f(x, args) = logp(x | y, params) AND as an initial guess for x0. This may cause issues if the query point is
146-
# far from the mode x0 or in a neighbourhood which results in poor convergence.
147-
return pytensor.function(args, [x0, conditional_gaussian_approx])
148-
149-
15049
def _unconstrained_vector_to_constrained_rvs(model):
15150
outputs = get_default_varnames(model.unobserved_value_vars, include_transformed=True)
15251
constrained_names = [

0 commit comments

Comments
 (0)