34
34
sigmoid ,
35
35
)
36
36
from pytensor .tensor .blockwise import Blockwise
37
- from pytensor .tensor .einsum import _delta
38
37
from pytensor .tensor .elemwise import DimShuffle
39
38
from pytensor .tensor .exceptions import NotScalarConstantError
40
39
from pytensor .tensor .linalg import cholesky , det , eigh , solve_triangular , trace
@@ -1520,53 +1519,50 @@ def make_node(self, rng, size, n, eta):
1520
1519
1521
1520
@classmethod
1522
1521
def rv_op (cls , n : int , eta , * , rng = None , size = None ):
1523
- # We flatten the size to make operations easier, and then rebuild it
1524
1522
n = pt .as_tensor (n , ndim = 0 , dtype = int )
1525
1523
eta = pt .as_tensor (eta , ndim = 0 )
1526
1524
rng = normalize_rng_param (rng )
1527
1525
size = normalize_size_param (size )
1528
1526
1529
- if rv_size_is_none (size ):
1530
- flat_size = 1
1531
- else :
1532
- flat_size = pt .prod (size , dtype = "int64" )
1527
+ next_rng , C = cls ._random_corr_matrix (rng = rng , n = n , eta = eta , size = size )
1533
1528
1534
- next_rng , C = cls ._random_corr_matrix (rng = rng , n = n , eta = eta , flat_size = flat_size )
1535
- C = C [0 ] if rv_size_is_none (size ) else C .reshape ((* size , n , n ))
1536
-
1537
- return cls (
1538
- inputs = [rng , size , n , eta ],
1539
- outputs = [next_rng , C ],
1540
- )(rng , size , n , eta )
1529
+ return cls (inputs = [rng , size , n , eta ], outputs = [next_rng , C ])(rng , size , n , eta )
1541
1530
1542
1531
@classmethod
1543
1532
def _random_corr_matrix (
1544
- cls , rng : Variable , n : int , eta : TensorVariable , flat_size : TensorVariable
1533
+ cls , rng : Variable , n : int , eta : TensorVariable , size : TensorVariable
1545
1534
) -> tuple [Variable , TensorVariable ]:
1546
1535
# original implementation in R see:
1547
1536
# https://github.com/rmcelreath/rethinking/blob/master/R/distributions.r
1548
1537
1538
+ size_is_none = rv_size_is_none (size )
1539
+ size = () if size_is_none else size
1540
+ ein_sig_z = "i, i->" if size_is_none else "...ij, ...ij->...i"
1541
+
1549
1542
beta = eta - 1.0 + n / 2.0
1550
- next_rng , beta_rvs = pt .random .beta (
1551
- alpha = beta , beta = beta , size = flat_size , rng = rng
1552
- ).owner .outputs
1543
+ next_rng , beta_rvs = pt .random .beta (alpha = beta , beta = beta , size = size , rng = rng ).owner .outputs
1553
1544
r12 = 2.0 * beta_rvs - 1.0
1554
- P = pt .full ((flat_size , n , n ), pt .eye (n ))
1545
+
1546
+ P = pt .full ((* size , n , n ), pt .eye (n ))
1555
1547
P = P [..., 0 , 1 ].set (r12 )
1556
1548
P = P [..., 1 , 1 ].set (pt .sqrt (1.0 - r12 ** 2 ))
1557
1549
n = get_underlying_scalar_constant_value (n )
1558
1550
1559
1551
for mp1 in range (2 , n ):
1560
1552
beta -= 0.5
1553
+
1561
1554
next_rng , y = pt .random .beta (
1562
- alpha = mp1 / 2.0 , beta = beta , size = flat_size , rng = next_rng
1555
+ alpha = mp1 / 2.0 , beta = beta , size = size , rng = next_rng
1563
1556
).owner .outputs
1557
+
1564
1558
next_rng , z = pt .random .normal (
1565
- loc = 0 , scale = 1 , size = (flat_size , mp1 ), rng = next_rng
1559
+ loc = 0 , scale = 1 , size = (* size , mp1 ), rng = next_rng
1566
1560
).owner .outputs
1567
- z = z / pt .sqrt (pt .einsum ("ij,ij->i" , z , z .copy ()))[..., np .newaxis ]
1561
+
1562
+ z = z / pt .sqrt (pt .einsum (ein_sig_z , z , z .copy ()))[..., np .newaxis ]
1568
1563
P = P [..., 0 :mp1 , mp1 ].set (pt .sqrt (y [..., np .newaxis ]) * z )
1569
1564
P = P [..., mp1 , mp1 ].set (pt .sqrt (1.0 - y ))
1565
+
1570
1566
C = pt .einsum ("...ji,...jk->...ik" , P , P .copy ())
1571
1567
1572
1568
return next_rng , C
@@ -1584,10 +1580,7 @@ def dist(cls, n, eta, **kwargs):
1584
1580
1585
1581
@staticmethod
1586
1582
def support_point (rv : TensorVariable , * args ):
1587
- ndim = rv .ndim
1588
-
1589
- # Batched identity matrix
1590
- return _delta (rv .shape , (ndim - 2 , ndim - 1 )).astype (int )
1583
+ return pt .broadcast_to (pt .eye (rv .shape [- 1 ]), rv .shape )
1591
1584
1592
1585
@staticmethod
1593
1586
def logp (value : TensorVariable , n , eta ):
0 commit comments