34
34
sigmoid ,
35
35
)
36
36
from pytensor .tensor .blockwise import Blockwise
37
+ from pytensor .tensor .einsum import _delta
37
38
from pytensor .tensor .elemwise import DimShuffle
38
39
from pytensor .tensor .exceptions import NotScalarConstantError
39
40
from pytensor .tensor .linalg import cholesky , det , eigh , solve_triangular , trace
76
77
)
77
78
from pymc .distributions .transforms import (
78
79
CholeskyCorrTransform ,
79
- Interval ,
80
80
ZeroSumTransform ,
81
81
_default_transform ,
82
82
)
@@ -1157,12 +1157,12 @@ def _lkj_normalizing_constant(eta, n):
1157
1157
if not isinstance (n , int ):
1158
1158
raise NotImplementedError ("n must be an integer" )
1159
1159
if eta == 1 :
1160
- result = gammaln (2.0 * pt .arange (1 , int ((n - 1 ) / 2 ) + 1 )).sum ()
1160
+ result = gammaln (2.0 * pt .arange (1 , ((n - 1 ) / 2 ) + 1 )).sum ()
1161
1161
if n % 2 == 1 :
1162
1162
result += (
1163
1163
0.25 * (n ** 2 - 1 ) * pt .log (np .pi )
1164
1164
- 0.25 * (n - 1 ) ** 2 * pt .log (2.0 )
1165
- - (n - 1 ) * gammaln (int (( n + 1 ) / 2 ) )
1165
+ - (n - 1 ) * gammaln (( n + 1 ) / 2 )
1166
1166
)
1167
1167
else :
1168
1168
result += (
@@ -1504,7 +1504,7 @@ def helper_deterministics(cls, n, packed_chol):
1504
1504
1505
1505
class LKJCorrRV (SymbolicRandomVariable ):
1506
1506
name = "lkjcorr"
1507
- extended_signature = "[rng],[size],(),()->[rng],(n)"
1507
+ extended_signature = "[rng],[size],(),()->[rng],(n,n )"
1508
1508
_print_name = ("LKJCorrRV" , "\\ operatorname{LKJCorrRV}" )
1509
1509
1510
1510
def make_node (self , rng , size , n , eta ):
@@ -1532,23 +1532,13 @@ def rv_op(cls, n: int, eta, *, rng=None, size=None):
1532
1532
flat_size = pt .prod (size , dtype = "int64" )
1533
1533
1534
1534
next_rng , C = cls ._random_corr_matrix (rng = rng , n = n , eta = eta , flat_size = flat_size )
1535
-
1536
- triu_idx = pt .triu_indices (n , k = 1 )
1537
- samples = C [..., triu_idx [0 ], triu_idx [1 ]]
1538
-
1539
- if rv_size_is_none (size ):
1540
- samples = samples [0 ]
1541
- else :
1542
- dist_shape = (n * (n - 1 )) // 2
1543
- samples = pt .reshape (samples , (* size , dist_shape ))
1535
+ C = C [0 ] if rv_size_is_none (size ) else C .reshape ((* size , n , n ))
1544
1536
1545
1537
return cls (
1546
1538
inputs = [rng , size , n , eta ],
1547
- outputs = [next_rng , samples ],
1539
+ outputs = [next_rng , C ],
1548
1540
)(rng , size , n , eta )
1549
1541
1550
- return samples
1551
-
1552
1542
@classmethod
1553
1543
def _random_corr_matrix (
1554
1544
cls , rng : Variable , n : int , eta : TensorVariable , flat_size : TensorVariable
@@ -1565,6 +1555,7 @@ def _random_corr_matrix(
1565
1555
P = P [..., 0 , 1 ].set (r12 )
1566
1556
P = P [..., 1 , 1 ].set (pt .sqrt (1.0 - r12 ** 2 ))
1567
1557
n = get_underlying_scalar_constant_value (n )
1558
+
1568
1559
for mp1 in range (2 , n ):
1569
1560
beta -= 0.5
1570
1561
next_rng , y = pt .random .beta (
@@ -1577,17 +1568,10 @@ def _random_corr_matrix(
1577
1568
P = P [..., 0 :mp1 , mp1 ].set (pt .sqrt (y [..., np .newaxis ]) * z )
1578
1569
P = P [..., mp1 , mp1 ].set (pt .sqrt (1.0 - y ))
1579
1570
C = pt .einsum ("...ji,...jk->...ik" , P , P .copy ())
1580
- return next_rng , C
1581
-
1582
1571
1583
- class MultivariateIntervalTransform (Interval ):
1584
- name = "interval"
1585
-
1586
- def log_jac_det (self , * args ):
1587
- return super ().log_jac_det (* args ).sum (- 1 )
1572
+ return next_rng , C
1588
1573
1589
1574
1590
- # Returns list of upper triangular values
1591
1575
class _LKJCorr (BoundedContinuous ):
1592
1576
rv_type = LKJCorrRV
1593
1577
rv_op = LKJCorrRV .rv_op
@@ -1598,10 +1582,15 @@ def dist(cls, n, eta, **kwargs):
1598
1582
eta = pt .as_tensor_variable (eta )
1599
1583
return super ().dist ([n , eta ], ** kwargs )
1600
1584
1601
- def support_point (rv , * args ):
1602
- return pt .zeros_like (rv )
1585
+ @staticmethod
1586
+ def support_point (rv : TensorVariable , * args ):
1587
+ ndim = rv .ndim
1603
1588
1604
- def logp (value , n , eta ):
1589
+ # Batched identity matrix
1590
+ return _delta (rv .shape , (ndim - 2 , ndim - 1 )).astype (int )
1591
+
1592
+ @staticmethod
1593
+ def logp (value : TensorVariable , n , eta ):
1605
1594
"""
1606
1595
Calculate logp of LKJ distribution at specified value.
1607
1596
@@ -1614,31 +1603,20 @@ def logp(value, n, eta):
1614
1603
-------
1615
1604
TensorVariable
1616
1605
"""
1617
- if value .ndim > 1 :
1618
- raise NotImplementedError ("LKJCorr logp is only implemented for vector values (ndim=1)" )
1619
-
1620
- # TODO: PyTensor does not have a `triu_indices`, so we can only work with constant
1621
- # n (or else find a different expression)
1606
+ # TODO: _lkj_normalizing_constant currently requires `eta` and `n` to be constants
1622
1607
try :
1623
1608
n = int (get_underlying_scalar_constant_value (n ))
1624
1609
except NotScalarConstantError :
1625
1610
raise NotImplementedError ("logp only implemented for constant `n`" )
1626
1611
1627
- shape = n * (n - 1 ) // 2
1628
- tri_index = np .zeros ((n , n ), dtype = "int32" )
1629
- tri_index [np .triu_indices (n , k = 1 )] = np .arange (shape )
1630
- tri_index [np .triu_indices (n , k = 1 )[::- 1 ]] = np .arange (shape )
1631
-
1632
- value = pt .take (value , tri_index )
1633
- value = pt .fill_diagonal (value , 1 )
1634
-
1635
- # TODO: _lkj_normalizing_constant currently requires `eta` and `n` to be constants
1636
1612
try :
1637
1613
eta = float (get_underlying_scalar_constant_value (eta ))
1638
1614
except NotScalarConstantError :
1639
1615
raise NotImplementedError ("logp only implemented for constant `eta`" )
1616
+
1640
1617
result = _lkj_normalizing_constant (eta , n )
1641
1618
result += (eta - 1.0 ) * pt .log (det (value ))
1619
+
1642
1620
return check_parameters (
1643
1621
result ,
1644
1622
value >= - 1 ,
@@ -1675,10 +1653,6 @@ class LKJCorr:
1675
1653
The shape parameter (eta > 0) of the LKJ distribution. eta = 1
1676
1654
implies a uniform distribution of the correlation matrices;
1677
1655
larger values put more weight on matrices with few correlations.
1678
- return_matrix : bool, default=False
1679
- If True, returns the full correlation matrix.
1680
- False only returns the values of the upper triangular matrix excluding
1681
- diagonal in a single vector of length n(n-1)/2 for memory efficiency
1682
1656
1683
1657
Notes
1684
1658
-----
@@ -1693,7 +1667,7 @@ class LKJCorr:
1693
1667
# Define the vector of fixed standard deviations
1694
1668
sds = 3 * np.ones(10)
1695
1669
1696
- corr = pm.LKJCorr("corr", eta=4, n=10, return_matrix=True )
1670
+ corr = pm.LKJCorr("corr", eta=4, n=10)
1697
1671
1698
1672
# Define a new MvNormal with the given correlation matrix
1699
1673
vals = sds * pm.MvNormal("vals", mu=np.zeros(10), cov=corr, shape=10)
@@ -1703,10 +1677,6 @@ class LKJCorr:
1703
1677
chol = pt.linalg.cholesky(corr)
1704
1678
vals = sds * pt.dot(chol, vals_raw)
1705
1679
1706
- # The matrix is internally still sampled as a upper triangular vector
1707
- # If you want access to it in matrix form in the trace, add
1708
- pm.Deterministic("corr_mat", corr)
1709
-
1710
1680
1711
1681
References
1712
1682
----------
@@ -1716,26 +1686,28 @@ class LKJCorr:
1716
1686
100(9), pp.1989-2001.
1717
1687
"""
1718
1688
1719
- def __new__ (cls , name , n , eta , * , return_matrix = False , ** kwargs ):
1720
- c_vec = _LKJCorr (name , eta = eta , n = n , ** kwargs )
1721
- if not return_matrix :
1722
- return c_vec
1723
- else :
1724
- return cls .vec_to_corr_mat (c_vec , n )
1725
-
1726
- @classmethod
1727
- def dist (cls , n , eta , * , return_matrix = False , ** kwargs ):
1728
- c_vec = _LKJCorr .dist (eta = eta , n = n , ** kwargs )
1729
- if not return_matrix :
1730
- return c_vec
1731
- else :
1732
- return cls .vec_to_corr_mat (c_vec , n )
1689
+ def __new__ (cls , name , n , eta , ** kwargs ):
1690
+ return_matrix = kwargs .pop ("return_matrix" , None )
1691
+ if return_matrix is not None :
1692
+ warnings .warn (
1693
+ "The `return_matrix` argument is deprecated and has no effect. "
1694
+ "LKJCorr always returns the correlation matrix." ,
1695
+ DeprecationWarning ,
1696
+ stacklevel = 2 ,
1697
+ )
1698
+ return _LKJCorr (name , eta = eta , n = n , ** kwargs )
1733
1699
1734
1700
@classmethod
1735
- def vec_to_corr_mat (cls , vec , n ):
1736
- tri = pt .zeros (pt .concatenate ([vec .shape [:- 1 ], (n , n )]))
1737
- tri = pt .subtensor .set_subtensor (tri [(..., * np .triu_indices (n , 1 ))], vec )
1738
- return tri + pt .moveaxis (tri , - 2 , - 1 ) + pt .diag (pt .ones (n ))
1701
+ def dist (cls , n , eta , ** kwargs ):
1702
+ return_matrix = kwargs .pop ("return_matrix" , None )
1703
+ if return_matrix is not None :
1704
+ warnings .warn (
1705
+ "The `return_matrix` argument is deprecated and has no effect. "
1706
+ "LKJCorr always returns the correlation matrix." ,
1707
+ DeprecationWarning ,
1708
+ stacklevel = 2 ,
1709
+ )
1710
+ return _LKJCorr .dist (eta = eta , n = n , ** kwargs )
1739
1711
1740
1712
1741
1713
class MatrixNormalRV (RandomVariable ):
0 commit comments