@@ -1510,7 +1510,7 @@ def helper_deterministics(cls, n, packed_chol):
15101510
15111511class LKJCorrRV (RandomVariable ):
15121512 name = "lkjcorr"
1513- signature = "(),()->(n)"
1513+ signature = "(),()->(n,n )"
15141514 dtype = "floatX"
15151515 _print_name = ("LKJCorrRV" , "\\ operatorname{LKJCorrRV}" )
15161516
@@ -1527,8 +1527,8 @@ def make_node(self, rng, size, n, eta):
15271527
15281528 def _supp_shape_from_params (self , dist_params , ** kwargs ):
15291529 n = dist_params [0 ].squeeze ()
1530- dist_shape = ((n * (n - 1 )) // 2 ,)
1531- return dist_shape
1530+ # dist_shape = ((n * (n - 1)) // 2,)
1531+ return ( n , n )
15321532
15331533 @classmethod
15341534 def rng_fn (cls , rng , n , eta , size ):
@@ -1609,23 +1609,26 @@ def logp(value, n, eta):
16091609 -------
16101610 TensorVariable
16111611 """
1612- if value .ndim > 1 :
1613- raise NotImplementedError ("LKJCorr logp is only implemented for vector values (ndim=1)" )
1614-
1615- # TODO: PyTensor does not have a `triu_indices`, so we can only work with constant
1616- # n (or else find a different expression)
1612+ # if value.ndim > 1:
1613+ # raise NotImplementedError("LKJCorr logp is only implemented for vector values (ndim=1)")
1614+ #
16171615 try :
16181616 n = int (get_underlying_scalar_constant_value (n ))
16191617 except NotScalarConstantError :
16201618 raise NotImplementedError ("logp only implemented for constant `n`" )
16211619
1622- shape = n * (n - 1 ) // 2
1623- tri_index = np .zeros ((n , n ), dtype = "int32" )
1624- tri_index [np .triu_indices (n , k = 1 )] = np .arange (shape )
1625- tri_index [np .triu_indices (n , k = 1 )[::- 1 ]] = np .arange (shape )
1620+ # shape = n * (n - 1) // 2
1621+ # tri_index = np.zeros((n, n), dtype="int32")
1622+ # tri_index[np.triu_indices(n, k=1)] = np.arange(shape)
1623+ # tri_index[np.triu_indices(n, k=1)[::-1]] = np.arange(shape)
1624+
1625+ # value = pt.take(value, tri_index)
1626+ # value = pt.fill_diagonal(value, 1)
16261627
1627- value = pt .take (value , tri_index )
1628- value = pt .fill_diagonal (value , 1 )
1628+ # print(n, type(n))
1629+ # print(value.type.shape)
1630+ # value = value @ value.T
1631+ # print(value.type.shape)
16291632
16301633 # TODO: _lkj_normalizing_constant currently requires `eta` and `n` to be constants
16311634 try :
0 commit comments