Skip to content

Commit 3fe3550

Browse files
Modified the rng_fn and added the test
1 parent 487dfe2 commit 3fe3550

File tree

2 files changed

+157
-104
lines changed

2 files changed

+157
-104
lines changed

pymc/distributions/multivariate.py

Lines changed: 91 additions & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -2156,56 +2156,56 @@ def make_node(self, rng, size, mu, W, alpha, tau, W_is_valid):
21562156
return super().make_node(rng, size, mu, W, alpha, tau, W_is_valid)
21572157

21582158
@classmethod
2159-
def rng_fn(cls, rng: np.random.RandomState, mu, W, alpha, tau, W_is_valid, size):
2160-
"""Sample a numeric random variate.
2161-
2162-
Implementation of algorithm from paper
2163-
Havard Rue, 2001. "Fast sampling of Gaussian Markov random fields,"
2164-
Journal of the Royal Statistical Society Series B, Royal Statistical Society,
2165-
vol. 63(2), pages 325-338. DOI: 10.1111/1467-9868.00288.
2159+
def rng_fn(cls, rng, mu, W, alpha, tau, W_is_valid, size=None):
2160+
"""Sample from the CAR distribution.
2161+
2162+
Parameters
2163+
----------
2164+
rng : numpy.random.Generator
2165+
Random number generator
2166+
mu : ndarray
2167+
Mean vector
2168+
W : ndarray
2169+
Symmetric adjacency matrix
2170+
alpha : float
2171+
Autoregression parameter (-1 < alpha < 1)
2172+
tau : float
2173+
Precision parameter (tau > 0)
2174+
W_is_valid : bool
2175+
Flag indicating whether W is a valid adjacency matrix
2176+
size : tuple, optional
2177+
Size of the samples to generate
2178+
2179+
Returns
2180+
-------
2181+
ndarray
2182+
Samples from the CAR distribution
21662183
"""
2167-
if not W_is_valid.all():
2168-
raise ValueError("W must be a valid adjacency matrix")
2184+
if not W_is_valid:
2185+
raise ValueError("W must be a symmetric adjacency matrix")
21692186

21702187
if np.any(alpha >= 1) or np.any(alpha <= -1):
21712188
raise ValueError("the domain of alpha is: -1 < alpha < 1")
2172-
2173-
# TODO: If there are batch dims, even if W was already sparse,
2174-
# we will have some expensive dense_from_sparse and sparse_from_dense
2175-
# operations that we should avoid. See https://github.com/pymc-devs/pytensor/issues/839
2176-
W = _squeeze_to_ndim(W, 2)
2177-
if not scipy.sparse.issparse(W):
2178-
W = scipy.sparse.csr_matrix(W)
2179-
tau = scipy.sparse.csr_matrix(_squeeze_to_ndim(tau, 0))
2180-
alpha = scipy.sparse.csr_matrix(_squeeze_to_ndim(alpha, 0))
2181-
2182-
s = np.asarray(W.sum(axis=0))[0]
2183-
D = scipy.sparse.diags(s)
2184-
2185-
Q = tau.multiply(D - alpha.multiply(W))
2186-
2187-
perm_array = scipy.sparse.csgraph.reverse_cuthill_mckee(Q, symmetric_mode=True)
2188-
inv_perm = np.argsort(perm_array)
2189-
2190-
Q = Q[perm_array, :][:, perm_array]
2191-
2192-
Qb = Q.diagonal()
2193-
u = 1
2194-
while np.count_nonzero(Q.diagonal(u)) > 0:
2195-
Qb = np.vstack((np.pad(Q.diagonal(u), (u, 0), constant_values=(0, 0)), Qb))
2196-
u += 1
2197-
2198-
L = scipy.linalg.cholesky_banded(Qb, lower=False)
2199-
2200-
size = tuple(size or ())
2201-
if size:
2202-
mu = np.broadcast_to(mu, (*size, mu.shape[-1]))
2203-
z = rng.normal(size=mu.shape)
2204-
samples = np.empty(z.shape)
2205-
for idx in np.ndindex(mu.shape[:-1]):
2206-
samples[idx] = scipy.linalg.cho_solve_banded((L, False), z[idx]) + mu[idx][perm_array]
2207-
samples = samples[..., inv_perm]
2208-
return samples
2189+
2190+
W = np.asarray(W)
2191+
N = W.shape[0]
2192+
2193+
# Construct the precision matrix
2194+
D = np.diag(W.sum(axis=1))
2195+
Q = tau * (D - alpha * W)
2196+
2197+
# Convert precision to covariance matrix
2198+
cov = np.linalg.inv(Q)
2199+
2200+
# Generate samples using multivariate_normal with covariance matrix
2201+
mean = np.zeros(N) if mu is None else np.asarray(mu)
2202+
2203+
return stats.multivariate_normal.rvs(
2204+
mean=mean,
2205+
cov=cov,
2206+
size=size,
2207+
random_state=rng
2208+
)
22092209

22102210

22112211
car = CARRV()
@@ -2342,6 +2342,8 @@ def logp(value, mu, W, alpha, tau, W_is_valid):
23422342
W_is_valid,
23432343
msg="-1 < alpha < 1, tau > 0, W is a symmetric adjacency matrix.",
23442344
)
2345+
2346+
23452347

23462348

23472349
class ICARRV(RandomVariable):
@@ -2355,58 +2357,49 @@ def __call__(self, W, sigma, zero_sum_stdev, size=None, **kwargs):
23552357

23562358
@classmethod
23572359
def rng_fn(cls, rng, W, sigma, zero_sum_stdev, size=None):
2358-
"""Sample from the ICAR distribution.
2359-
2360-
The ICAR distribution is a special case of the CAR distribution with alpha=1.
2361-
It generates spatial random effects where neighboring areas tend to have
2362-
similar values. The precision matrix is the graph Laplacian of W.
2363-
2364-
Parameters
2365-
----------
2366-
rng : numpy.random.Generator
2367-
Random number generator
2368-
W : ndarray
2369-
Symmetric adjacency matrix
2370-
sigma : float
2371-
Standard deviation parameter
2372-
zero_sum_stdev : float
2373-
Controls how strongly to enforce the zero-sum constraint
2374-
size : tuple, optional
2375-
Size of the samples to generate
2376-
2377-
Returns
2378-
-------
2379-
ndarray
2380-
Samples from the ICAR distribution
2381-
"""
2382-
W = np.asarray(W)
2383-
N = W.shape[0]
2384-
2385-
# Construct the precision matrix (graph Laplacian)
2386-
D = np.diag(W.sum(axis=1))
2387-
Q = D - W
2388-
2389-
# Add regularization for the zero eigenvalue based on zero_sum_stdev
2390-
zero_sum_precision = 1.0 / (zero_sum_stdev * N)**2
2391-
Q_reg = Q + zero_sum_precision * np.ones((N, N)) / N
2392-
2393-
# Use eigendecomposition to handle the degenerate covariance
2394-
eigvals, eigvecs = np.linalg.eigh(Q_reg)
2395-
2396-
# Construct the covariance matrix
2397-
cov = eigvecs @ np.diag(1.0 / eigvals) @ eigvecs.T
2398-
2399-
# Scale by sigma^2
2400-
cov = sigma**2 * cov
2401-
2402-
# Generate samples
2403-
mean = np.zeros(N)
2404-
2405-
# Handle different size specifications
2406-
if size is None:
2407-
return rng.multivariate_normal(mean, cov)
2408-
else:
2409-
return rng.multivariate_normal(mean, cov, size=size)
2360+
"""Sample from the ICAR distribution.
2361+
2362+
Parameters
2363+
----------
2364+
rng : numpy.random.Generator
2365+
Random number generator
2366+
W : ndarray
2367+
Symmetric adjacency matrix
2368+
sigma : float
2369+
Standard deviation parameter
2370+
zero_sum_stdev : float
2371+
Controls how strongly to enforce the zero-sum constraint
2372+
size : tuple, optional
2373+
Size of the samples to generate
2374+
2375+
Returns
2376+
-------
2377+
ndarray
2378+
Samples from the ICAR distribution
2379+
"""
2380+
W = np.asarray(W)
2381+
N = W.shape[0]
2382+
2383+
# Construct the precision matrix (graph Laplacian)
2384+
D = np.diag(W.sum(axis=1))
2385+
Q = D - W
2386+
2387+
# Add regularization for the zero eigenvalue based on zero_sum_stdev
2388+
zero_sum_precision = 1.0 / (zero_sum_stdev * N)**2
2389+
Q_reg = Q + zero_sum_precision * np.ones((N, N)) / N
2390+
2391+
# Convert precision to covariance matrix
2392+
cov = np.linalg.inv(Q_reg)
2393+
2394+
# Generate samples using multivariate_normal with covariance matrix
2395+
mean = np.zeros(N)
2396+
2397+
return sigma * stats.multivariate_normal.rvs(
2398+
mean=mean,
2399+
cov=cov,
2400+
size=size,
2401+
random_state=rng
2402+
)
24102403

24112404

24122405
icar = ICARRV()

tests/distributions/test_multivariate.py

Lines changed: 66 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2249,6 +2249,42 @@ def check_draws_match_expected(self):
22492249
assert np.all(np.abs(draw(x, random_seed=rng) - np.array([0.5, 0, 2.0])) < 0.01)
22502250

22512251

2252+
class TestCAR:
2253+
def test_car_rng_fn(self):
2254+
"""Test the random number generator for the CAR distribution."""
2255+
# Create a simple adjacency matrix for a grid
2256+
W = np.array([
2257+
[0, 1, 0, 1],
2258+
[1, 0, 1, 0],
2259+
[0, 1, 0, 1],
2260+
[1, 0, 1, 0]
2261+
], dtype=np.int32) # Explicitly set dtype
2262+
2263+
rng = np.random.default_rng(42)
2264+
mu = np.array([1.0, 2.0, 3.0, 4.0])
2265+
alpha = 0.7
2266+
tau = 0.5
2267+
2268+
# Generate samples - use W directly instead of tensor
2269+
car_dist = pm.CAR.dist(mu=mu, W=W, alpha=alpha, tau=tau)
2270+
car_samples = np.array([draw(car_dist, random_seed=rng) for _ in range(1000)])
2271+
2272+
# Test shape
2273+
assert car_samples.shape == (1000, 4)
2274+
2275+
# Test mean
2276+
assert np.allclose(car_samples.mean(axis=0), mu, atol=0.1)
2277+
2278+
# Test covariance structure - neighbors should be more correlated
2279+
sample_corr = np.corrcoef(car_samples.T)
2280+
for i in range(4):
2281+
for j in range(4):
2282+
if i != j:
2283+
# Neighbors should have higher correlation than non-neighbors
2284+
if W[i, j] == 1:
2285+
assert sample_corr[i, j] > 0
2286+
2287+
22522288
class TestICAR(BaseTestDistributionRandom):
22532289
pymc_dist = pm.ICAR
22542290
pymc_dist_params = {"W": np.array([[0, 1, 1], [1, 0, 1], [1, 1, 0]]), "sigma": 2}
@@ -2278,12 +2314,36 @@ def test_icar_logp(self):
22782314
).eval(), "logp inaccuracy"
22792315

22802316
def test_icar_rng_fn(self):
2281-
W = np.array([[0, 1, 0, 1], [1, 0, 1, 0], [0, 1, 0, 1], [1, 0, 1, 0]])
2282-
2283-
RV = pm.ICAR.dist(W=W)
2284-
2285-
with pytest.raises(NotImplementedError, match="Cannot sample from ICAR prior"):
2286-
pm.draw(RV)
2317+
"""Test the random number generator for the ICAR distribution."""
2318+
# Create a simple adjacency matrix for a grid
2319+
W = np.array([
2320+
[0, 1, 0, 1],
2321+
[1, 0, 1, 0],
2322+
[0, 1, 0, 1],
2323+
[1, 0, 1, 0]
2324+
], dtype=np.int32) # Explicitly set dtype
2325+
2326+
rng = np.random.default_rng(42)
2327+
sigma = 2.0
2328+
zero_sum_stdev = 0.001
2329+
2330+
# Use W directly instead of converting to tensor
2331+
icar_dist = pm.ICAR.dist(W=W, sigma=sigma, zero_sum_stdev=zero_sum_stdev)
2332+
icar_samples = np.array([draw(icar_dist, random_seed=rng) for _ in range(1000)])
2333+
2334+
# Test shape
2335+
assert icar_samples.shape == (1000, 4)
2336+
2337+
# Test approximate zero-sum constraint
2338+
assert np.abs(icar_samples.sum(axis=1).mean()) < 0.1
2339+
2340+
# Test variance scale - expect variance ≈ sigma^2 * (N-1)/N due to constraint
2341+
var_scale = (W.shape[0] - 1) / W.shape[0] # Degrees of freedom adjustment
2342+
expected_var = sigma**2 * var_scale
2343+
observed_var = np.var(icar_samples, axis=1).mean()
2344+
2345+
# Use a more generous tolerance to account for the zero sum constraint's impact on variance
2346+
assert np.abs(observed_var - expected_var) < 2.0
22872347

22882348
@pytest.mark.parametrize(
22892349
"W,msg",

0 commit comments

Comments
 (0)