@@ -1401,12 +1401,12 @@ class TestZeroSumNormal:
1401
1401
@pytest .mark .parametrize (
1402
1402
"dims, zerosum_axes, shape" ,
1403
1403
[
1404
- (("regions" , "answers" ), "answers" , None ),
1405
- (("regions" , "answers" ), ( "regions" , "answers" ) , None ),
1406
- (("regions" , "answers" ), 0 , None ),
1407
- (( "regions" , "answers" ), - 1 , None ),
1408
- (( "regions" , "answers" ), ( 0 , 1 ), None ),
1409
- (None , - 2 , (len (COORDS ["regions" ]), len (COORDS ["answers" ]))),
1404
+ (("regions" , "answers" ), None , None ),
1405
+ (("regions" , "answers" ), 1 , None ),
1406
+ (("regions" , "answers" ), 2 , None ),
1407
+ (None , None , ( len ( COORDS [ "regions" ]), len ( COORDS [ "answers" ])) ),
1408
+ (None , 1 , ( len ( COORDS [ "regions" ]), len ( COORDS [ "answers" ])) ),
1409
+ (None , 2 , (len (COORDS ["regions" ]), len (COORDS ["answers" ]))),
1410
1410
],
1411
1411
)
1412
1412
def test_zsn_dims_shape (self , dims , zerosum_axes , shape ):
@@ -1419,41 +1419,27 @@ def test_zsn_dims_shape(self, dims, zerosum_axes, shape):
1419
1419
1420
1420
assert s .posterior .v .shape == (1 , 10 , len (COORDS ["regions" ]), len (COORDS ["answers" ]))
1421
1421
1422
- if not isinstance (zerosum_axes , (list , tuple )):
1423
- zerosum_axes = [zerosum_axes ]
1422
+ zerosum_axes = np .arange (- v .owner .op .ndim_supp , 0 )
1423
+ nonzero_axes = np .arange (v .ndim - v .owner .op .ndim_supp )
1424
+
1425
+ for ax in zerosum_axes :
1426
+ for samples in [
1427
+ s .posterior .v .mean (axis = ax ),
1428
+ random_samples .mean (axis = ax ),
1429
+ ]:
1430
+ assert np .isclose (
1431
+ samples , 0
1432
+ ).all (), f"{ ax } is a zerosum_axis but is not summing to 0 across all samples."
1424
1433
1425
- if isinstance ( zerosum_axes [ 0 ], str ) :
1426
- for ax in zerosum_axes :
1434
+ if nonzero_axes :
1435
+ for ax in nonzero_axes :
1427
1436
for samples in [
1428
- s .posterior .v .mean (dim = ax ),
1429
- random_samples .mean (axis = dims . index ( ax ) + 1 ),
1437
+ s .posterior .v .mean (axis = ax ),
1438
+ random_samples .mean (axis = ax ),
1430
1439
]:
1431
- assert np .isclose (
1440
+ assert not np .isclose (
1432
1441
samples , 0
1433
- ).all (), f"{ ax } is a zerosum_axis but is not summing to 0 across all samples."
1434
-
1435
- nonzero_axes = list (set (dims ).difference (zerosum_axes ))
1436
- if nonzero_axes :
1437
- for ax in nonzero_axes :
1438
- for samples in [
1439
- s .posterior .v .mean (dim = ax ),
1440
- random_samples .mean (axis = dims .index (ax ) + 1 ),
1441
- ]:
1442
- assert not np .isclose (
1443
- samples , 0
1444
- ).all (), f"{ ax } is not a zerosum_axis, but is nonetheless summing to 0 across all samples."
1445
-
1446
- else :
1447
- for ax in zerosum_axes :
1448
- if ax < 0 :
1449
- assert np .isclose (
1450
- s .posterior .v .mean (axis = ax ), 0
1451
- ).all (), f"{ ax } is a zerosum_axis but is not summing to 0 across all samples."
1452
- else :
1453
- ax = ax + 2 # because 'chain' and 'draw' are added as new axes after sampling
1454
- assert np .isclose (
1455
- s .posterior .v .mean (axis = ax ), 0
1456
- ).all (), f"{ ax } is a zerosum_axis but is not summing to 0 across all samples."
1442
+ ).all (), f"{ ax } is not a zerosum_axis, but is nonetheless summing to 0 across all samples."
1457
1443
1458
1444
@pytest .mark .parametrize (
1459
1445
"dims, zerosum_axes" ,
0 commit comments