Skip to content

Commit 3e86a3e

Browse files
committed
Fix get_support_shape
1 parent cf5b384 commit 3e86a3e

File tree

2 files changed

+8
-4
lines changed

2 files changed

+8
-4
lines changed

pymc/distributions/multivariate.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2487,7 +2487,7 @@ def dist(cls, sigma=1, zerosum_axes=None, support_shape=None, **kwargs):
24872487

24882488
if support_shape is None:
24892489
if zerosum_axes > 0:
2490-
raise ValueError("You must specify shape or support_shape parameter")
2490+
raise ValueError("You must specify dims, shape or support_shape parameter")
24912491
# edge-case doesn't work for now, because at.stack in get_support_shape fails
24922492
# else:
24932493
# support_shape = () # because it's just a Normal in that case
@@ -2553,7 +2553,6 @@ def rv_op(cls, sigma, zerosum_axes, support_shape, size=None):
25532553
)(normal_dist, sigma, support_shape)
25542554

25552555
# TODO:
2556-
# refactor ZSN tests
25572556
# test get_support_shape with 2D
25582557
# test ZSN logp
25592558
# test ZSN variance

pymc/distributions/shape_utils.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -741,8 +741,13 @@ def get_support_shape(
741741
inferred_support_shape = support_shape
742742
# If there are two sources of information for the support shapes, assert they are consistent:
743743
elif support_shape is not None:
744-
inferred_support_shape = Assert(msg="support_shape does not match last shape dimension")(
745-
inferred_support_shape, at.all(at.eq(inferred_support_shape, support_shape))
744+
inferred_support_shape = at.stack(
745+
[
746+
Assert(msg="support_shape does not match last shape dimension")(
747+
inferred, at.eq(inferred, explicit)
748+
)
749+
for inferred, explicit in zip(inferred_support_shape, support_shape)
750+
]
746751
)
747752

748753
return inferred_support_shape

0 commit comments

Comments
 (0)