Skip to content

Commit d0aaae5

Browse files
added comments explaining logp bottleneck
1 parent dccd9a6 commit d0aaae5

File tree

1 file changed

+6
-4
lines changed

1 file changed

+6
-4
lines changed

pymc_extras/model/marginal/distributions.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -432,23 +432,25 @@ def laplace_marginal_rv_logp(op: MarginalLaplaceRV, values, *inputs, **kwargs):
432432
else {"method": "L-BFGS-B", "optimizer_kwargs": {"tol": 1e-8}}
433433
)
434434

435+
# This step is currently bottlenecking the logp calculation.
435436
x0, _ = minimize(
436437
objective=-logp, # logp(x | y, params) = logp(y | x, params) + logp(x | params) + const (const omitted during minimization)
437438
x=marginalized_vv,
438439
**minimizer_kwargs,
439440
)
440441

441442
# Set minimizer initialisation to be random
442-
# TODO Assumes that the observed variable y is the first/only element of values, and that d is shape[-1]
443+
# Assumes that the observed variable y is the first/only element of values, and that d is shape[-1]
443444
d = values[0].data.shape[-1]
444445
rng = np.random.default_rng(op.minimizer_seed)
445446
x0_init = rng.random(d)
446447
x0 = pytensor.graph.replace.graph_replace(x0, {marginalized_vv: x0_init})
447448

448449
# logp(x | y, params) using laplace approx evaluated at x0
449-
hess = pytensor.gradient.hessian(
450-
log_likelihood, marginalized_vv
451-
) # TODO check how stan makes this quicker
450+
# This step is also expensive (but not as much as minimize). Could be made more efficient by recycling hessian from the minimizer step, however that requires a bespoke algorithm described in Rasmussen & Williams
451+
# since the general optimisation scheme maximises logp(x | y, params) rather than logp(y | x, params), and thus the hessian that comes out of methods
452+
# like L-BFGS-B is in fact not the hessian of logp(y | x, params)
453+
hess = pytensor.gradient.hessian(log_likelihood, marginalized_vv)
452454

453455
# Get Q from the list of inputs
454456
Q = None

0 commit comments

Comments
 (0)