Skip to content

Commit 37fa1ed

Browse files
added seperate line for evaluating Q-hess
1 parent c05c94e commit 37fa1ed

File tree

1 file changed

+7
-4
lines changed

1 file changed

+7
-4
lines changed

pymc_extras/inference/laplace.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -143,13 +143,16 @@ def get_conditional_gaussian_approximation(
143143

144144
# Full log(p(x | y, params)) using the Laplace approximation (up to a constant)
145145
_, logdetQ = pt.nlinalg.slogdet(Q)
146-
conditional_gaussian_approx = (
147-
-0.5 * x.T @ (-hess + Q) @ x + x.T @ (Q @ mu + jac - hess @ x0) + 0.5 * logdetQ
148-
)
146+
# conditional_gaussian_approx = (
147+
# -0.5 * x.T @ (-hess + Q) @ x + x.T @ (Q @ mu + jac - hess @ x0) + 0.5 * logdetQ
148+
# )
149+
150+
# In the future, this could be made more efficient with only adding the diagonal of -hess
151+
tau = Q - hess
149152

150153
# Currently x is passed both as the query point for f(x, args) = logp(x | y, params) AND as an initial guess for x0. This may cause issues if the query point is
151154
# far from the mode x0 or in a neighbourhood which results in poor convergence.
152-
return pytensor.function(args, pm.MvNormal(mu=x0, tau=Q-hess))
155+
return pytensor.function(args, [x0, pm.MvNormal(mu=x0, tau=tau)])
153156

154157

155158
def laplace_draws_to_inferencedata(

0 commit comments

Comments
 (0)