Skip to content

Commit 487dfe2

Browse files
Implement the rng function
1 parent f9057f4 commit 487dfe2

File tree

1 file changed

+53
-2
lines changed

1 file changed

+53
-2
lines changed

pymc/distributions/multivariate.py

Lines changed: 53 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2354,8 +2354,59 @@ def __call__(self, W, sigma, zero_sum_stdev, size=None, **kwargs):
23542354
return super().__call__(W, sigma, zero_sum_stdev, size=size, **kwargs)
23552355

23562356
@classmethod
2357-
def rng_fn(cls, rng, size, W, sigma, zero_sum_stdev):
2358-
raise NotImplementedError("Cannot sample from ICAR prior")
2357+
def rng_fn(cls, rng, W, sigma, zero_sum_stdev, size=None):
2358+
"""Sample from the ICAR distribution.
2359+
2360+
The ICAR distribution is a special case of the CAR distribution with alpha=1.
2361+
It generates spatial random effects where neighboring areas tend to have
2362+
similar values. The precision matrix is the graph Laplacian of W.
2363+
2364+
Parameters
2365+
----------
2366+
rng : numpy.random.Generator
2367+
Random number generator
2368+
W : ndarray
2369+
Symmetric adjacency matrix
2370+
sigma : float
2371+
Standard deviation parameter
2372+
zero_sum_stdev : float
2373+
Controls how strongly to enforce the zero-sum constraint
2374+
size : tuple, optional
2375+
Size of the samples to generate
2376+
2377+
Returns
2378+
-------
2379+
ndarray
2380+
Samples from the ICAR distribution
2381+
"""
2382+
W = np.asarray(W)
2383+
N = W.shape[0]
2384+
2385+
# Construct the precision matrix (graph Laplacian)
2386+
D = np.diag(W.sum(axis=1))
2387+
Q = D - W
2388+
2389+
# Add regularization for the zero eigenvalue based on zero_sum_stdev
2390+
zero_sum_precision = 1.0 / (zero_sum_stdev * N)**2
2391+
Q_reg = Q + zero_sum_precision * np.ones((N, N)) / N
2392+
2393+
# Use eigendecomposition to handle the degenerate covariance
2394+
eigvals, eigvecs = np.linalg.eigh(Q_reg)
2395+
2396+
# Construct the covariance matrix
2397+
cov = eigvecs @ np.diag(1.0 / eigvals) @ eigvecs.T
2398+
2399+
# Scale by sigma^2
2400+
cov = sigma**2 * cov
2401+
2402+
# Generate samples
2403+
mean = np.zeros(N)
2404+
2405+
# Handle different size specifications
2406+
if size is None:
2407+
return rng.multivariate_normal(mean, cov)
2408+
else:
2409+
return rng.multivariate_normal(mean, cov, size=size)
23592410

23602411

23612412
icar = ICARRV()

0 commit comments

Comments
 (0)