@@ -1388,6 +1388,18 @@ def test_issue_3706(self):
1388
1388
1389
1389
1390
1390
class TestZeroSumNormal :
1391
+ def assert_zerosum_axes (self , random_samples , axes_to_check , check_zerosum_axes = True ):
1392
+ if check_zerosum_axes :
1393
+ for ax in axes_to_check :
1394
+ assert np .isclose (
1395
+ random_samples .mean (axis = ax ), 0
1396
+ ).all (), f"{ ax } is a zerosum_axis but is not summing to 0 across all samples."
1397
+ else :
1398
+ for ax in axes_to_check :
1399
+ assert not np .isclose (
1400
+ random_samples .mean (axis = ax ), 0
1401
+ ).all (), f"{ ax } is not a zerosum_axis, but is nonetheless summing to 0 across all samples."
1402
+
1391
1403
@pytest .mark .parametrize (
1392
1404
"dims, zerosum_axes" ,
1393
1405
[
@@ -1504,18 +1516,6 @@ def test_zsn_change_dist_size(self, zerosum_axes):
1504
1516
random_samples = pm .draw (new_dist , draws = 100 )
1505
1517
self .assert_zerosum_axes (random_samples , zerosum_axes )
1506
1518
1507
- def assert_zerosum_axes (self , random_samples , axes_to_check , check_zerosum_axes = True ):
1508
- if check_zerosum_axes :
1509
- for ax in axes_to_check :
1510
- assert np .isclose (
1511
- random_samples .mean (axis = ax ), 0
1512
- ).all (), f"{ ax } is a zerosum_axis but is not summing to 0 across all samples."
1513
- else :
1514
- for ax in axes_to_check :
1515
- assert not np .isclose (
1516
- random_samples .mean (axis = ax ), 0
1517
- ).all (), f"{ ax } is not a zerosum_axis, but is nonetheless summing to 0 across all samples."
1518
-
1519
1519
@pytest .mark .parametrize (
1520
1520
"sigma, n" ,
1521
1521
[
0 commit comments