@@ -2133,6 +2133,7 @@ def logp(value, rng, size, mu, sigma, *covs):
21332133 return a
21342134
21352135
2136+ # TODO: change this code
21362137class CARRV (RandomVariable ):
21372138 name = "car"
21382139 signature = "(m),(m,m),(),(),()->(m)"
@@ -2344,21 +2345,56 @@ def logp(value, mu, W, alpha, tau, W_is_valid):
23442345 )
23452346
23462347
2347- class ICARRV (RandomVariable ):
2348+ class ICARRV (SymbolicMVNormalUsedInternally ):
2349+ r"""A SymbolicRandomVariable representing an Intrinsic Conditional Autoregressive (ICAR) distribution.
2350+
2351+ This class contains the symbolic logic for the ICAR distribution, which is used by the
2352+ user-facing `pm.ICAR` class to generate random samples and compute log-probabilities.
2353+ """
2354+
23482355 name = "icar"
2349- signature = "(m,m),(),()->(m)"
2350- dtype = "floatX"
2356+ extended_signature = "[rng],[size],(n,n),(),()->[rng],(n)"
23512357 _print_name = ("ICAR" , "\\ operatorname{ICAR}" )
23522358
2353- def __call__ (self , W , sigma , zero_sum_stdev , size = None , ** kwargs ):
2354- return super ().__call__ (W , sigma , zero_sum_stdev , size = size , ** kwargs )
2355-
23562359 @classmethod
2357- def rng_fn (cls , rng , size , W , sigma , zero_sum_stdev ):
2358- raise NotImplementedError ("Cannot sample from ICAR prior" )
2360+ def rv_op (cls , W , sigma , zero_sum_stdev , method = "eigh" , rng = None , size = None ):
2361+ W = pt .as_tensor (W )
2362+ sigma = pt .as_tensor (sigma )
2363+ zero_sum_stdev = pt .as_tensor (zero_sum_stdev )
2364+ rng = normalize_rng_param (rng )
2365+ size = normalize_size_param (size )
2366+
2367+ if rv_size_is_none (size ):
2368+ size = implicit_size_from_params (
2369+ W , sigma , zero_sum_stdev , ndims_params = cls .ndims_params
2370+ )
2371+
2372+ N = W .shape [0 ]
2373+
2374+ # Construct the precision matrix (graph Laplacian)
2375+ D = pt .diag (W .sum (axis = 1 ))
2376+ Q = (D - W ) / (sigma * sigma )
2377+
2378+ # Add regularization for the zero eigenvalue based on zero_sum_stdev
2379+ zero_sum_precision = 1.0 / (zero_sum_stdev * zero_sum_stdev )
2380+ Q_reg = Q + zero_sum_precision * pt .ones ((N , N )) / N
2381+
2382+ # Convert precision to covariance matrix
2383+ cov = pt .linalg .inv (Q_reg ) # TODO: Should this be matrix_inverse(Q_reg)
23592384
2385+ next_rng , mv_draws = multivariate_normal (
2386+ mean = pt .zeros (N ),
2387+ cov = cov ,
2388+ size = size ,
2389+ rng = rng ,
2390+ method = method ,
2391+ ).owner .outputs
23602392
2361- icar = ICARRV ()
2393+ return cls (
2394+ inputs = [rng , size , W , sigma , zero_sum_stdev ],
2395+ outputs = [next_rng , mv_draws ],
2396+ method = method ,
2397+ )(rng , size , W , sigma , zero_sum_stdev )
23622398
23632399
23642400class ICAR (Continuous ):
@@ -2449,7 +2485,8 @@ class ICAR(Continuous):
24492485
24502486 """
24512487
2452- rv_op = icar
2488+ rv_type = ICARRV
2489+ rv_op = ICARRV .rv_op
24532490
24542491 @classmethod
24552492 def dist (cls , W , sigma = 1 , zero_sum_stdev = 0.001 , ** kwargs ):
0 commit comments