|
15 | 15 |
|
16 | 16 | import logging
|
17 | 17 |
|
18 |
| -from collections.abc import Callable |
19 | 18 | from functools import partial
|
20 | 19 | from typing import Literal
|
21 | 20 | from typing import cast as type_cast
|
22 | 21 |
|
23 | 22 | import arviz as az
|
24 | 23 | import numpy as np
|
25 | 24 | import pymc as pm
|
26 |
| -import pytensor |
27 | 25 | import pytensor.tensor as pt
|
28 | 26 | import xarray as xr
|
29 | 27 |
|
30 | 28 | from better_optimize.constants import minimize_method
|
31 |
| -from numpy.typing import ArrayLike |
32 | 29 | from pymc.blocking import DictToArrayBijection
|
33 | 30 | from pymc.model.transform.optimization import freeze_dims_and_data
|
34 | 31 | from pymc.pytensorf import join_nonshared_inputs
|
35 | 32 | from pymc.util import get_default_varnames
|
36 | 33 | from pytensor.graph import vectorize_graph
|
37 |
| -from pytensor.tensor import TensorVariable |
38 |
| -from pytensor.tensor.optimize import minimize |
39 | 34 | from pytensor.tensor.type import Variable
|
40 | 35 |
|
41 | 36 | from pymc_extras.inference.laplace_approx.find_map import (
|
|
51 | 46 | _log = logging.getLogger(__name__)
|
52 | 47 |
|
53 | 48 |
|
54 |
| -def get_conditional_gaussian_approximation( |
55 |
| - x: TensorVariable, |
56 |
| - Q: TensorVariable | ArrayLike, |
57 |
| - mu: TensorVariable | ArrayLike, |
58 |
| - args: list[TensorVariable] | None = None, |
59 |
| - model: pm.Model | None = None, |
60 |
| - method: minimize_method = "BFGS", |
61 |
| - use_jac: bool = True, |
62 |
| - use_hess: bool = False, |
63 |
| - optimizer_kwargs: dict | None = None, |
64 |
| -) -> Callable: |
65 |
| - """ |
66 |
| - Returns a function to estimate the a posteriori log probability of a latent Gaussian field x and its mode x0 using the Laplace approximation. |
67 |
| -
|
68 |
| - That is: |
69 |
| - y | x, sigma ~ N(Ax, sigma^2 W) |
70 |
| - x | params ~ N(mu, Q(params)^-1) |
71 |
| -
|
72 |
| - We seek to estimate log(p(x | y, params)): |
73 |
| -
|
74 |
| - log(p(x | y, params)) = log(p(y | x, params)) + log(p(x | params)) + const |
75 |
| -
|
76 |
| - Let f(x) = log(p(y | x, params)). From the definition of our model above, we have log(p(x | params)) = -0.5*(x - mu).T Q (x - mu) + 0.5*logdet(Q). |
77 |
| -
|
78 |
| - This gives log(p(x | y, params)) = f(x) - 0.5*(x - mu).T Q (x - mu) + 0.5*logdet(Q). We will estimate this using the Laplace approximation by Taylor expanding f(x) about the mode. |
79 |
| -
|
80 |
| - Thus: |
81 |
| -
|
82 |
| - 1. Maximize log(p(x | y, params)) = f(x) - 0.5*(x - mu).T Q (x - mu) wrt x (note that logdet(Q) does not depend on x) to find the mode x0. |
83 |
| -
|
84 |
| - 2. Substitute x0 into the Laplace approximation expanded about the mode: log(p(x | y, params)) ~= -0.5*x.T (-f''(x0) + Q) x + x.T (Q.mu + f'(x0) - f''(x0).x0) + 0.5*logdet(Q). |
85 |
| -
|
86 |
| - Parameters |
87 |
| - ---------- |
88 |
| - x: TensorVariable |
89 |
| - The parameter with which to maximize wrt (that is, find the mode in x). In INLA, this is the latent field x~N(mu,Q^-1). |
90 |
| - Q: TensorVariable | ArrayLike |
91 |
| - The precision matrix of the latent field x. |
92 |
| - mu: TensorVariable | ArrayLike |
93 |
| - The mean of the latent field x. |
94 |
| - args: list[TensorVariable] |
95 |
| - Args to supply to the compiled function. That is, (x0, logp) = f(x, *args). If set to None, assumes the model RVs are args. |
96 |
| - model: Model |
97 |
| - PyMC model to use. |
98 |
| - method: minimize_method |
99 |
| - Which minimization algorithm to use. |
100 |
| - use_jac: bool |
101 |
| - If true, the minimizer will compute the gradient of log(p(x | y, params)). |
102 |
| - use_hess: bool |
103 |
| - If true, the minimizer will compute the Hessian log(p(x | y, params)). |
104 |
| - optimizer_kwargs: dict |
105 |
| - Kwargs to pass to scipy.optimize.minimize. |
106 |
| -
|
107 |
| - Returns |
108 |
| - ------- |
109 |
| - f: Callable |
110 |
| - A function which accepts a value of x and args and returns [x0, log(p(x | y, params))], where x0 is the mode. x is currently both the point at which to evaluate logp and the initial guess for the minimizer. |
111 |
| - """ |
112 |
| - model = pm.modelcontext(model) |
113 |
| - |
114 |
| - if args is None: |
115 |
| - args = model.continuous_value_vars + model.discrete_value_vars |
116 |
| - |
117 |
| - # f = log(p(y | x, params)) |
118 |
| - f_x = model.logp() |
119 |
| - jac = pytensor.gradient.grad(f_x, x) |
120 |
| - hess = pytensor.gradient.jacobian(jac.flatten(), x) |
121 |
| - |
122 |
| - # log(p(x | y, params)) only including terms that depend on x for the minimization step (logdet(Q) ignored as it is a constant wrt x) |
123 |
| - log_x_posterior = f_x - 0.5 * (x - mu).T @ Q @ (x - mu) |
124 |
| - |
125 |
| - # Maximize log(p(x | y, params)) wrt x to find mode x0 |
126 |
| - x0, _ = minimize( |
127 |
| - objective=-log_x_posterior, |
128 |
| - x=x, |
129 |
| - method=method, |
130 |
| - jac=use_jac, |
131 |
| - hess=use_hess, |
132 |
| - optimizer_kwargs=optimizer_kwargs, |
133 |
| - ) |
134 |
| - |
135 |
| - # require f'(x0) and f''(x0) for Laplace approx |
136 |
| - jac = pytensor.graph.replace.graph_replace(jac, {x: x0}) |
137 |
| - hess = pytensor.graph.replace.graph_replace(hess, {x: x0}) |
138 |
| - |
139 |
| - # Full log(p(x | y, params)) using the Laplace approximation (up to a constant) |
140 |
| - _, logdetQ = pt.nlinalg.slogdet(Q) |
141 |
| - conditional_gaussian_approx = ( |
142 |
| - -0.5 * x.T @ (-hess + Q) @ x + x.T @ (Q @ mu + jac - hess @ x0) + 0.5 * logdetQ |
143 |
| - ) |
144 |
| - |
145 |
| - # Currently x is passed both as the query point for f(x, args) = logp(x | y, params) AND as an initial guess for x0. This may cause issues if the query point is |
146 |
| - # far from the mode x0 or in a neighbourhood which results in poor convergence. |
147 |
| - return pytensor.function(args, [x0, conditional_gaussian_approx]) |
148 |
| - |
149 |
| - |
150 | 49 | def _unconstrained_vector_to_constrained_rvs(model):
|
151 | 50 | outputs = get_default_varnames(model.unobserved_value_vars, include_transformed=True)
|
152 | 51 | constrained_names = [
|
|
0 commit comments