@@ -2156,56 +2156,56 @@ def make_node(self, rng, size, mu, W, alpha, tau, W_is_valid):
21562156 return super ().make_node (rng , size , mu , W , alpha , tau , W_is_valid )
21572157
21582158 @classmethod
2159- def rng_fn (cls , rng : np .random .RandomState , mu , W , alpha , tau , W_is_valid , size ):
2160- """Sample a numeric random variate.
2161-
2162- Implementation of algorithm from paper
2163- Havard Rue, 2001. "Fast sampling of Gaussian Markov random fields,"
2164- Journal of the Royal Statistical Society Series B, Royal Statistical Society,
2165- vol. 63(2), pages 325-338. DOI: 10.1111/1467-9868.00288.
2159+ def rng_fn (cls , rng , mu , W , alpha , tau , W_is_valid , size = None ):
2160+ """Sample from the CAR distribution.
2161+
2162+ Parameters
2163+ ----------
2164+ rng : numpy.random.Generator
2165+ Random number generator
2166+ mu : ndarray
2167+ Mean vector
2168+ W : ndarray
2169+ Symmetric adjacency matrix
2170+ alpha : float
2171+ Autoregression parameter (-1 < alpha < 1)
2172+ tau : float
2173+ Precision parameter (tau > 0)
2174+ W_is_valid : bool
2175+ Flag indicating whether W is a valid adjacency matrix
2176+ size : tuple, optional
2177+ Size of the samples to generate
2178+
2179+ Returns
2180+ -------
2181+ ndarray
2182+ Samples from the CAR distribution
21662183 """
2167- if not W_is_valid . all () :
2168- raise ValueError ("W must be a valid adjacency matrix" )
2184+ if not W_is_valid :
2185+ raise ValueError ("W must be a symmetric adjacency matrix" )
21692186
21702187 if np .any (alpha >= 1 ) or np .any (alpha <= - 1 ):
21712188 raise ValueError ("the domain of alpha is: -1 < alpha < 1" )
2172-
2173- # TODO: If there are batch dims, even if W was already sparse,
2174- # we will have some expensive dense_from_sparse and sparse_from_dense
2175- # operations that we should avoid. See https://github.com/pymc-devs/pytensor/issues/839
2176- W = _squeeze_to_ndim (W , 2 )
2177- if not scipy .sparse .issparse (W ):
2178- W = scipy .sparse .csr_matrix (W )
2179- tau = scipy .sparse .csr_matrix (_squeeze_to_ndim (tau , 0 ))
2180- alpha = scipy .sparse .csr_matrix (_squeeze_to_ndim (alpha , 0 ))
2181-
2182- s = np .asarray (W .sum (axis = 0 ))[0 ]
2183- D = scipy .sparse .diags (s )
2184-
2185- Q = tau .multiply (D - alpha .multiply (W ))
2186-
2187- perm_array = scipy .sparse .csgraph .reverse_cuthill_mckee (Q , symmetric_mode = True )
2188- inv_perm = np .argsort (perm_array )
2189-
2190- Q = Q [perm_array , :][:, perm_array ]
2191-
2192- Qb = Q .diagonal ()
2193- u = 1
2194- while np .count_nonzero (Q .diagonal (u )) > 0 :
2195- Qb = np .vstack ((np .pad (Q .diagonal (u ), (u , 0 ), constant_values = (0 , 0 )), Qb ))
2196- u += 1
2197-
2198- L = scipy .linalg .cholesky_banded (Qb , lower = False )
2199-
2200- size = tuple (size or ())
2201- if size :
2202- mu = np .broadcast_to (mu , (* size , mu .shape [- 1 ]))
2203- z = rng .normal (size = mu .shape )
2204- samples = np .empty (z .shape )
2205- for idx in np .ndindex (mu .shape [:- 1 ]):
2206- samples [idx ] = scipy .linalg .cho_solve_banded ((L , False ), z [idx ]) + mu [idx ][perm_array ]
2207- samples = samples [..., inv_perm ]
2208- return samples
2189+
2190+ W = np .asarray (W )
2191+ N = W .shape [0 ]
2192+
2193+ # Construct the precision matrix
2194+ D = np .diag (W .sum (axis = 1 ))
2195+ Q = tau * (D - alpha * W )
2196+
2197+ # Convert precision to covariance matrix
2198+ cov = np .linalg .inv (Q )
2199+
2200+ # Generate samples using multivariate_normal with covariance matrix
2201+ mean = np .zeros (N ) if mu is None else np .asarray (mu )
2202+
2203+ return stats .multivariate_normal .rvs (
2204+ mean = mean ,
2205+ cov = cov ,
2206+ size = size ,
2207+ random_state = rng
2208+ )
22092209
22102210
22112211car = CARRV ()
@@ -2342,6 +2342,8 @@ def logp(value, mu, W, alpha, tau, W_is_valid):
23422342 W_is_valid ,
23432343 msg = "-1 < alpha < 1, tau > 0, W is a symmetric adjacency matrix." ,
23442344 )
2345+
2346+
23452347
23462348
23472349class ICARRV (RandomVariable ):
@@ -2355,58 +2357,49 @@ def __call__(self, W, sigma, zero_sum_stdev, size=None, **kwargs):
23552357
23562358 @classmethod
23572359 def rng_fn (cls , rng , W , sigma , zero_sum_stdev , size = None ):
2358- """Sample from the ICAR distribution.
2359-
2360- The ICAR distribution is a special case of the CAR distribution with alpha=1.
2361- It generates spatial random effects where neighboring areas tend to have
2362- similar values. The precision matrix is the graph Laplacian of W.
2363-
2364- Parameters
2365- ----------
2366- rng : numpy.random.Generator
2367- Random number generator
2368- W : ndarray
2369- Symmetric adjacency matrix
2370- sigma : float
2371- Standard deviation parameter
2372- zero_sum_stdev : float
2373- Controls how strongly to enforce the zero-sum constraint
2374- size : tuple, optional
2375- Size of the samples to generate
2376-
2377- Returns
2378- -------
2379- ndarray
2380- Samples from the ICAR distribution
2381- """
2382- W = np .asarray (W )
2383- N = W .shape [0 ]
2384-
2385- # Construct the precision matrix (graph Laplacian)
2386- D = np .diag (W .sum (axis = 1 ))
2387- Q = D - W
2388-
2389- # Add regularization for the zero eigenvalue based on zero_sum_stdev
2390- zero_sum_precision = 1.0 / (zero_sum_stdev * N )** 2
2391- Q_reg = Q + zero_sum_precision * np .ones ((N , N )) / N
2392-
2393- # Use eigendecomposition to handle the degenerate covariance
2394- eigvals , eigvecs = np .linalg .eigh (Q_reg )
2395-
2396- # Construct the covariance matrix
2397- cov = eigvecs @ np .diag (1.0 / eigvals ) @ eigvecs .T
2398-
2399- # Scale by sigma^2
2400- cov = sigma ** 2 * cov
2401-
2402- # Generate samples
2403- mean = np .zeros (N )
2404-
2405- # Handle different size specifications
2406- if size is None :
2407- return rng .multivariate_normal (mean , cov )
2408- else :
2409- return rng .multivariate_normal (mean , cov , size = size )
2360+ """Sample from the ICAR distribution.
2361+
2362+ Parameters
2363+ ----------
2364+ rng : numpy.random.Generator
2365+ Random number generator
2366+ W : ndarray
2367+ Symmetric adjacency matrix
2368+ sigma : float
2369+ Standard deviation parameter
2370+ zero_sum_stdev : float
2371+ Controls how strongly to enforce the zero-sum constraint
2372+ size : tuple, optional
2373+ Size of the samples to generate
2374+
2375+ Returns
2376+ -------
2377+ ndarray
2378+ Samples from the ICAR distribution
2379+ """
2380+ W = np .asarray (W )
2381+ N = W .shape [0 ]
2382+
2383+ # Construct the precision matrix (graph Laplacian)
2384+ D = np .diag (W .sum (axis = 1 ))
2385+ Q = D - W
2386+
2387+ # Add regularization for the zero eigenvalue based on zero_sum_stdev
2388+ zero_sum_precision = 1.0 / (zero_sum_stdev * N )** 2
2389+ Q_reg = Q + zero_sum_precision * np .ones ((N , N )) / N
2390+
2391+ # Convert precision to covariance matrix
2392+ cov = np .linalg .inv (Q_reg )
2393+
2394+ # Generate samples using multivariate_normal with covariance matrix
2395+ mean = np .zeros (N )
2396+
2397+ return sigma * stats .multivariate_normal .rvs (
2398+ mean = mean ,
2399+ cov = cov ,
2400+ size = size ,
2401+ random_state = rng
2402+ )
24102403
24112404
24122405icar = ICARRV ()
0 commit comments