Skip to content

Commit 27fd53d

Browse files
Don't flatten batch dims
1 parent 4a0f717 commit 27fd53d

File tree

2 files changed

+32
-42
lines changed

2 files changed

+32
-42
lines changed

pymc/distributions/multivariate.py

Lines changed: 22 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@
3434
sigmoid,
3535
)
3636
from pytensor.tensor.blockwise import Blockwise
37-
from pytensor.tensor.einsum import _delta
3837
from pytensor.tensor.elemwise import DimShuffle
3938
from pytensor.tensor.exceptions import NotScalarConstantError
4039
from pytensor.tensor.linalg import cholesky, det, eigh, solve_triangular, trace
@@ -1213,12 +1212,8 @@ def rv_op(cls, n, eta, sd_dist, *, size=None):
12131212
D = sd_dist.type(name="D") # Make sd_dist opaque to OpFromGraph
12141213
size = D.shape[:-1]
12151214

1216-
# We flatten the size to make operations easier, and then rebuild it
1217-
flat_size = pt.prod(size, dtype="int64")
1218-
1219-
next_rng, C = LKJCorrRV._random_corr_matrix(rng=rng, n=n, eta=eta, flat_size=flat_size)
1220-
D_matrix = D.reshape((flat_size, n))
1221-
C *= D_matrix[..., :, None] * D_matrix[..., None, :]
1215+
next_rng, C = LKJCorrRV._random_corr_matrix(rng=rng, n=n, eta=eta, size=size)
1216+
C *= D[..., :, None] * D[..., None, :]
12221217

12231218
tril_idx = pt.tril_indices(n, k=0)
12241219
samples = pt.linalg.cholesky(C)[..., tril_idx[0], tril_idx[1]]
@@ -1520,53 +1515,52 @@ def make_node(self, rng, size, n, eta):
15201515

15211516
@classmethod
15221517
def rv_op(cls, n: int, eta, *, rng=None, size=None):
1523-
# We flatten the size to make operations easier, and then rebuild it
1518+
# HACK: normalize_size_param doesn't handle size=() properly
1519+
if not size:
1520+
size = None
1521+
15241522
n = pt.as_tensor(n, ndim=0, dtype=int)
15251523
eta = pt.as_tensor(eta, ndim=0)
15261524
rng = normalize_rng_param(rng)
15271525
size = normalize_size_param(size)
15281526

1529-
if rv_size_is_none(size):
1530-
flat_size = 1
1531-
else:
1532-
flat_size = pt.prod(size, dtype="int64")
1527+
next_rng, C = cls._random_corr_matrix(rng=rng, n=n, eta=eta, size=size)
15331528

1534-
next_rng, C = cls._random_corr_matrix(rng=rng, n=n, eta=eta, flat_size=flat_size)
1535-
C = C[0] if rv_size_is_none(size) else C.reshape((*size, n, n))
1536-
1537-
return cls(
1538-
inputs=[rng, size, n, eta],
1539-
outputs=[next_rng, C],
1540-
)(rng, size, n, eta)
1529+
return cls(inputs=[rng, size, n, eta], outputs=[next_rng, C])(rng, size, n, eta)
15411530

15421531
@classmethod
15431532
def _random_corr_matrix(
1544-
cls, rng: Variable, n: int, eta: TensorVariable, flat_size: TensorVariable
1533+
cls, rng: Variable, n: int, eta: TensorVariable, size: TensorVariable
15451534
) -> tuple[Variable, TensorVariable]:
15461535
# original implementation in R see:
15471536
# https://github.com/rmcelreath/rethinking/blob/master/R/distributions.r
1537+
size = () if rv_size_is_none(size) else size
15481538

15491539
beta = eta - 1.0 + n / 2.0
1550-
next_rng, beta_rvs = pt.random.beta(
1551-
alpha=beta, beta=beta, size=flat_size, rng=rng
1552-
).owner.outputs
1540+
next_rng, beta_rvs = pt.random.beta(alpha=beta, beta=beta, size=size, rng=rng).owner.outputs
15531541
r12 = 2.0 * beta_rvs - 1.0
1554-
P = pt.full((flat_size, n, n), pt.eye(n))
1542+
1543+
P = pt.full((*size, n, n), pt.eye(n))
15551544
P = P[..., 0, 1].set(r12)
15561545
P = P[..., 1, 1].set(pt.sqrt(1.0 - r12**2))
15571546
n = get_underlying_scalar_constant_value(n)
15581547

