Skip to content

Commit 83bef75

Browse files
WIP: minor refactor
1 parent 37fa1ed commit 83bef75

File tree

2 files changed

+15
-7
lines changed

2 files changed

+15
-7
lines changed

pymc_extras/inference/laplace.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -121,11 +121,13 @@ def get_conditional_gaussian_approximation(
121121

122122
# f = log(p(y | x, params))
123123
f_x = model.logp()
124-
jac = pytensor.gradient.grad(f_x, x)
125-
hess = pytensor.gradient.jacobian(jac.flatten(), x)
124+
# jac = pytensor.gradient.grad(f_x, x)
125+
# hess = pytensor.gradient.jacobian(jac.flatten(), x)
126126

127127
# log(p(x | y, params)) only including terms that depend on x for the minimization step (logdet(Q) ignored as it is a constant wrt x)
128-
log_x_posterior = f_x - 0.5 * (x - mu).T @ Q @ (x - mu)
128+
log_x_posterior = f_x - 0.5 * (x - mu).T @ Q @ (
129+
x - mu
130+
) # TODO could be f + x.logp - IS X.LOGP DUPLICATED IN F?
129131

130132
# Maximize log(p(x | y, params)) wrt x to find mode x0
131133
x0, _ = minimize(
@@ -138,11 +140,13 @@ def get_conditional_gaussian_approximation(
138140
)
139141

140142
# require f'(x0) and f''(x0) for Laplace approx
141-
jac = pytensor.graph.replace.graph_replace(jac, {x: x0})
143+
# jac = pytensor.graph.replace.graph_replace(jac, {x: x0})
144+
jac = pytensor.gradient.grad(f_x, x)
145+
hess = pytensor.gradient.jacobian(jac.flatten(), x)
142146
hess = pytensor.graph.replace.graph_replace(hess, {x: x0})
143147

144148
# Full log(p(x | y, params)) using the Laplace approximation (up to a constant)
145-
_, logdetQ = pt.nlinalg.slogdet(Q)
149+
# _, logdetQ = pt.nlinalg.slogdet(Q)
146150
# conditional_gaussian_approx = (
147151
# -0.5 * x.T @ (-hess + Q) @ x + x.T @ (Q @ mu + jac - hess @ x0) + 0.5 * logdetQ
148152
# )
@@ -152,7 +156,11 @@ def get_conditional_gaussian_approximation(
152156

153157
# Currently x is passed both as the query point for f(x, args) = logp(x | y, params) AND as an initial guess for x0. This may cause issues if the query point is
154158
# far from the mode x0 or in a neighbourhood which results in poor convergence.
155-
return pytensor.function(args, [x0, pm.MvNormal(mu=x0, tau=tau)])
159+
return (
160+
x0,
161+
pm.MvNormal(f"{x.name}_laplace_approx", mu=x0, tau=tau),
162+
tau,
163+
) # pytensor.function(args, [x0, pm.MvNormal(mu=x0, tau=tau)])
156164

157165

158166
def laplace_draws_to_inferencedata(

tests/test_laplace.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -323,7 +323,7 @@ def test_get_conditional_gaussian_approximation():
323323
Q = pm.MvNormal("Q", mu=Q_mu, cov=Q_cov)
324324

325325
# Pytensor currently doesn't support autograd for pt inverses, so we use a numeric Q instead
326-
x = pm.MvNormal("x", mu=mu_param, cov=np.linalg.inv(Q_val))
326+
x = pm.MvNormal("x", mu=mu_param, tau=Q) # cov=np.linalg.inv(Q_val))
327327

328328
y = pm.MvNormal(
329329
"y",

0 commit comments

Comments
 (0)