Skip to content

Commit 134d5c9

Browse files
end-to-end implementation
1 parent f4c2416 commit 134d5c9

File tree

5 files changed

+383
-605
lines changed

5 files changed

+383
-605
lines changed

notebooks/INLA_testing.ipynb

Lines changed: 236 additions & 487 deletions
Large diffs are not rendered by default.

pymc_extras/inference/inla.py

Lines changed: 30 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@
66

77
from better_optimize.constants import minimize_method
88
from numpy.typing import ArrayLike
9+
from pymc.distributions.multivariate import MvNormal
910
from pytensor.tensor import TensorVariable
11+
from pytensor.tensor.linalg import inv as matrix_inverse
1012
from pytensor.tensor.optimize import minimize
1113

1214
from pymc_extras.model.marginal.marginal_model import marginalize
@@ -67,6 +69,8 @@ def get_conditional_gaussian_approximation(
6769
x0, p(x | y, params): list[TensorVariable]
6870
Mode and Laplace approximation for posterior.
6971
"""
72+
raise DeprecationWarning("Legacy code. Please use fit_INLA instead.")
73+
7074
model = pm.modelcontext(model)
7175

7276
# f = log(p(y | x, params))
@@ -108,6 +112,8 @@ def get_log_marginal_likelihood(
108112
use_hess: bool = False,
109113
optimizer_kwargs: dict | None = None,
110114
) -> TensorVariable:
115+
raise DeprecationWarning("Legacy code. Please use fit_INLA instead.")
116+
111117
model = pm.modelcontext(model)
112118

113119
x0, log_laplace_approx = get_conditional_gaussian_approximation(
@@ -134,43 +140,40 @@ def get_log_marginal_likelihood(
134140

135141
def fit_INLA(
136142
x: TensorVariable,
137-
Q: TensorVariable | ArrayLike,
138-
# mu: TensorVariable | ArrayLike,
143+
temp_kwargs=None, # TODO REMOVE. DEBUGGING TOOL
139144
model: pm.Model | None = None,
140145
minimizer_kwargs: dict | None = None,
146+
return_latent_posteriors: bool = True,
141147
**sampler_kwargs,
142148
) -> az.InferenceData:
143149
model = pm.modelcontext(model)
144150

145-
# Marginalize out the latent field
146-
marginalize(model, [x], Q, minimizer_kwargs, method="INLA")
151+
# Check if latent field is Gaussian
152+
if not isinstance(x.owner.op, MvNormal):
153+
raise ValueError(
154+
f"Latent field {x} is not instance of MvNormal. Has distribution {x.owner.op}."
155+
)
147156

148-
# Sample over the hyperparameters
149-
pm.sample(model=model, **sampler_kwargs)
157+
_, _, _, tau = x.owner.inputs
150158

151-
# # logp(y | params)
152-
# x0, log_likelihood = get_log_marginal_likelihood(
153-
# x, Q, mu, model, method, use_jac, use_hess, optimizer_kwargs
154-
# )
159+
# Latent field should use precison rather than covariance
160+
if not tau.owner or tau.owner.op != matrix_inverse:
161+
raise ValueError(
162+
f"Latent field {x} is not in precision matrix form. Use MvNormal(tau=Q) instead."
163+
)
155164

156-
# # TODO How to obtain prior? It can parametrise Q, mu, y, etc. Not sure if we could extract from model.logp somehow. Otherwise simply specify as a user input
157-
# # Perhaps obtain as RVs which y depends on which aren't x?
158-
# prior = None
159-
# params = None
160-
# log_prior = pm.logp(prior, model.rvs_to_values[params])
165+
Q = tau.owner.inputs[0]
161166

162-
# # logp(params | y) = logp(y | params) + logp(params) + const
163-
# log_posterior = log_likelihood + log_prior
164-
# log_posterior = pytensor.graph.replace.graph_replace(log_posterior, {x: x0})
167+
# Marginalize out the latent field
168+
minimizer_kwargs = {"method": "L-BFGS-B", "optimizer_kwargs": {"tol": 1e-8}}
169+
marginalize_kwargs = {"Q": Q, "temp_kwargs": temp_kwargs, "minimizer_kwargs": minimizer_kwargs}
170+
marginal_model = marginalize(model, x, use_laplace=True, **marginalize_kwargs)
165171

166-
# # TODO log_marginal_x_likelihood is almost the same as log_likelihood, but need to do some sampling?
167-
# log_marginal_x_likelihood = None
168-
# log_marginal_x_posterior = log_marginal_x_likelihood + log_prior
172+
# Sample over the hyperparameters
173+
idata = pm.sample(model=marginal_model, **sampler_kwargs)
169174

170-
# # TODO can we sample over log likelihoods?
171-
# # Marginalize params
172-
# idata_params = log_posterior.sample() # TODO something like NUTS, QMC, etc.?
173-
# idata_x = log_marginal_x_posterior.sample()
175+
if not return_latent_posteriors:
176+
return idata
174177

175-
# Bundle up idatas somehow
176-
# return idata_params, idata_x
178+
# TODO Unmarginalize stuff
179+
raise NotImplementedError("Latent posteriors not supported yet, WIP.")

pymc_extras/model/marginal/distributions.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
from pymc.distributions import Bernoulli, Categorical, DiscreteUniform
1010
from pymc.distributions.distribution import _support_point, support_point
11+
from pymc.distributions.multivariate import _precision_mv_normal_logp
1112
from pymc.logprob.abstract import MeasurableOp, _logprob
1213
from pymc.logprob.basic import conditional_logp, logp
1314
from pymc.pytensorf import constant_fold
@@ -414,7 +415,7 @@ def laplace_marginal_rv_logp(op: MarginalLaplaceRV, values, *inputs, **kwargs):
414415
minimizer_kwargs = (
415416
op.minimizer_kwargs
416417
if op.minimizer_kwargs is not None
417-
else {"method": "BFGS", "optimizer_kwargs": {"tol": 1e-8}}
418+
else {"method": "L-BFGS-B", "optimizer_kwargs": {"tol": 1e-8}}
418419
)
419420

420421
x0, _ = minimize(
@@ -423,23 +424,23 @@ def laplace_marginal_rv_logp(op: MarginalLaplaceRV, values, *inputs, **kwargs):
423424
**minimizer_kwargs,
424425
)
425426

426-
# # Set minimizer initialisation to be random
427+
# Set minimizer initialisation to be random
427428
d = 3 # 10000 # TODO pull this from x.shape (or similar) somehow
428429
rng = np.random.default_rng(12345)
429430
x0 = pytensor.graph.replace.graph_replace(x0, {marginalized_vv: rng.random(d)})
430431

431432
# TODO USE CLOSED FORM SOLUTION FOR NOW
432-
# n, y_obs = op.temp_kwargs
433-
# mu_param = pytensor.graph.basic.get_var_by_name(x, "mu_param")[0]
434-
# x0 = (y_obs.sum(axis=0) - mu_param) / (n - 1)
433+
n, y_obs = op.temp_kwargs
434+
mu_param = pytensor.graph.basic.get_var_by_name(x, "mu")[0]
435+
x0 = (y_obs.sum(axis=0) - mu_param) / (n - 1)
435436

436437
# logp(x | y, params) using laplace approx evaluated at x0
437-
hess = pytensor.gradient.hessian(log_likelihood, marginalized_vv)
438+
hess = pytensor.gradient.hessian(
439+
log_likelihood, marginalized_vv
440+
) # TODO check how stan makes this quicker
438441
tau = op.Q - hess
439-
_, logdetTau = pt.nlinalg.slogdet(tau)
440-
log_laplace_approx = 0.5 * logdetTau - 0.5 * marginalized_vv.shape[0] * np.log(
441-
2 * np.pi
442-
) # At x = x0, the quadratic term becomes 0
442+
mu = x0 # TODO double check with Theo
443+
log_laplace_approx, _ = _precision_mv_normal_logp(x0, mu, tau)
443444

444445
# logp(y | params) = logp(y | x, params) + logp(x | params) - logp(x | y, params)
445446
marginal_likelihood = logp - log_laplace_approx

pymc_extras/model/marginal/marginal_model.py

Lines changed: 41 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from arviz import InferenceData, dict_to_dataset
1010
from pymc.backends.arviz import coords_and_dims_for_inferencedata, dataset_to_point_list
1111
from pymc.distributions.discrete import Bernoulli, Categorical, DiscreteUniform
12+
from pymc.distributions.multivariate import MvNormal
1213
from pymc.distributions.transforms import Chain
1314
from pymc.logprob.transforms import IntervalTransform
1415
from pymc.model import Model
@@ -45,6 +46,7 @@
4546
from pymc_extras.model.marginal.distributions import (
4647
MarginalDiscreteMarkovChainRV,
4748
MarginalFiniteDiscreteRV,
49+
MarginalLaplaceRV,
4850
MarginalRV,
4951
NonSeparableLogpWarning,
5052
get_domain_of_finite_discrete_rv,
@@ -144,7 +146,9 @@ def _unique(seq: Sequence) -> list:
144146
return [x for x in seq if not (x in seen or seen_add(x))]
145147

146148

147-
def marginalize(model: Model, rvs_to_marginalize: ModelRVs) -> MarginalModel:
149+
def marginalize(
150+
model: Model, rvs_to_marginalize: ModelRVs, use_laplace: bool = False, **marginalize_kwargs
151+
) -> MarginalModel:
148152
"""Marginalize a subset of variables in a PyMC model.
149153
150154
This creates a class of `MarginalModel` from an existing `Model`, with the specified
@@ -158,6 +162,8 @@ def marginalize(model: Model, rvs_to_marginalize: ModelRVs) -> MarginalModel:
158162
PyMC model to marginalize. Original variables well be cloned.
159163
rvs_to_marginalize : Sequence[TensorVariable]
160164
Variables to marginalize in the returned model.
165+
use_laplace : bool
166+
Whether to use Laplace appoximations to marginalize out rvs_to_marginalize.
161167
162168
Returns
163169
-------
@@ -186,7 +192,12 @@ def marginalize(model: Model, rvs_to_marginalize: ModelRVs) -> MarginalModel:
186192
raise NotImplementedError(
187193
"Marginalization for DiscreteMarkovChain with non-matrix transition probability is not supported"
188194
)
189-
elif not isinstance(rv_op, Bernoulli | Categorical | DiscreteUniform):
195+
elif use_laplace and not isinstance(rv_op, MvNormal):
196+
raise ValueError(
197+
f"Marginalisation method set to Laplace but RV {rv_to_marginalize} is not instance of MvNormal. Has distribution {rv_to_marginalize.owner.op}"
198+
)
199+
200+
elif not use_laplace and not isinstance(rv_op, Bernoulli | Categorical | DiscreteUniform):
190201
raise NotImplementedError(
191202
f"Marginalization of RV with distribution {rv_to_marginalize.owner.op} is not supported"
192203
)
@@ -241,7 +252,9 @@ def marginalize(model: Model, rvs_to_marginalize: ModelRVs) -> MarginalModel:
241252
]
242253
input_rvs = _unique((*marginalized_rv_input_rvs, *other_direct_rv_ancestors))
243254