15591548
for mp1 in range(2, n):
15601549
beta -= 0.5
1550+
15611551
next_rng, y = pt.random.beta(
1562-
alpha=mp1 / 2.0, beta=beta, size=flat_size, rng=next_rng
1552+
alpha=mp1 / 2.0, beta=beta, size=size, rng=next_rng
15631553
).owner.outputs
1554+
15641555
next_rng, z = pt.random.normal(
1565-
loc=0, scale=1, size=(flat_size, mp1), rng=next_rng
1556+
loc=0, scale=1, size=(*size, mp1), rng=next_rng
15661557
).owner.outputs
1567-
z = z / pt.sqrt(pt.einsum("ij,ij->i", z, z.copy()))[..., np.newaxis]
1558+
1559+
ein_sig_z = "i, i->" if z.ndim == 1 else "...ij, ...ij->...i"
1560+
z = z / pt.sqrt(pt.einsum(ein_sig_z, z, z.copy()))[..., np.newaxis]
15681561
P = P[..., 0:mp1, mp1].set(pt.sqrt(y[..., np.newaxis]) * z)
15691562
P = P[..., mp1, mp1].set(pt.sqrt(1.0 - y))
1563+
15701564
C = pt.einsum("...ji,...jk->...ik", P, P.copy())
15711565

15721566
return next_rng, C
@@ -1584,10 +1578,7 @@ def dist(cls, n, eta, **kwargs):
15841578

15851579
@staticmethod
15861580
def support_point(rv: TensorVariable, *args):
1587-
ndim = rv.ndim
1588-
1589-
# Batched identity matrix
1590-
return _delta(rv.shape, (ndim - 2, ndim - 1)).astype(int)
1581+
return pt.broadcast_to(pt.eye(rv.shape[-1]), rv.shape)
15911582

15921583
@staticmethod
15931584
def logp(value: TensorVariable, n, eta):

tests/distributions/test_multivariate.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1312,17 +1312,17 @@ def test_kronecker_normal_support_point(self, mu, covs, size, expected):
13121312
@pytest.mark.parametrize(
13131313
"n, eta, size, expected",
13141314
[
1315-
(3, 1, None, np.zeros(3)),
1316-
(5, 1, None, np.zeros(10)),
1317-
pytest.param(3, 1, 1, np.zeros((1, 3))),
1318-
pytest.param(5, 1, (2, 3), np.zeros((2, 3, 10))),
1315+
(3, 1, None, np.eye(3)),
1316+
(5, 1, None, np.eye(5)),
1317+
(3, 1, (1,), np.broadcast_to(np.eye(3), (1, 3, 3))),
1318+
(5, 1, (2, 3), np.broadcast_to(np.eye(5), (2, 3, 5, 5))),
13191319
],
1320+
ids=["n=3", "n=5", "batch_1", "batch_2"],
13201321
)
13211322
def test_lkjcorr_support_point(self, n, eta, size, expected):
13221323
with pm.Model() as model:
1323-
pm.LKJCorr("x", n=n, eta=eta, size=size, return_matrix=False)
1324-
# LKJCorr logp is only implemented for vector values (size=None)
1325-
assert_support_point_is_expected(model, expected, check_finite_logp=size is None)
1324+
pm.LKJCorr("x", n=n, eta=eta, size=size)
1325+
assert_support_point_is_expected(model, expected, check_finite_logp=True)
13261326

13271327
@pytest.mark.parametrize(
13281328
"n, eta, size, expected",
@@ -1466,13 +1466,12 @@ def test_with_lkjcorr_matrix(
14661466
self,
14671467
):
14681468
with pm.Model() as model:
1469-
corr = pm.LKJCorr("corr", n=3, eta=2, return_matrix=True)
1470-
pm.Deterministic("corr_mat", corr)
1471-
mv = pm.MvNormal("mv", 0.0, cov=corr, size=4)
1469+
corr_mat = pm.LKJCorr("corr_mat", n=3, eta=2)
1470+
mv = pm.MvNormal("mv", 0.0, cov=corr_mat, size=4)
14721471
prior = pm.sample_prior_predictive(draws=10, return_inferencedata=False)
14731472

14741473
assert prior["corr_mat"].shape == (10, 3, 3) # square
1475-
assert (prior["corr_mat"][:, [0, 1, 2], [0, 1, 2]] == 1.0).all() # 1.0 on diagonal
1474+
assert np.allclose(prior["corr_mat"][:, [0, 1, 2], [0, 1, 2]], 1.0) # 1.0 on diagonal
14761475
assert (prior["corr_mat"] == prior["corr_mat"].transpose(0, 2, 1)).all() # symmetric
14771476
assert (
14781477
prior["corr_mat"].max() <= 1.0 and prior["corr_mat"].min() >= -1.0

0 commit comments

Comments
 (0)