@@ -1467,18 +1467,19 @@ def test_zsn_fail_axis(self, error, match, shape, support_shape, zerosum_axes):
1467
1467
1468
1468
@pytest .mark .parametrize (
1469
1469
"zerosum_axes" ,
1470
- [( - 1 ), ( - 2 ), ( 1 ), (( 0 , 1 )), (( - 2 , - 1 )) ],
1470
+ [1 , 2 ],
1471
1471
)
1472
1472
def test_zsn_change_dist_size (self , zerosum_axes ):
1473
1473
base_dist = pm .ZeroSumNormal .dist (shape = (4 , 9 ), zerosum_axes = zerosum_axes )
1474
1474
random_samples = pm .draw (base_dist , draws = 100 )
1475
1475
1476
- if not isinstance (zerosum_axes , (list , tuple )):
1477
- zerosum_axes = [zerosum_axes ]
1478
1476
self .assert_zerosum_axes (random_samples , zerosum_axes )
1479
1477
1480
1478
new_dist = change_dist_size (base_dist , new_size = (5 , 3 ), expand = False )
1481
- assert new_dist .eval ().shape == (5 , 3 )
1479
+ if zerosum_axes == 1 :
1480
+ assert new_dist .eval ().shape == (5 , 3 , 9 )
1481
+ elif zerosum_axes == 2 :
1482
+ assert new_dist .eval ().shape == (5 , 3 , 4 , 9 )
1482
1483
random_samples = pm .draw (new_dist , draws = 100 )
1483
1484
self .assert_zerosum_axes (random_samples , zerosum_axes )
1484
1485
@@ -1488,16 +1489,11 @@ def test_zsn_change_dist_size(self, zerosum_axes):
1488
1489
self .assert_zerosum_axes (random_samples , zerosum_axes )
1489
1490
1490
1491
def assert_zerosum_axes (self , random_samples , zerosum_axes ):
1492
+ zerosum_axes = np .arange (- zerosum_axes , 0 )
1491
1493
for ax in zerosum_axes :
1492
- if ax < 0 :
1493
- assert np .isclose (
1494
- random_samples .mean (axis = ax ), 0
1495
- ).all (), f"{ ax } is a zerosum_axis but is not summing to 0 across all samples."
1496
- else :
1497
- ax = ax + 1
1498
- assert np .isclose (
1499
- random_samples .mean (axis = ax ), 0
1500
- ).all (), f"{ ax } is a zerosum_axis but is not summing to 0 across all samples."
1494
+ assert np .isclose (
1495
+ random_samples .mean (axis = ax ), 0
1496
+ ).all (), f"{ ax } is a zerosum_axis but is not summing to 0 across all samples."
1501
1497
1502
1498
1503
1499
class TestMvStudentTCov (BaseTestDistributionRandom ):
0 commit comments