244-
replace_finite_discrete_marginal_subgraph(fg, rv_to_marginalize, dependent_rvs, input_rvs)
255+
replace_marginal_subgraph(
256+
fg, rv_to_marginalize, dependent_rvs, input_rvs, use_laplace, **marginalize_kwargs
257+
)
245258

246259
return model_from_fgraph(fg, mutate_fgraph=True)
247260

@@ -551,22 +564,32 @@ def remove_model_vars(vars):
551564
return fgraph.outputs
552565

553566

554-
def replace_finite_discrete_marginal_subgraph(
555-
fgraph, rv_to_marginalize, dependent_rvs, input_rvs
567+
def replace_marginal_subgraph(
568+
fgraph,
569+
rv_to_marginalize,
570+
dependent_rvs,
571+
input_rvs,
572+
use_laplace=False,
573+
**marginalize_kwargs,
556574
) -> None:
557575
# If the marginalized RV has multiple dimensions, check that graph between
558576
# marginalized RV and dependent RVs does not mix information from batch dimensions
559577
# (otherwise logp would require enumerating over all combinations of batch dimension values)
560-
try:
561-
dependent_rvs_dim_connections = subgraph_batch_dim_connection(
562-
rv_to_marginalize, dependent_rvs
563-
)
564-
except (ValueError, NotImplementedError) as e:
565-
# For the perspective of the user this is a NotImplementedError
566-
raise NotImplementedError(
567-
"The graph between the marginalized and dependent RVs cannot be marginalized efficiently. "
568-
"You can try splitting the marginalized RV into separate components and marginalizing them separately."
569-
) from e
578+
if not use_laplace:
579+
try:
580+
dependent_rvs_dim_connections = subgraph_batch_dim_connection(
581+
rv_to_marginalize, dependent_rvs
582+
)
583+
except (ValueError, NotImplementedError) as e:
584+
# For the perspective of the user this is a NotImplementedError
585+
raise NotImplementedError(
586+
"The graph between the marginalized and dependent RVs cannot be marginalized efficiently. "
587+
"You can try splitting the marginalized RV into separate components and marginalizing them separately."
588+
) from e
589+
else:
590+
dependent_rvs_dim_connections = [
591+
(None,),
592+
]
570593

571594
output_rvs = [rv_to_marginalize, *dependent_rvs]
572595
rng_updates = collect_default_updates(output_rvs, inputs=input_rvs, must_be_shared=False)
@@ -581,6 +604,8 @@ def replace_finite_discrete_marginal_subgraph(
581604

582605
if isinstance(inner_outputs[0].owner.op, DiscreteMarkovChain):
583606
marginalize_constructor = MarginalDiscreteMarkovChainRV
607+
elif use_laplace:
608+
marginalize_constructor = MarginalLaplaceRV
584609
else:
585610
marginalize_constructor = MarginalFiniteDiscreteRV
586611

@@ -590,6 +615,7 @@ def replace_finite_discrete_marginal_subgraph(
590615
outputs=inner_outputs,
591616
dims_connections=dependent_rvs_dim_connections,
592617
dims=dims,
618+
**marginalize_kwargs,
593619
)
594620

595621
new_outputs = marginalization_op(*inputs)

0 commit comments

Comments
 (0)