Skip to content

Commit 7e4ed0a

Browse files
committed
Fix examples in ZSN docstrings
1 parent 4c52737 commit 7e4ed0a

File tree

1 file changed

+10
-13
lines changed

1 file changed

+10
-13
lines changed

pymc/distributions/multivariate.py

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2437,13 +2437,16 @@ class ZeroSumNormal(Distribution):
24372437
"answers": ["yes", "no", "whatever", "don't understand question"],
24382438
}
24392439
with pm.Model(coords=COORDS) as m:
2440-
...: v = pm.ZeroSumNormal("v", dims=("regions", "answers"), zerosum_axes="answers")
2440+
# the zero sum axis will be 'answers'
2441+
...: v = pm.ZeroSumNormal("v", dims=("regions", "answers"))
24412442
24422443
with pm.Model(coords=COORDS) as m:
2443-
...: v = pm.ZeroSumNormal("v", dims=("regions", "answers"), zerosum_axes=("regions", "answers"))
2444+
# the zero sum axes will be 'answers' and 'regions'
2445+
...: v = pm.ZeroSumNormal("v", dims=("regions", "answers"), zerosum_axes=2)
24442446
24452447
with pm.Model(coords=COORDS) as m:
2446-
...: v = pm.ZeroSumNormal("v", dims=("regions", "answers"), zerosum_axes=1)
2448+
# the zero sum axes will be the last two
2449+
...: v = pm.ZeroSumNormal("v", shape=(3, 4, 5), zerosum_axes=2)
24472450
"""
24482451
rv_type = ZeroSumNormalRV
24492452

@@ -2525,18 +2528,13 @@ def check_zerosum_axes(cls, zerosum_axes: Optional[int]) -> int:
25252528

25262529
@classmethod
25272530
def rv_op(cls, sigma, zerosum_axes, support_shape, size=None):
2528-
# if size is None:
2529-
# zerosum_axes_ = np.asarray(zerosum_axes)
2530-
# # just a placeholder size to infer minimum shape
2531-
# size = np.ones(
2532-
# max((max(np.abs(zerosum_axes_) - 1), max(zerosum_axes_))) + 1, dtype=int
2533-
# ).tolist()
2534-
2535-
# check if zerosum_axes is valid
2536-
# normalize_axis_tuple(zerosum_axes, len(size))
25372531

25382532
shape = to_tuple(size) + tuple(support_shape)
25392533
normal_dist = ignore_logprob(pm.Normal.dist(sigma=sigma, shape=shape))
2534+
2535+
if zerosum_axes > normal_dist.ndim:
2536+
raise ValueError("Shape of distribution is too small for the number of zerosum axes")
2537+
25402538
normal_dist_, sigma_, support_shape_ = (
25412539
normal_dist.type(),
25422540
sigma.type(),
@@ -2555,7 +2553,6 @@ def rv_op(cls, sigma, zerosum_axes, support_shape, size=None):
25552553
)(normal_dist, sigma, support_shape)
25562554

25572555
# TODO:
2558-
# write __new__
25592556
# refactor ZSN tests
25602557
# test get_support_shape with 2D
25612558
# test ZSN logp

0 commit comments

Comments
 (0)