34
34
sigmoid ,
35
35
)
36
36
from pytensor .tensor .blockwise import Blockwise
37
- from pytensor .tensor .einsum import _delta
38
37
from pytensor .tensor .elemwise import DimShuffle
39
38
from pytensor .tensor .exceptions import NotScalarConstantError
40
39
from pytensor .tensor .linalg import cholesky , det , eigh , solve_triangular , trace
@@ -1213,12 +1212,8 @@ def rv_op(cls, n, eta, sd_dist, *, size=None):
1213
1212
D = sd_dist .type (name = "D" ) # Make sd_dist opaque to OpFromGraph
1214
1213
size = D .shape [:- 1 ]
1215
1214
1216
- # We flatten the size to make operations easier, and then rebuild it
1217
- flat_size = pt .prod (size , dtype = "int64" )
1218
-
1219
- next_rng , C = LKJCorrRV ._random_corr_matrix (rng = rng , n = n , eta = eta , flat_size = flat_size )
1220
- D_matrix = D .reshape ((flat_size , n ))
1221
- C *= D_matrix [..., :, None ] * D_matrix [..., None , :]
1215
+ next_rng , C = LKJCorrRV ._random_corr_matrix (rng = rng , n = n , eta = eta , size = size )
1216
+ C *= D [..., :, None ] * D [..., None , :]
1222
1217
1223
1218
tril_idx = pt .tril_indices (n , k = 0 )
1224
1219
samples = pt .linalg .cholesky (C )[..., tril_idx [0 ], tril_idx [1 ]]
@@ -1520,53 +1515,52 @@ def make_node(self, rng, size, n, eta):
1520
1515
1521
1516
@classmethod
1522
1517
def rv_op (cls , n : int , eta , * , rng = None , size = None ):
1523
- # We flatten the size to make operations easier, and then rebuild it
1518
+ # HACK: normalize_size_param doesn't handle size=() properly
1519
+ if not size :
1520
+ size = None
1521
+
1524
1522
n = pt .as_tensor (n , ndim = 0 , dtype = int )
1525
1523
eta = pt .as_tensor (eta , ndim = 0 )
1526
1524
rng = normalize_rng_param (rng )
1527
1525
size = normalize_size_param (size )
1528
1526
1529
- if rv_size_is_none (size ):
1530
- flat_size = 1
1531
- else :
1532
- flat_size = pt .prod (size , dtype = "int64" )
1527
+ next_rng , C = cls ._random_corr_matrix (rng = rng , n = n , eta = eta , size = size )
1533
1528
1534
- next_rng , C = cls ._random_corr_matrix (rng = rng , n = n , eta = eta , flat_size = flat_size )
1535
- C = C [0 ] if rv_size_is_none (size ) else C .reshape ((* size , n , n ))
1536
-
1537
- return cls (
1538
- inputs = [rng , size , n , eta ],
1539
- outputs = [next_rng , C ],
1540
- )(rng , size , n , eta )
1529
+ return cls (inputs = [rng , size , n , eta ], outputs = [next_rng , C ])(rng , size , n , eta )
1541
1530
1542
1531
@classmethod
1543
1532
def _random_corr_matrix (
1544
- cls , rng : Variable , n : int , eta : TensorVariable , flat_size : TensorVariable
1533
+ cls , rng : Variable , n : int , eta : TensorVariable , size : TensorVariable
1545
1534
) -> tuple [Variable , TensorVariable ]:
1546
1535
# original implementation in R see:
1547
1536
# https://github.com/rmcelreath/rethinking/blob/master/R/distributions.r
1537
+ size = () if rv_size_is_none (size ) else size
1548
1538
1549
1539
beta = eta - 1.0 + n / 2.0
1550
- next_rng , beta_rvs = pt .random .beta (
1551
- alpha = beta , beta = beta , size = flat_size , rng = rng
1552
- ).owner .outputs
1540
+ next_rng , beta_rvs = pt .random .beta (alpha = beta , beta = beta , size = size , rng = rng ).owner .outputs
1553
1541
r12 = 2.0 * beta_rvs - 1.0
1554
- P = pt .full ((flat_size , n , n ), pt .eye (n ))
1542
+
1543
+ P = pt .full ((* size , n , n ), pt .eye (n ))
1555
1544
P = P [..., 0 , 1 ].set (r12 )
1556
1545
P = P [..., 1 , 1 ].set (pt .sqrt (1.0 - r12 ** 2 ))
1557
1546
n = get_underlying_scalar_constant_value (n )
1558
1547
1559
1548
for mp1 in range (2 , n ):
1560
1549
beta -= 0.5
1550
+
1561
1551
next_rng , y = pt .random .beta (
1562
- alpha = mp1 / 2.0 , beta = beta , size = flat_size , rng = next_rng
1552
+ alpha = mp1 / 2.0 , beta = beta , size = size , rng = next_rng
1563
1553
).owner .outputs
1554
+
1564
1555
next_rng , z = pt .random .normal (
1565
- loc = 0 , scale = 1 , size = (flat_size , mp1 ), rng = next_rng
1556
+ loc = 0 , scale = 1 , size = (* size , mp1 ), rng = next_rng
1566
1557
).owner .outputs
1567
- z = z / pt .sqrt (pt .einsum ("ij,ij->i" , z , z .copy ()))[..., np .newaxis ]
1558
+
1559
+ ein_sig_z = "i, i->" if z .ndim == 1 else "...ij, ...ij->...i"
1560
+ z = z / pt .sqrt (pt .einsum (ein_sig_z , z , z .copy ()))[..., np .newaxis ]
1568
1561
P = P [..., 0 :mp1 , mp1 ].set (pt .sqrt (y [..., np .newaxis ]) * z )
1569
1562
P = P [..., mp1 , mp1 ].set (pt .sqrt (1.0 - y ))
1563
+
1570
1564
C = pt .einsum ("...ji,...jk->...ik" , P , P .copy ())
1571
1565
1572
1566
return next_rng , C
@@ -1584,10 +1578,7 @@ def dist(cls, n, eta, **kwargs):
1584
1578
1585
1579
@staticmethod
1586
1580
def support_point (rv : TensorVariable , * args ):
1587
- ndim = rv .ndim
1588
-
1589
- # Batched identity matrix
1590
- return _delta (rv .shape , (ndim - 2 , ndim - 1 )).astype (int )
1581
+ return pt .broadcast_to (pt .eye (rv .shape [- 1 ]), rv .shape )
1591
1582
1592
1583
@staticmethod
1593
1584
def logp (value : TensorVariable , n , eta ):
0 commit comments