Skip to content

Commit 0fbccf4

Browse files
committed
feat #7713: implement ICARRV as SymbolicRV
1 parent 011fb35 commit 0fbccf4

File tree

2 files changed

+83
-14
lines changed

2 files changed

+83
-14
lines changed

pymc/distributions/multivariate.py

Lines changed: 47 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2133,6 +2133,7 @@ def logp(value, rng, size, mu, sigma, *covs):
21332133
return a
21342134

21352135

2136+
# TODO: change this code
21362137
class CARRV(RandomVariable):
21372138
name = "car"
21382139
signature = "(m),(m,m),(),(),()->(m)"
@@ -2344,21 +2345,56 @@ def logp(value, mu, W, alpha, tau, W_is_valid):
23442345
)
23452346

23462347

2347-
class ICARRV(RandomVariable):
2348+
class ICARRV(SymbolicMVNormalUsedInternally):
2349+
r"""A SymbolicRandomVariable representing an Intrinsic Conditional Autoregressive (ICAR) distribution.
2350+
2351+
This class contains the symbolic logic for the ICAR distribution, which is used by the
2352+
user-facing `pm.ICAR` class to generate random samples and compute log-probabilities.
2353+
"""
2354+
23482355
name = "icar"
2349-
signature = "(m,m),(),()->(m)"
2350-
dtype = "floatX"
2356+
extended_signature = "[rng],[size],(n,n),(),()->[rng],(n)"
23512357
_print_name = ("ICAR", "\\operatorname{ICAR}")
23522358

2353-
def __call__(self, W, sigma, zero_sum_stdev, size=None, **kwargs):
2354-
return super().__call__(W, sigma, zero_sum_stdev, size=size, **kwargs)
2355-
23562359
@classmethod
2357-
def rng_fn(cls, rng, size, W, sigma, zero_sum_stdev):
2358-
raise NotImplementedError("Cannot sample from ICAR prior")
2360+
def rv_op(cls, W, sigma, zero_sum_stdev, method="eigh", rng=None, size=None):
2361+
W = pt.as_tensor(W)
2362+
sigma = pt.as_tensor(sigma)
2363+
zero_sum_stdev = pt.as_tensor(zero_sum_stdev)
2364+
rng = normalize_rng_param(rng)
2365+
size = normalize_size_param(size)
2366+
2367+
if rv_size_is_none(size):
2368+
size = implicit_size_from_params(
2369+
W, sigma, zero_sum_stdev, ndims_params=cls.ndims_params
2370+
)
2371+
2372+
N = W.shape[0]
2373+
2374+
# Construct the precision matrix (graph Laplacian)
2375+
D = pt.diag(W.sum(axis=1))
2376+
Q = (D - W) / (sigma * sigma)
2377+
2378+
# Add regularization for the zero eigenvalue based on zero_sum_stdev
2379+
zero_sum_precision = 1.0 / (zero_sum_stdev * zero_sum_stdev)
2380+
Q_reg = Q + zero_sum_precision * pt.ones((N, N)) / N
2381+
2382+
# Convert precision to covariance matrix
2383+
cov = pt.linalg.inv(Q_reg) # TODO: Should this be matrix_inverse(Q_reg)
23592384

2385+
next_rng, mv_draws = multivariate_normal(
2386+
mean=pt.zeros(N),
2387+
cov=cov,
2388+
size=size,
2389+
rng=rng,
2390+
method=method,
2391+
).owner.outputs
23602392

2361-
icar = ICARRV()
2393+
return cls(
2394+
inputs=[rng, size, W, sigma, zero_sum_stdev],
2395+
outputs=[next_rng, mv_draws],
2396+
method=method,
2397+
)(rng, size, W, sigma, zero_sum_stdev)
23622398

23632399

23642400
class ICAR(Continuous):
@@ -2449,7 +2485,8 @@ class ICAR(Continuous):
24492485
24502486
"""
24512487

2452-
rv_op = icar
2488+
rv_type = ICARRV
2489+
rv_op = ICARRV.rv_op
24532490

24542491
@classmethod
24552492
def dist(cls, W, sigma=1, zero_sum_stdev=0.001, **kwargs):

tests/distributions/test_multivariate.py

Lines changed: 36 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2279,12 +2279,43 @@ def test_icar_logp(self):
22792279
).eval(), "logp inaccuracy"
22802280

22812281
def test_icar_rng_fn(self):
2282-
W = np.array([[0, 1, 0, 1], [1, 0, 1, 0], [0, 1, 0, 1], [1, 0, 1, 0]])
2282+
delta = 0.05 # limit for KS p-value
2283+
n_fails = 20 # Allows the KS fails a certain number of times
2284+
size = (100,)
2285+
2286+
W_val = np.array(
2287+
[[0.0, 1.0, 0.0, 1.0], [1.0, 0.0, 1.0, 0.0], [0.0, 1.0, 0.0, 1.0], [1.0, 0.0, 1.0, 0.0]]
2288+
)
2289+
sigma = 2.0
2290+
zero_sum_stdev = 0.1
2291+
N = W_val.shape[0]
2292+
2293+
D = np.diag(W_val.sum(axis=1))
2294+
Q = (D - W_val) / (sigma * sigma)
2295+
zero_sum_precision = 1.0 / (zero_sum_stdev**2)
2296+
Q_reg = Q + zero_sum_precision * np.ones((N, N)) / N
2297+
cov = np.linalg.inv(Q_reg)
2298+
2299+
# TODO: Should W be a pt.tensor ?
2300+
with pm.Model():
2301+
icar = pm.ICAR("icar", W=W_val, sigma=sigma, zero_sum_stdev=zero_sum_stdev, size=size)
2302+
mn = pm.MvNormal("mn", mu=0.0, cov=cov, size=size)
2303+
# Draw n_fails samples
2304+
check = pm.sample_prior_predictive(n_fails, return_inferencedata=False, random_seed=42)
22832305

2284-
RV = pm.ICAR.dist(W=W)
2306+
p, f = delta, n_fails
2307+
while p <= delta and f > 0:
2308+
icar_smp, mn_smp = check["icar"][f - 1, :, :], check["mn"][f - 1, :, :]
2309+
p = min(
2310+
st.ks_2samp(
2311+
np.atleast_1d(icar_smp[..., idx]).flatten(),
2312+
np.atleast_1d(mn_smp[..., idx]).flatten(),
2313+
)[1]
2314+
for idx in range(icar_smp.shape[-1])
2315+
)
2316+
f -= 1
22852317

2286-
with pytest.raises(NotImplementedError, match="Cannot sample from ICAR prior"):
2287-
pm.draw(RV)
2318+
assert p > delta
22882319

22892320
@pytest.mark.parametrize(
22902321
"W,msg",
@@ -2307,6 +2338,7 @@ def test_icar_matrix_checks(self, W, msg):
23072338
pm.ICAR("phi", W=W)
23082339

23092340

2341+
# TODO: Fix this after updating the rng approach
23102342
@pytest.mark.parametrize("sparse", [True, False])
23112343
def test_car_rng_fn(sparse):
23122344
delta = 0.05 # limit for KS p-value

0 commit comments

Comments
 (0)