Skip to content

Commit 92f6a0f

Browse files
removed temp_kwargs, made Q amenable to RVs, removed dependency on MvNormal
1 parent 12b109f commit 92f6a0f

File tree

4 files changed

+61
-38
lines changed

4 files changed

+61
-38
lines changed

pymc_extras/inference/INLA/inla.py

Lines changed: 24 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,49 +1,54 @@
1-
import warnings
2-
31
import arviz as az
42
import pymc as pm
53

6-
from pymc.distributions.multivariate import MvNormal
74
from pytensor.tensor import TensorVariable
8-
from pytensor.tensor.linalg import inv as matrix_inverse
95

106
from pymc_extras.model.marginal.marginal_model import marginalize
117

128

139
def fit_INLA(
1410
x: TensorVariable,
15-
temp_kwargs=None, # TODO REMOVE. DEBUGGING TOOL
11+
Q: TensorVariable,
12+
minimizer_seed: int = 42,
1613
model: pm.Model | None = None,
1714
minimizer_kwargs: dict | None = None,
1815
return_latent_posteriors: bool = True,
1916
**sampler_kwargs,
2017
) -> az.InferenceData:
21-
warnings.warn("Currently only valid for a nested normal model. WIP.", UserWarning)
22-
2318
model = pm.modelcontext(model)
2419

2520
# Check if latent field is Gaussian
26-
if not isinstance(x.owner.op, MvNormal):
27-
raise ValueError(
28-
f"Latent field {x} is not instance of MvNormal. Has distribution {x.owner.op}."
29-
)
21+
# if not isinstance(x.owner.op, MvNormal):
22+
# raise ValueError(
23+
# f"Latent field {x} is not instance of MvNormal. Has distribution {x.owner.op}."
24+
# )
25+
26+
# _, _, _, tau = x.owner.inputs
3027

31-
_, _, _, tau = x.owner.inputs
28+
# # Latent field should use precison rather than covariance
29+
# if not (tau.owner and tau.owner.op == matrix_inverse):
30+
# raise ValueError(
31+
# f"Latent field {x} is not in precision matrix form. Use MvNormal(tau=Q) instead."
32+
# )
3233

33-
# Latent field should use precison rather than covariance
34-
if not (tau.owner and tau.owner.op == matrix_inverse):
35-
raise ValueError(
36-
f"Latent field {x} is not in precision matrix form. Use MvNormal(tau=Q) instead."
37-
)
34+
# Q = tau.owner.inputs[0]
3835

39-
Q = tau.owner.inputs[0]
36+
# TODO is there a better way to check if it's a RV?
37+
# print(vars(Q.owner))
38+
# if isinstance(Q, TensorVariable) and "module" in vars(Q.owner):
39+
Q = model.rvs_to_values[Q]
4040

4141
# Marginalize out the latent field
4242
minimizer_kwargs = {"method": "L-BFGS-B", "optimizer_kwargs": {"tol": 1e-8}}
43-
marginalize_kwargs = {"Q": Q, "temp_kwargs": temp_kwargs, "minimizer_kwargs": minimizer_kwargs}
43+
marginalize_kwargs = {
44+
"Q": Q,
45+
"minimizer_seed": minimizer_seed,
46+
"minimizer_kwargs": minimizer_kwargs,
47+
}
4448
marginal_model = marginalize(model, x, use_laplace=True, **marginalize_kwargs)
4549

4650
# Sample over the hyperparameters
51+
# marginal_model.logp().dprint()
4752
idata = pm.sample(model=marginal_model, **sampler_kwargs)
4853

4954
if not return_latent_posteriors:

pymc_extras/model/marginal/distributions.py

Lines changed: 32 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from pymc.distributions import Bernoulli, Categorical, DiscreteUniform
1010
from pymc.distributions.distribution import _support_point, support_point
1111
from pymc.distributions.multivariate import _logdet_from_cholesky, nan_lower_cholesky
12+
from pymc.logprob import ValuedRV
1213
from pymc.logprob.abstract import MeasurableOp, _logprob
1314
from pymc.logprob.basic import conditional_logp, logp
1415
from pymc.pytensorf import constant_fold
@@ -142,12 +143,12 @@ def __init__(
142143
self,
143144
*args,
144145
Q: TensorVariable,
145-
temp_kwargs: list,
146+
minimizer_seed: int,
146147
minimizer_kwargs: dict | None = None,
147148
**kwargs,
148149
) -> None:
149-
self.temp_kwargs = temp_kwargs # TODO REMOVE
150150
self.Q = Q
151+
self.minimizer_seed = minimizer_seed
151152
self.minimizer_kwargs = minimizer_kwargs
152153
super().__init__(*args, **kwargs)
153154

