Skip to content

Commit c120f7e

Browse files
committed
Put assert_zerosum_axes at top of test class
1 parent 08c9df0 commit c120f7e

File tree

1 file changed

+12
-12
lines changed

1 file changed

+12
-12
lines changed

pymc/tests/distributions/test_multivariate.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1388,6 +1388,18 @@ def test_issue_3706(self):
13881388

13891389

13901390
class TestZeroSumNormal:
1391+
def assert_zerosum_axes(self, random_samples, axes_to_check, check_zerosum_axes=True):
1392+
if check_zerosum_axes:
1393+
for ax in axes_to_check:
1394+
assert np.isclose(
1395+
random_samples.mean(axis=ax), 0
1396+
).all(), f"{ax} is a zerosum_axis but is not summing to 0 across all samples."
1397+
else:
1398+
for ax in axes_to_check:
1399+
assert not np.isclose(
1400+
random_samples.mean(axis=ax), 0
1401+
).all(), f"{ax} is not a zerosum_axis, but is nonetheless summing to 0 across all samples."
1402+
13911403
@pytest.mark.parametrize(
13921404
"dims, zerosum_axes",
13931405
[
@@ -1504,18 +1516,6 @@ def test_zsn_change_dist_size(self, zerosum_axes):
15041516
random_samples = pm.draw(new_dist, draws=100)
15051517
self.assert_zerosum_axes(random_samples, zerosum_axes)
15061518

1507-
def assert_zerosum_axes(self, random_samples, axes_to_check, check_zerosum_axes=True):
1508-
if check_zerosum_axes:
1509-
for ax in axes_to_check:
1510-
assert np.isclose(
1511-
random_samples.mean(axis=ax), 0
1512-
).all(), f"{ax} is a zerosum_axis but is not summing to 0 across all samples."
1513-
else:
1514-
for ax in axes_to_check:
1515-
assert not np.isclose(
1516-
random_samples.mean(axis=ax), 0
1517-
).all(), f"{ax} is not a zerosum_axis, but is nonetheless summing to 0 across all samples."
1518-
15191519
@pytest.mark.parametrize(
15201520
"sigma, n",
15211521
[

0 commit comments

Comments
 (0)