@@ -2354,8 +2354,59 @@ def __call__(self, W, sigma, zero_sum_stdev, size=None, **kwargs):
23542354 return super ().__call__ (W , sigma , zero_sum_stdev , size = size , ** kwargs )
23552355
23562356 @classmethod
2357- def rng_fn (cls , rng , size , W , sigma , zero_sum_stdev ):
2358- raise NotImplementedError ("Cannot sample from ICAR prior" )
2357+ def rng_fn (cls , rng , W , sigma , zero_sum_stdev , size = None ):
2358+ """Sample from the ICAR distribution.
2359+
2360+ The ICAR distribution is a special case of the CAR distribution with alpha=1.
2361+ It generates spatial random effects where neighboring areas tend to have
2362+ similar values. The precision matrix is the graph Laplacian of W.
2363+
2364+ Parameters
2365+ ----------
2366+ rng : numpy.random.Generator
2367+ Random number generator
2368+ W : ndarray
2369+ Symmetric adjacency matrix
2370+ sigma : float
2371+ Standard deviation parameter
2372+ zero_sum_stdev : float
2373+ Controls how strongly to enforce the zero-sum constraint
2374+ size : tuple, optional
2375+ Size of the samples to generate
2376+
2377+ Returns
2378+ -------
2379+ ndarray
2380+ Samples from the ICAR distribution
2381+ """
2382+ W = np .asarray (W )
2383+ N = W .shape [0 ]
2384+
2385+ # Construct the precision matrix (graph Laplacian)
2386+ D = np .diag (W .sum (axis = 1 ))
2387+ Q = D - W
2388+
2389+ # Add regularization for the zero eigenvalue based on zero_sum_stdev
2390+ zero_sum_precision = 1.0 / (zero_sum_stdev * N )** 2
2391+ Q_reg = Q + zero_sum_precision * np .ones ((N , N )) / N
2392+
2393+ # Use eigendecomposition to handle the degenerate covariance
2394+ eigvals , eigvecs = np .linalg .eigh (Q_reg )
2395+
2396+ # Construct the covariance matrix
2397+ cov = eigvecs @ np .diag (1.0 / eigvals ) @ eigvecs .T
2398+
2399+ # Scale by sigma^2
2400+ cov = sigma ** 2 * cov
2401+
2402+ # Generate samples
2403+ mean = np .zeros (N )
2404+
2405+ # Handle different size specifications
2406+ if size is None :
2407+ return rng .multivariate_normal (mean , cov )
2408+ else :
2409+ return rng .multivariate_normal (mean , cov , size = size )
23592410
23602411
23612412icar = ICARRV ()
0 commit comments