|
1 | 1 | import warnings
|
2 | 2 |
|
3 | 3 | import arviz as az
|
4 |
| -import numpy as np |
5 | 4 | import pymc as pm
|
6 |
| -import pytensor |
7 |
| -import pytensor.tensor as pt |
8 | 5 |
|
9 |
| -from better_optimize.constants import minimize_method |
10 |
| -from numpy.typing import ArrayLike |
11 | 6 | from pymc.distributions.multivariate import MvNormal
|
12 | 7 | from pytensor.tensor import TensorVariable
|
13 | 8 | from pytensor.tensor.linalg import inv as matrix_inverse
|
14 |
| -from pytensor.tensor.optimize import minimize |
15 | 9 |
|
16 | 10 | from pymc_extras.model.marginal.marginal_model import marginalize
|
17 | 11 |
|
18 | 12 |
|
19 |
| -def get_conditional_gaussian_approximation( |
20 |
| - x: TensorVariable, |
21 |
| - Q: TensorVariable | ArrayLike, |
22 |
| - mu: TensorVariable | ArrayLike, |
23 |
| - model: pm.Model | None = None, |
24 |
| - method: minimize_method = "BFGS", |
25 |
| - use_jac: bool = True, |
26 |
| - use_hess: bool = False, |
27 |
| - optimizer_kwargs: dict | None = None, |
28 |
| -) -> list[TensorVariable]: |
29 |
| - """ |
30 |
| - Returns an estimate the a posteriori probability of a latent Gaussian field x and its mode x0 using the Laplace approximation. |
31 |
| -
|
32 |
| - That is: |
33 |
| - y | x, sigma ~ N(Ax, sigma^2 W) |
34 |
| - x | params ~ N(mu, Q(params)^-1) |
35 |
| -
|
36 |
| - We seek to estimate p(x | y, params) with a Gaussian: |
37 |
| -
|
38 |
| - log(p(x | y, params)) = log(p(y | x, params)) + log(p(x | params)) + const |
39 |
| -
|
40 |
| - Let f(x) = log(p(y | x, params)). From the definition of our model above, we have log(p(x | params)) = -0.5*(x - mu).T Q (x - mu) + 0.5*logdet(Q). |
41 |
| -
|
42 |
| - This gives log(p(x | y, params)) = f(x) - 0.5*(x - mu).T Q (x - mu) + 0.5*logdet(Q). We will estimate this using the Laplace approximation by Taylor expanding f(x) about the mode. |
43 |
| -
|
44 |
| - Thus: |
45 |
| -
|
46 |
| - 1. Maximize log(p(x | y, params)) = f(x) - 0.5*(x - mu).T Q (x - mu) wrt x (note that logdet(Q) does not depend on x) to find the mode x0. |
47 |
| -
|
48 |
| - 2. Use the Laplace approximation expanded about the mode: p(x | y, params) ~= N(mu=x0, tau=Q - f''(x0)). |
49 |
| -
|
50 |
| - Parameters |
51 |
| - ---------- |
52 |
| - x: TensorVariable |
53 |
| - The parameter with which to maximize wrt (that is, find the mode in x). In INLA, this is the latent Gaussian field x~N(mu,Q^-1). |
54 |
| - Q: TensorVariable | ArrayLike |
55 |
| - The precision matrix of the latent field x. |
56 |
| - mu: TensorVariable | ArrayLike |
57 |
| - The mean of the latent field x. |
58 |
| - model: Model |
59 |
| - PyMC model to use. |
60 |
| - method: minimize_method |
61 |
| - Which minimization algorithm to use. |
62 |
| - use_jac: bool |
63 |
| - If true, the minimizer will compute the gradient of log(p(x | y, params)). |
64 |
| - use_hess: bool |
65 |
| - If true, the minimizer will compute the Hessian log(p(x | y, params)). |
66 |
| - optimizer_kwargs: dict |
67 |
| - Kwargs to pass to scipy.optimize.minimize. |
68 |
| -
|
69 |
| - Returns |
70 |
| - ------- |
71 |
| - x0, p(x | y, params): list[TensorVariable] |
72 |
| - Mode and Laplace approximation for posterior. |
73 |
| - """ |
74 |
| - raise DeprecationWarning("Legacy code. Please use fit_INLA instead.") |
75 |
| - |
76 |
| - model = pm.modelcontext(model) |
77 |
| - |
78 |
| - # f = log(p(y | x, params)) |
79 |
| - f_x = model.logp() |
80 |
| - |
81 |
| - # log(p(x | y, params)) only including terms that depend on x for the minimization step (logdet(Q) ignored as it is a constant wrt x) |
82 |
| - log_x_posterior = f_x - 0.5 * (x - mu).T @ Q @ (x - mu) |
83 |
| - |
84 |
| - # Maximize log(p(x | y, params)) wrt x to find mode x0 |
85 |
| - x0, _ = minimize( |
86 |
| - objective=-log_x_posterior, |
87 |
| - x=x, |
88 |
| - method=method, |
89 |
| - jac=use_jac, |
90 |
| - hess=use_hess, |
91 |
| - optimizer_kwargs=optimizer_kwargs, |
92 |
| - ) |
93 |
| - |
94 |
| - # require f''(x0) for Laplace approx |
95 |
| - hess = pytensor.gradient.hessian(f_x, x) |
96 |
| - hess = pytensor.graph.replace.graph_replace(hess, {x: x0}) |
97 |
| - |
98 |
| - # Could be made more efficient with adding diagonals only |
99 |
| - tau = Q - hess |
100 |
| - |
101 |
| - # Currently x is passed both as the query point for f(x, args) = logp(x | y, params) AND as an initial guess for x0. This may cause issues if the query point is |
102 |
| - # far from the mode x0 or in a neighbourhood which results in poor convergence. |
103 |
| - _, logdetTau = pt.nlinalg.slogdet(tau) |
104 |
| - return x0, 0.5 * logdetTau - 0.5 * x0.shape[0] * np.log(2 * np.pi) |
105 |
| - |
106 |
| - |
107 |
| -def get_log_marginal_likelihood( |
108 |
| - x: TensorVariable, |
109 |
| - Q: TensorVariable | ArrayLike, |
110 |
| - mu: TensorVariable | ArrayLike, |
111 |
| - model: pm.Model | None = None, |
112 |
| - method: minimize_method = "BFGS", |
113 |
| - use_jac: bool = True, |
114 |
| - use_hess: bool = False, |
115 |
| - optimizer_kwargs: dict | None = None, |
116 |
| -) -> TensorVariable: |
117 |
| - raise DeprecationWarning("Legacy code. Please use fit_INLA instead.") |
118 |
| - |
119 |
| - model = pm.modelcontext(model) |
120 |
| - |
121 |
| - x0, log_laplace_approx = get_conditional_gaussian_approximation( |
122 |
| - x, Q, mu, model, method, use_jac, use_hess, optimizer_kwargs |
123 |
| - ) |
124 |
| - # log_laplace_approx = pm.logp(laplace_approx, x)#model.rvs_to_values[x]) |
125 |
| - |
126 |
| - _, logdetQ = pt.nlinalg.slogdet(Q) |
127 |
| - # log_x_likelihood = ( |
128 |
| - # -0.5 * (x - mu).T @ Q @ (x - mu) + 0.5 * logdetQ - 0.5 * x.shape[0] * np.log(2 * np.pi) |
129 |
| - # ) |
130 |
| - log_x_likelihood = ( |
131 |
| - -0.5 * (x0 - mu).T @ Q @ (x0 - mu) + 0.5 * logdetQ - 0.5 * x0.shape[0] * np.log(2 * np.pi) |
132 |
| - ) |
133 |
| - |
134 |
| - log_likelihood = ( # logp(y | params) = |
135 |
| - model.logp() # logp(y | x, params) |
136 |
| - + log_x_likelihood # * logp(x | params) |
137 |
| - - log_laplace_approx # / logp(x | y, params) |
138 |
| - ) |
139 |
| - |
140 |
| - return x0, log_likelihood |
141 |
| - |
142 |
| - |
143 | 13 | def fit_INLA(
|
144 | 14 | x: TensorVariable,
|
145 | 15 | temp_kwargs=None, # TODO REMOVE. DEBUGGING TOOL
|
|
0 commit comments