Skip to content

Commit 47b8dae

Browse files
refactor: moved laplace approx into seperate function + more docstrings
1 parent 3636e98 commit 47b8dae

File tree

1 file changed

+74
-23
lines changed

1 file changed

+74
-23
lines changed

pymc_extras/model/marginal/distributions.py

Lines changed: 74 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from pytensor.graph.replace import clone_replace, graph_replace
2222
from pytensor.scan import map as scan_map
2323
from pytensor.scan import scan
24-
from pytensor.tensor import TensorVariable
24+
from pytensor.tensor import TensorLike, TensorVariable
2525
from pytensor.tensor.optimize import minimize
2626
from pytensor.tensor.random.type import RandomType
2727

@@ -397,18 +397,25 @@ def step_alpha(logp_emission, log_alpha, log_P):
397397
return joint_logp, *dummy_logps
398398

399399

400-
def _precision_mv_normal_logp(value, mean, tau):
400+
def _precision_mv_normal_logp(value: TensorLike, mean: TensorLike, tau: TensorLike):
401401
"""
402402
Compute the log likelihood of a multivariate normal distribution in precision form. May be phased out - see https://github.com/pymc-devs/pymc/pull/7895
403403
404404
Parameters
405405
----------
406-
value: TODO
406+
value: TensorLike
407407
Query point to compute the log prob at.
408-
mean: TODO
408+
mean: TensorLike
409409
Mean vector of the Gaussian,
410-
tau: TODO
410+
tau: TensorLike
411411
Precision matrix of the Gaussian (i.e. cov = inv(tau))
412+
413+
Returns
414+
-------
415+
logp: TensorLike
416+
Log likelihood at value.
417+
posdef: TensorLike
418+
Boolean indicating whether the precision matrix is positive definite.
412419
"""
413420
k = value.shape[-1].astype("floatX")
414421

@@ -420,6 +427,64 @@ def _precision_mv_normal_logp(value, mean, tau):
420427
return logp, posdef
421428

422429

430+
def get_laplace_approx(
431+
log_likelihood: TensorVariable,
432+
logp_objective: TensorVariable,
433+
x: TensorVariable,
434+
x0_init: TensorLike,
435+
Q: TensorLike,
436+
minimizer_kwargs: dict = {"method": "L-BFGS-B", "optimizer_kwargs": {"tol": 1e-8}},
437+
):
438+
"""
439+
Compute the laplace approximation of some variable x.
440+
441+
Parameters
442+
----------
443+
log_likelihood: TensorVariable
444+
Model likelihood logp(y | x, params).
445+
logp_objective: TensorVariable
446+
Obective log likelihood to maximize, logp(x | y, params) (up to some constant in x).
447+
x: TensorVariable
448+
Variable to be laplace approximated.
449+
x0_init: TensorLike
450+
Initial guess for minimization.
451+
Q: TensorLike
452+
Precision matrix of x.
453+
minimizer_kwargs:
454+
Kwargs to pass to pytensor.optimize.minimize.
455+
456+
Returns
457+
-------
458+
x0: TensorVariable
459+
x*, the maximizer of logp(x | y, params) in x.
460+
log_laplace_approx: TensorVariable
461+
Laplace approximation evaluated at x.
462+
"""
463+
# Maximize log(p(x | y, params)) wrt x to find mode x0
464+
# This step is currently bottlenecking the logp calculation.
465+
x0, _ = minimize(
466+
objective=-logp_objective, # logp(x | y, params) = logp(y | x, params) + logp(x | params) + const (const omitted during minimization)
467+
x=x,
468+
**minimizer_kwargs,
469+
)
470+
471+
# Set minimizer initialisation to be random
472+
x0 = pytensor.graph.replace.graph_replace(x0, {x: x0_init})
473+
474+
# logp(x | y, params) using laplace approx evaluated at x0
475+
# This step is also expensive (but not as much as minimize). Could be made more efficient by recycling hessian from the minimizer step, however that requires a bespoke algorithm described in Rasmussen & Williams
476+
# since the general optimisation scheme maximises logp(x | y, params) rather than logp(y | x, params), and thus the hessian that comes out of methods
477+
# like L-BFGS-B is in fact not the hessian of logp(y | x, params)
478+
hess = pytensor.gradient.hessian(log_likelihood, x)
479+
480+
# Evaluate logp of Laplace approx N(x*, Q - f"(x*)) at some point x
481+
tau = Q - hess
482+
mu = x0
483+
log_laplace_approx, _ = _precision_mv_normal_logp(x, mu, tau)
484+
485+
return x0, log_laplace_approx
486+
487+
423488
@_logprob.register(MarginalLaplaceRV)
424489
def laplace_marginal_rv_logp(op: MarginalLaplaceRV, values, *inputs, **kwargs):
425490
# Clone the inner RV graph of the Marginalized RV
@@ -440,26 +505,11 @@ def laplace_marginal_rv_logp(op: MarginalLaplaceRV, values, *inputs, **kwargs):
440505
# logp = logp(y | x, params) + logp(x | params)
441506
logp = pt.sum([pt.sum(logps_dict[k]) for k in logps_dict])
442507

443-
# Maximize log(p(x | y, params)) wrt x to find mode x0
444-
# This step is currently bottlenecking the logp calculation.
445-
x0, _ = minimize(
446-
objective=-logp, # logp(x | y, params) = logp(y | x, params) + logp(x | params) + const (const omitted during minimization)
447-
x=marginalized_vv,
448-
**op.minimizer_kwargs,
449-
)
450-
451508
# Set minimizer initialisation to be random
452509
# Assumes that the observed variable y is the first/only element of values, and that d is shape[-1]
453510
d = values[0].data.shape[-1]
454511
rng = np.random.default_rng(op.minimizer_seed)
455512
x0_init = rng.random(d)
456-
x0 = pytensor.graph.replace.graph_replace(x0, {marginalized_vv: x0_init})
457-
458-
# logp(x | y, params) using laplace approx evaluated at x0
459-
# This step is also expensive (but not as much as minimize). Could be made more efficient by recycling hessian from the minimizer step, however that requires a bespoke algorithm described in Rasmussen & Williams
460-
# since the general optimisation scheme maximises logp(x | y, params) rather than logp(y | x, params), and thus the hessian that comes out of methods
461-
# like L-BFGS-B is in fact not the hessian of logp(y | x, params)
462-
hess = pytensor.gradient.hessian(log_likelihood, marginalized_vv)
463513

464514
# Get Q from the list of inputs
465515
Q = None
@@ -486,9 +536,10 @@ def laplace_marginal_rv_logp(op: MarginalLaplaceRV, values, *inputs, **kwargs):
486536
else:
487537
Q = op.Q
488538

489-
tau = Q - hess
490-
mu = x0
491-
log_laplace_approx, _ = _precision_mv_normal_logp(x0, mu, tau)
539+
# Obtain laplace approx
540+
x0, log_laplace_approx = get_laplace_approx(
541+
log_likelihood, logp, marginalized_vv, x0_init, Q, op.minimizer_kwargs
542+
)
492543

493544
# logp(y | params) = logp(y | x, params) + logp(x | params) - logp(x | y, params)
494545
marginal_likelihood = logp - log_laplace_approx

0 commit comments

Comments
 (0)