Skip to content

Commit a266a2e

Browse files
detailed docstring
1 parent bfa1a1a commit a266a2e

File tree

1 file changed

+21
-3
lines changed

1 file changed

+21
-3
lines changed

pymc_extras/inference/laplace.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,25 @@ def get_conditional_gaussian_approximation(
6868
optimizer_kwargs: dict | None = None,
6969
) -> Callable:
7070
"""
71-
Returns a function to estimate log(p(x | y, params)) and its mode x0 using the Laplace approximation.
71+
Returns a function to estimate the a posteriori log probability of a latent Gaussian field x and its mode x0 using the Laplace approximation.
72+
73+
That is:
74+
y | x, sigma ~ N(Ax, sigma^2 W)
75+
x | params ~ N(mu, Q(params)^-1)
76+
77+
We seek to estimate log(p(x | y, params)):
78+
79+
log(p(x | y, params)) = log(p(y | x, params)) + log(p(x | params)) + const
80+
81+
Let f(x) = log(p(y | x, params)). From the definition of our model above, we have log(p(x | params)) = -0.5*(x - mu).T Q (x - mu) + 0.5*logdet(Q).
82+
83+
This gives log(p(x | y, params)) = f(x) - 0.5*(x - mu).T Q (x - mu) + 0.5*logdet(Q). We will estimate this using the Laplace approximation by Taylor expanding f(x) about the mode.
84+
85+
Thus:
86+
87+
1. Maximize log(p(x | y, params)) = f(x) - 0.5*(x - mu).T Q (x - mu) wrt x (note that logdet(Q) does not depend on x) to find the mode x0.
88+
89+
2. Substitute x0 into the Laplace approximation expanded about the mode: log(p(x | y, params)) ~= -0.5*x.T (-f''(x0) + Q) x + x.T (Q.mu + f'(x0) - f''(x0).x0) + 0.5*logdet(Q).
7290
7391
Parameters
7492
----------
@@ -109,7 +127,7 @@ def get_conditional_gaussian_approximation(
109127
# log(p(x | y, params)) only including terms that depend on x for the minimization step (logdet(Q) ignored as it is a constant wrt x)
110128
log_x_posterior = f_x - 0.5 * (x - mu).T @ Q @ (x - mu)
111129

112-
# Maximize log(p(x | y, params)) wrt x
130+
# Maximize log(p(x | y, params)) wrt x to find mode x0
113131
x0, _ = minimize(
114132
objective=-log_x_posterior,
115133
x=x,
@@ -123,7 +141,7 @@ def get_conditional_gaussian_approximation(
123141
jac = pytensor.graph.replace.graph_replace(jac, {x: x0})
124142
hess = pytensor.graph.replace.graph_replace(hess, {x: x0})
125143

126-
# Full log(p(x | y, params)) using Laplace approximation (up to a constant)
144+
# Full log(p(x | y, params)) using the Laplace approximation (up to a constant)
127145
_, logdetQ = pt.nlinalg.slogdet(Q)
128146
conditional_gaussian_approx = (
129147
-0.5 * x.T @ (-hess + Q) @ x + x.T @ (Q @ mu + jac - hess @ x0) + 0.5 * logdetQ

0 commit comments

Comments
 (0)