Skip to content

Commit 263f612

Browse files
WIP: MarginalLaplaceRV
1 parent 43fb626 commit 263f612

File tree

3 files changed

+72
-8
lines changed

3 files changed

+72
-8
lines changed

pymc_extras/inference/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@
1414

1515
from pymc_extras.inference.find_map import find_MAP
1616
from pymc_extras.inference.fit import fit
17+
from pymc_extras.inference.inla import fit_INLA
1718
from pymc_extras.inference.laplace import fit_laplace
1819
from pymc_extras.inference.pathfinder.pathfinder import fit_pathfinder
1920

20-
__all__ = ["fit", "fit_pathfinder", "fit_laplace", "find_MAP"]
21+
__all__ = ["fit", "fit_pathfinder", "fit_laplace", "find_MAP", "fit_INLA"]

pymc_extras/inference/inla.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,8 @@ def get_conditional_gaussian_approximation(
9292

9393
# Currently x is passed both as the query point for f(x, args) = logp(x | y, params) AND as an initial guess for x0. This may cause issues if the query point is
9494
# far from the mode x0 or in a neighbourhood which results in poor convergence.
95-
return x0, pm.MvNormal(f"{x.name}_laplace_approx", mu=x0, tau=tau)
95+
_, logdetTau = pt.nlinalg.slogdet(tau)
96+
return x0, 0.5 * logdetTau - 0.5 * x0.shape[0] * np.log(2 * np.pi)
9697

9798

9899
def get_log_marginal_likelihood(
@@ -107,14 +108,17 @@ def get_log_marginal_likelihood(
107108
) -> TensorVariable:
108109
model = pm.modelcontext(model)
109110

110-
x0, laplace_approx = get_conditional_gaussian_approximation(
111+
x0, log_laplace_approx = get_conditional_gaussian_approximation(
111112
x, Q, mu, model, method, use_jac, use_hess, optimizer_kwargs
112113
)
113-
log_laplace_approx = pm.logp(laplace_approx, model.rvs_to_values[x])
114+
# log_laplace_approx = pm.logp(laplace_approx, x)#model.rvs_to_values[x])
114115

115116
_, logdetQ = pt.nlinalg.slogdet(Q)
117+
# log_x_likelihood = (
118+
# -0.5 * (x - mu).T @ Q @ (x - mu) + 0.5 * logdetQ - 0.5 * x.shape[0] * np.log(2 * np.pi)
119+
# )
116120
log_x_likelihood = (
117-
-0.5 * (x - mu).T @ Q @ (x - mu) + 0.5 * logdetQ - 0.5 * x.shape[0] * np.log(2 * np.pi)
121+
-0.5 * (x0 - mu).T @ Q @ (x0 - mu) + 0.5 * logdetQ - 0.5 * x0.shape[0] * np.log(2 * np.pi)
118122
)
119123

120124
log_likelihood = ( # logp(y | params) =
@@ -123,7 +127,7 @@ def get_log_marginal_likelihood(
123127
- log_laplace_approx # / logp(x | y, params)
124128
)
125129

126-
return log_likelihood
130+
return x0, log_likelihood
127131

128132

129133
def fit_INLA(
@@ -139,23 +143,25 @@ def fit_INLA(
139143
model = pm.modelcontext(model)
140144

141145
# logp(y | params)
142-
log_likelihood = get_log_marginal_likelihood(
146+
x0, log_likelihood = get_log_marginal_likelihood(
143147
x, Q, mu, model, method, use_jac, use_hess, optimizer_kwargs
144148
)
145149

146150
# TODO How to obtain prior? It can parametrise Q, mu, y, etc. Not sure if we could extract from model.logp somehow. Otherwise simply specify as a user input
151+
# Perhaps obtain as RVs which y depends on which aren't x?
147152
prior = None
148153
params = None
149154
log_prior = pm.logp(prior, model.rvs_to_values[params])
150155

151156
# logp(params | y) = logp(y | params) + logp(params) + const
152157
log_posterior = log_likelihood + log_prior
158+
log_posterior = pytensor.graph.replace.graph_replace(log_posterior, {x: x0})
153159

154160
# TODO log_marginal_x_likelihood is almost the same as log_likelihood, but need to do some sampling?
155161
log_marginal_x_likelihood = None
156162
log_marginal_x_posterior = log_marginal_x_likelihood + log_prior
157163

158-
# TODO can we sample over log likelihoods?
164+
# TODO can we sample over log likelihoods?w
159165
# Marginalize params
160166
idata_params = log_posterior.sample() # TODO something like NUTS, QMC, etc.?
161167
idata_x = log_marginal_x_posterior.sample()

pymc_extras/model/marginal/distributions.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,10 @@ class MarginalDiscreteMarkovChainRV(MarginalRV):
132132
"""Base class for Marginalized Discrete Markov Chain RVs"""
133133

134134

135+
class MarginalLaplaceRV(MarginalRV):
136+
"""Base class for Marginalized Laplace-Approximated RVs"""
137+
138+
135139
def get_domain_of_finite_discrete_rv(rv: TensorVariable) -> tuple[int, ...]:
136140
op = rv.owner.op
137141
dist_params = rv.owner.op.dist_params(rv.owner)
@@ -371,3 +375,56 @@ def step_alpha(logp_emission, log_alpha, log_P):
371375
warn_non_separable_logp(values)
372376
dummy_logps = (DUMMY_ZERO,) * (len(values) - 1)
373377
return joint_logp, *dummy_logps
378+
379+
380+
@_logprob.register(MarginalLaplaceRV)
381+
def laplace_marginal_rv_logp(op: MarginalLaplaceRV, values, *inputs, **kwargs):
382+
# Clone the inner RV graph of the Marginalized RV
383+
x, *inner_rvs = inline_ofg_outputs(op, inputs)
384+
385+
# Obtain the joint_logp graph of the inner RV graph
386+
inner_rv_values = dict(zip(inner_rvs, values))
387+
marginalized_vv = x.clone()
388+
rv_values = inner_rv_values | {x: marginalized_vv}
389+
logps_dict = conditional_logp(rv_values=rv_values, **kwargs)
390+
391+
logp = pt.sum(
392+
[pt.sum(logps_dict[k]) for k in logps_dict]
393+
) # TODO check this gives the proper p(y | x, params)
394+
395+
import pytensor
396+
397+
from pytensor.tensor.optimize import minimize
398+
399+
# Maximize log(p(x | y, params)) wrt x to find mode x0
400+
x0, _ = minimize(
401+
objective=-logp,
402+
x=marginalized_vv,
403+
method="BFGS",
404+
# jac=use_jac,
405+
# hess=use_hess,
406+
optimizer_kwargs={"tol": 1e-8},
407+
)
408+
409+
# require f''(x0) for Laplace approx
410+
hess = pytensor.gradient.hessian(logp, marginalized_vv)
411+
# hess = pytensor.graph.replace.graph_replace(hess, {marginalized_vv: x0})
412+
413+
# Could be made more efficient with adding diagonals only
414+
rng = np.random.default_rng(12345)
415+
d = 3
416+
Q = np.diag(rng.random(d))
417+
tau = Q - hess
418+
419+
# Currently x is passed both as the query point for f(x, args) = logp(x | y, params) AND as an initial guess for x0. This may cause issues if the query point is
420+
# far from the mode x0 or in a neighbourhood which results in poor convergence.
421+
_, logdetTau = pt.nlinalg.slogdet(tau)
422+
log_laplace_approx = 0.5 * logdetTau - 0.5 * x0.shape[0] * np.log(2 * np.pi)
423+
424+
# Reduce logp dimensions corresponding to broadcasted variables
425+
# marginalized_logp = logps_dict.pop(marginalized_vv)
426+
joint_logp = logp - log_laplace_approx
427+
428+
joint_logp = pytensor.graph.replace.graph_replace(joint_logp, {marginalized_vv: x0})
429+
430+
return joint_logp # TODO check if pm.sample adds on p(params). Otherwise this is p(y|params) not p(params|y)

0 commit comments

Comments
 (0)