@@ -1415,10 +1415,7 @@ def test_zsn_dims_shape(self, dims, zerosum_axes, shape):
1415
1415
s = pm .sample (10 , chains = 1 , tune = 100 )
1416
1416
1417
1417
# to test forward graph
1418
- random_samples = pm .draw (
1419
- v ,
1420
- draws = 10 ,
1421
- )
1418
+ random_samples = pm .draw (v , draws = 10 )
1422
1419
1423
1420
assert s .posterior .v .shape == (1 , 10 , len (COORDS ["regions" ]), len (COORDS ["answers" ]))
1424
1421
@@ -1475,14 +1472,39 @@ def test_zsn_fail_axis(self, dims, zerosum_axes):
1475
1472
with pm .Model (coords = COORDS ) as m :
1476
1473
_ = pm .ZeroSumNormal ("v" , dims = dims , zerosum_axes = zerosum_axes )
1477
1474
1478
- def test_zsn_change_dist_size (self ):
1479
- base_dist = pm .ZeroSumNormal .dist (shape = (4 , 9 ))
1475
+ @pytest .mark .parametrize (
1476
+ "zerosum_axes" ,
1477
+ [(- 1 ), (- 2 ), (1 ), ((0 , 1 )), ((- 2 , - 1 ))],
1478
+ )
1479
+ def test_zsn_change_dist_size (self , zerosum_axes ):
1480
+ base_dist = pm .ZeroSumNormal .dist (shape = (4 , 9 ), zerosum_axes = zerosum_axes )
1481
+ random_samples = pm .draw (base_dist , draws = 100 )
1482
+
1483
+ if not isinstance (zerosum_axes , (list , tuple )):
1484
+ zerosum_axes = [zerosum_axes ]
1485
+ self .assert_zerosum_axes (random_samples , zerosum_axes )
1480
1486
1481
1487
new_dist = change_dist_size (base_dist , new_size = (5 , 3 ), expand = False )
1482
1488
assert new_dist .eval ().shape == (5 , 3 )
1489
+ random_samples = pm .draw (new_dist , draws = 100 )
1490
+ self .assert_zerosum_axes (random_samples , zerosum_axes )
1483
1491
1484
1492
new_dist = change_dist_size (base_dist , new_size = (5 , 3 ), expand = True )
1485
1493
assert new_dist .eval ().shape == (5 , 3 , 4 , 9 )
1494
+ random_samples = pm .draw (new_dist , draws = 100 )
1495
+ self .assert_zerosum_axes (random_samples , zerosum_axes )
1496
+
1497
+ def assert_zerosum_axes (self , random_samples , zerosum_axes ):
1498
+ for ax in zerosum_axes :
1499
+ if ax < 0 :
1500
+ assert np .isclose (
1501
+ random_samples .mean (axis = ax ), 0
1502
+ ).all (), f"{ ax } is a zerosum_axis but is not summing to 0 across all samples."
1503
+ else :
1504
+ ax = ax + 1
1505
+ assert np .isclose (
1506
+ random_samples .mean (axis = ax ), 0
1507
+ ).all (), f"{ ax } is a zerosum_axis but is not summing to 0 across all samples."
1486
1508
1487
1509
1488
1510
class TestMvStudentTCov (BaseTestDistributionRandom ):
0 commit comments