@@ -1418,27 +1418,15 @@ def test_zsn_dims_shape(self, dims, zerosum_axes, shape):
1418
1418
1419
1419
assert s .posterior .v .shape == (1 , 10 , len (COORDS ["regions" ]), len (COORDS ["answers" ]))
1420
1420
1421
- zerosum_axes = np .arange (- v .owner .op .ndim_supp , 0 )
1422
- nonzero_axes = np .arange (v .ndim - v .owner .op .ndim_supp )
1423
-
1424
- for ax in zerosum_axes :
1425
- for samples in [
1426
- s .posterior .v .mean (axis = ax ),
1427
- random_samples .mean (axis = ax ),
1428
- ]:
1429
- assert np .isclose (
1430
- samples , 0
1431
- ).all (), f"{ ax } is a zerosum_axis but is not summing to 0 across all samples."
1432
-
1433
- if nonzero_axes :
1434
- for ax in nonzero_axes :
1435
- for samples in [
1436
- s .posterior .v .mean (axis = ax ),
1437
- random_samples .mean (axis = ax ),
1438
- ]:
1439
- assert not np .isclose (
1440
- samples , 0
1441
- ).all (), f"{ ax } is not a zerosum_axis, but is nonetheless summing to 0 across all samples."
1421
+ ndim_supp = v .owner .op .ndim_supp
1422
+ zerosum_axes = np .arange (- ndim_supp , 0 )
1423
+ nonzero_axes = np .arange (v .ndim - ndim_supp )
1424
+ for samples in [
1425
+ s .posterior .v ,
1426
+ random_samples ,
1427
+ ]:
1428
+ self .assert_zerosum_axes (samples , zerosum_axes )
1429
+ self .assert_zerosum_axes (samples , nonzero_axes , check_zerosum_axes = False )
1442
1430
1443
1431
@pytest .mark .parametrize (
1444
1432
"error, match, shape, support_shape, zerosum_axes" ,
@@ -1473,6 +1461,7 @@ def test_zsn_change_dist_size(self, zerosum_axes):
1473
1461
base_dist = pm .ZeroSumNormal .dist (shape = (4 , 9 ), zerosum_axes = zerosum_axes )
1474
1462
random_samples = pm .draw (base_dist , draws = 100 )
1475
1463
1464
+ zerosum_axes = np .arange (- zerosum_axes , 0 )
1476
1465
self .assert_zerosum_axes (random_samples , zerosum_axes )
1477
1466
1478
1467
new_dist = change_dist_size (base_dist , new_size = (5 , 3 ), expand = False )
@@ -1488,12 +1477,17 @@ def test_zsn_change_dist_size(self, zerosum_axes):
1488
1477
random_samples = pm .draw (new_dist , draws = 100 )
1489
1478
self .assert_zerosum_axes (random_samples , zerosum_axes )
1490
1479
1491
- def assert_zerosum_axes (self , random_samples , zerosum_axes ):
1492
- zerosum_axes = np .arange (- zerosum_axes , 0 )
1493
- for ax in zerosum_axes :
1494
- assert np .isclose (
1495
- random_samples .mean (axis = ax ), 0
1496
- ).all (), f"{ ax } is a zerosum_axis but is not summing to 0 across all samples."
1480
+ def assert_zerosum_axes (self , random_samples , axes_to_check , check_zerosum_axes = True ):
1481
+ if check_zerosum_axes :
1482
+ for ax in axes_to_check :
1483
+ assert np .isclose (
1484
+ random_samples .mean (axis = ax ), 0
1485
+ ).all (), f"{ ax } is a zerosum_axis but is not summing to 0 across all samples."
1486
+ else :
1487
+ for ax in axes_to_check :
1488
+ assert not np .isclose (
1489
+ random_samples .mean (axis = ax ), 0
1490
+ ).all (), f"{ ax } is not a zerosum_axis, but is nonetheless summing to 0 across all samples."
1497
1491
1498
1492
1499
1493
class TestMvStudentTCov (BaseTestDistributionRandom ):
0 commit comments