@@ -440,21 +441,42 @@ def laplace_marginal_rv_logp(op: MarginalLaplaceRV, values, *inputs, **kwargs):
440441
# Set minimizer initialisation to be random
441442
# TODO Assumes that the observed variable y is the first/only element of values, and that d is shape[-1]
442443
d = values[0].data.shape[-1]
443-
rng = np.random.default_rng(12345)
444+
rng = np.random.default_rng(op.minimizer_seed)
444445
x0_init = rng.random(d)
445446
x0 = pytensor.graph.replace.graph_replace(x0, {marginalized_vv: x0_init})
446447

447-
# TODO USE CLOSED FORM SOLUTION FOR NOW
448-
n, y_obs = op.temp_kwargs
449-
mu_param = pytensor.graph.basic.get_var_by_name(x, "mu")[0]
450-
x0 = (y_obs.sum(axis=0) - mu_param) / (n - 1)
451-
452448
# logp(x | y, params) using laplace approx evaluated at x0
453449
hess = pytensor.gradient.hessian(
454450
log_likelihood, marginalized_vv
455451
) # TODO check how stan makes this quicker
456-
tau = op.Q - hess
457-
mu = x0 # TODO double check with Theo
452+
453+
# Get Q from the list of inputs
454+
Q = None
455+
if isinstance(op.Q, TensorVariable):
456+
for var in inputs:
457+
if var.owner is not None and isinstance(var.owner.op, ValuedRV):
458+
for inp in var.owner.inputs:
459+
if (
460+
inp.name is not None
461+
and inp.name == op.Q.name
462+
or inp.name == op.Q.name + "_log"
463+
):
464+
Q = var
465+
break
466+
467+
if var.name is not None and var.name == op.Q.name or var.name == op.Q.name + "_log":
468+
Q = var
469+
break
470+
471+
if Q is None:
472+
raise ValueError(f"No inputs could be matched to precision matrix {op.Q}: {inputs}.")
473+
474+
# Q is an array
475+
else:
476+
Q = op.Q
477+
478+
tau = Q - hess
479+
mu = x0
458480
log_laplace_approx, _ = _precision_mv_normal_logp(x0, mu, tau)
459481

460482
# logp(y | params) = logp(y | x, params) + logp(x | params) - logp(x | y, params)

pymc_extras/model/marginal/marginal_model.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
from arviz import InferenceData, dict_to_dataset
1010
from pymc.backends.arviz import coords_and_dims_for_inferencedata, dataset_to_point_list
1111
from pymc.distributions.discrete import Bernoulli, Categorical, DiscreteUniform
12-
from pymc.distributions.multivariate import MvNormal
1312
from pymc.distributions.transforms import Chain
1413
from pymc.logprob.transforms import IntervalTransform
1514
from pymc.model import Model
@@ -192,10 +191,10 @@ def marginalize(
192191
raise NotImplementedError(
193192
"Marginalization for DiscreteMarkovChain with non-matrix transition probability is not supported"
194193
)
195-
elif use_laplace and not isinstance(rv_op, MvNormal):
196-
raise ValueError(
197-
f"Marginalisation method set to Laplace but RV {rv_to_marginalize} is not instance of MvNormal. Has distribution {rv_to_marginalize.owner.op}"
198-
)
194+
# elif use_laplace and not isinstance(rv_op, MvNormal):
195+
# raise ValueError(
196+
# f"Marginalisation method set to Laplace but RV {rv_to_marginalize} is not instance of MvNormal. Has distribution {rv_to_marginalize.owner.op}"
197+
# )
199198

200199
elif not use_laplace and not isinstance(rv_op, Bernoulli | Categorical | DiscreteUniform):
201200
raise NotImplementedError(
@@ -587,9 +586,7 @@ def replace_marginal_subgraph(
587586
"You can try splitting the marginalized RV into separate components and marginalizing them separately."
588587
) from e
589588
else:
590-
dependent_rvs_dim_connections = [
591-
(None,),
592-
]
589+
dependent_rvs_dim_connections = None
593590

594591
output_rvs = [rv_to_marginalize, *dependent_rvs]
595592
rng_updates = collect_default_updates(output_rvs, inputs=input_rvs, must_be_shared=False)

tests/inference/INLA/test_inla.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,6 @@ def test_3_layer_normal(rng):
9797
idata = pmx.fit(
9898
method="INLA",
9999
x=x,
100-
temp_kwargs=[n, y_obs], # TODO REMOVE LATER - DEBUGGING TOOL
101100
return_latent_posteriors=False,
102101
)
103102

0 commit comments

Comments
 (0)