Skip to content

Commit e367958

Browse files
updated inla docstring
1 parent e032c25 commit e367958

File tree

1 file changed

+42
-6
lines changed

1 file changed

+42
-6
lines changed

pymc_extras/inference/INLA/inla.py

Lines changed: 42 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,7 @@ def fit_INLA(
2424
2525
Where the prior on the hyperparameters :math:`\pi(\theta)` is arbitrary, the prior on the latent field is Gaussian (and in precision form): :math:`\pi(x) = N(\mu, Q^{-1})` and the latent field is linked to the observables $y$ through some linear map.
2626
27-
As it stands, INLA in PyMC Extras has three main limitations:
28-
29-
- Does not support inference over the latent field, only the hyperparameters.
30-
- Optimisation for :math:`\mu^*` is bottlenecked by calling `minimize`, and to a lesser extent, computing the hessian :math:`f^"(x)`.
31-
- Does not offer sparse support which can provide significant speedups.
27+
As it stands, INLA in PyMC Extras is currently experimental.
3228
3329
Parameters
3430
----------
@@ -46,11 +42,51 @@ def fit_INLA(
4642
If True, also return posteriors for the latent Gaussian field (currently unsupported).
4743
sampler_kwargs:
4844
Kwargs to pass to pm.sample.
45+
46+
Returns
47+
-------
48+
idata: az.InferenceData
49+
Standard PyMC InferenceData instance.
50+
51+
Examples
52+
--------
53+
.. code:: ipython
54+
55+
In [1]: rng = np.random.default_rng(123)
56+
...: n = 10000
57+
...: d = 3
58+
...: mu_mu = 10 * rng.random(d)
59+
...: mu_true = rng.random(d)
60+
...: tau = np.identity(d)
61+
...: cov = np.linalg.inv(tau)
62+
...: y_obs = rng.multivariate_normal(mean=mu_true, cov=cov, size=n)
63+
64+
In [2]: with pm.Model() as model:
65+
...: mu = pm.MvNormal("mu", mu=mu_mu, tau=tau)
66+
...: x = pm.MvNormal("x", mu=mu, tau=tau)
67+
...: y = pm.MvNormal("y", mu=x, tau=tau, observed=y_obs)
68+
69+
...: idata = pmx.fit(
70+
...: method="INLA",
71+
...: x=x,
72+
...: Q=tau,
73+
...: return_latent_posteriors=False,
74+
...: )
75+
76+
In[3]: posterior_mean_true = (mu_mu + mu_true) / 2
77+
...: posterior_mean_inla = idata.posterior.mu.mean(axis=(0, 1)).values
78+
...: print(posterior_mean_true)
79+
...: print(posterior_mean_inla)
80+
81+
Out[3]:
82+
[3.50394522 0.35705804 1.50784662]
83+
[3.48732847 0.35738072 1.46851421]
84+
4985
"""
5086
model = pm.modelcontext(model)
5187

5288
# Get the TensorVariable if Q is provided as an RV
53-
if Q in model.rvs_to_values.keys():
89+
if isinstance(Q, TensorVariable) and Q in model.rvs_to_values.keys():
5490
Q = model.rvs_to_values[Q]
5591

5692
# Marginalize out the latent field

0 commit comments

Comments
 (0)