Skip to content

Commit 065c6b2

Browse files
refactor: labelling of p(x|y,params)
1 parent fb39764 commit 065c6b2

File tree

1 file changed

+4
-5
lines changed

1 file changed

+4
-5
lines changed

pymc_extras/model/marginal/distributions.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -436,7 +436,7 @@ def get_laplace_approx(
436436
minimizer_kwargs: dict = {"method": "L-BFGS-B", "optimizer_kwargs": {"tol": 1e-8}},
437437
):
438438
"""
439-
Compute the laplace approximation of some variable x.
439+
Compute the laplace approximation logp_G(x | y, params) of some variable x.
440440
441441
Parameters
442442
----------
@@ -458,7 +458,7 @@ def get_laplace_approx(
458458
x0: TensorVariable
459459
x*, the maximizer of logp(x | y, params) in x.
460460
log_laplace_approx: TensorVariable
461-
Laplace approximation evaluated at x.
461+
Laplace approximation of logp(x | y, params) evaluated at x.
462462
"""
463463
# Maximize log(p(x | y, params)) wrt x to find mode x0
464464
# This step is currently bottlenecking the logp calculation.
@@ -471,13 +471,12 @@ def get_laplace_approx(
471471
# Set minimizer initialisation to be random
472472
x0 = pytensor.graph.replace.graph_replace(x0, {x: x0_init})
473473

474-
# logp(x | y, params) using laplace approx evaluated at x0
475474
# This step is also expensive (but not as much as minimize). Could be made more efficient by recycling hessian from the minimizer step, however that requires a bespoke algorithm described in Rasmussen & Williams
476475
# since the general optimisation scheme maximises logp(x | y, params) rather than logp(y | x, params), and thus the hessian that comes out of methods
477476
# like L-BFGS-B is in fact not the hessian of logp(y | x, params)
478477
hess = pytensor.gradient.hessian(log_likelihood, x)
479478

480-
# Evaluate logp of Laplace approx N(x*, Q - f"(x*)) at some point x
479+
# Evaluate logp of Laplace approx of logp(x | y, params) at some point x
481480
tau = Q - hess
482481
mu = x0
483482
log_laplace_approx, _ = _precision_mv_normal_logp(x, mu, tau)
@@ -502,7 +501,7 @@ def laplace_marginal_rv_logp(op: MarginalLaplaceRV, values, *inputs, **kwargs):
502501
[logp_term.sum() for value, logp_term in logps_dict.items() if value is not marginalized_vv]
503502
)
504503

505-
# logp = logp(y | x, params) + logp(x | params)
504+
# logp = logp(y | x, params) + logp(x | params) (i.e. logp(x | y, params) up to a constant in x)
506505
logp = pt.sum([pt.sum(logps_dict[k]) for k in logps_dict])
507506

508507
# Set minimizer initialisation to be random

0 commit comments

Comments
 (0)