@@ -2437,13 +2437,16 @@ class ZeroSumNormal(Distribution):
2437
2437
"answers": ["yes", "no", "whatever", "don't understand question"],
2438
2438
}
2439
2439
with pm.Model(coords=COORDS) as m:
2440
- ...: v = pm.ZeroSumNormal("v", dims=("regions", "answers"), zerosum_axes="answers")
2440
+ # the zero sum axis will be 'answers'
2441
+ ...: v = pm.ZeroSumNormal("v", dims=("regions", "answers"))
2441
2442
2442
2443
with pm.Model(coords=COORDS) as m:
2443
- ...: v = pm.ZeroSumNormal("v", dims=("regions", "answers"), zerosum_axes=("regions", "answers"))
2444
+ # the zero sum axes will be 'answers' and 'regions'
2445
+ ...: v = pm.ZeroSumNormal("v", dims=("regions", "answers"), zerosum_axes=2)
2444
2446
2445
2447
with pm.Model(coords=COORDS) as m:
2446
- ...: v = pm.ZeroSumNormal("v", dims=("regions", "answers"), zerosum_axes=1)
2448
+ # the zero sum axes will be the last two
2449
+ ...: v = pm.ZeroSumNormal("v", shape=(3, 4, 5), zerosum_axes=2)
2447
2450
"""
2448
2451
rv_type = ZeroSumNormalRV
2449
2452
@@ -2525,18 +2528,13 @@ def check_zerosum_axes(cls, zerosum_axes: Optional[int]) -> int:
2525
2528
2526
2529
@classmethod
2527
2530
def rv_op (cls , sigma , zerosum_axes , support_shape , size = None ):
2528
- # if size is None:
2529
- # zerosum_axes_ = np.asarray(zerosum_axes)
2530
- # # just a placeholder size to infer minimum shape
2531
- # size = np.ones(
2532
- # max((max(np.abs(zerosum_axes_) - 1), max(zerosum_axes_))) + 1, dtype=int
2533
- # ).tolist()
2534
-
2535
- # check if zerosum_axes is valid
2536
- # normalize_axis_tuple(zerosum_axes, len(size))
2537
2531
2538
2532
shape = to_tuple (size ) + tuple (support_shape )
2539
2533
normal_dist = ignore_logprob (pm .Normal .dist (sigma = sigma , shape = shape ))
2534
+
2535
+ if zerosum_axes > normal_dist .ndim :
2536
+ raise ValueError ("Shape of distribution is too small for the number of zerosum axes" )
2537
+
2540
2538
normal_dist_ , sigma_ , support_shape_ = (
2541
2539
normal_dist .type (),
2542
2540
sigma .type (),
@@ -2555,7 +2553,6 @@ def rv_op(cls, sigma, zerosum_axes, support_shape, size=None):
2555
2553
)(normal_dist , sigma , support_shape )
2556
2554
2557
2555
# TODO:
2558
- # write __new__
2559
2556
# refactor ZSN tests
2560
2557
# test get_support_shape with 2D
2561
2558
# test ZSN logp
0 commit comments