@@ -1432,14 +1432,14 @@ def test_zsn_dims_shape(self, dims, zerosum_axes, shape):
1432
1432
"error, match, shape, support_shape, zerosum_axes" ,
1433
1433
[
1434
1434
(IndexError , "index out of range" , (3 , 4 , 5 ), None , 4 ),
1435
- (AssertionError , "does not match" , (3 , 4 ), 3 , None ), # support_shape should be 4
1435
+ (AssertionError , "does not match" , (3 , 4 ), ( 3 ,) , None ), # support_shape should be 4
1436
1436
(
1437
1437
AssertionError ,
1438
1438
"does not match" ,
1439
1439
(3 , 4 ),
1440
1440
(3 , 4 ),
1441
1441
None ,
1442
- ), # doesn't work because zerosum_axes = 1
1442
+ ), # doesn't work because zerosum_axes = 1 by default
1443
1443
],
1444
1444
)
1445
1445
def test_zsn_fail_axis (self , error , match , shape , support_shape , zerosum_axes ):
@@ -1449,9 +1449,20 @@ def test_zsn_fail_axis(self, error, match, shape, support_shape, zerosum_axes):
1449
1449
"v" , shape = shape , support_shape = support_shape , zerosum_axes = zerosum_axes
1450
1450
)
1451
1451
1452
- # v = pm.ZeroSumNormal("v", support_shape=(3, 4), zerosum_axes=2) # should work
1452
+ @pytest .mark .parametrize (
1453
+ "shape, support_shape" ,
1454
+ [
1455
+ (None , (3 , 4 )),
1456
+ ((3 , 4 ), (3 , 4 )),
1457
+ ],
1458
+ )
1459
+ def test_zsn_support_shape (self , shape , support_shape ):
1460
+ with pm .Model () as m :
1461
+ v = pm .ZeroSumNormal ("v" , shape = shape , support_shape = support_shape , zerosum_axes = 2 )
1453
1462
1454
- # v = pm.ZeroSumNormal("v", shape=(3, 4), support_shape=(3, 4), zerosum_axes=2) this should work but doesn't
1463
+ random_samples = pm .draw (v , draws = 10 )
1464
+ zerosum_axes = np .arange (- 2 , 0 )
1465
+ self .assert_zerosum_axes (random_samples , zerosum_axes )
1455
1466
1456
1467
@pytest .mark .parametrize (
1457
1468
"zerosum_axes" ,
@@ -1465,9 +1476,9 @@ def test_zsn_change_dist_size(self, zerosum_axes):
1465
1476
self .assert_zerosum_axes (random_samples , zerosum_axes )
1466
1477
1467
1478
new_dist = change_dist_size (base_dist , new_size = (5 , 3 ), expand = False )
1468
- if zerosum_axes == 1 :
1479
+ try :
1469
1480
assert new_dist .eval ().shape == (5 , 3 , 9 )
1470
- elif zerosum_axes == 2 :
1481
+ except AssertionError :
1471
1482
assert new_dist .eval ().shape == (5 , 3 , 4 , 9 )
1472
1483
random_samples = pm .draw (new_dist , draws = 100 )
1473
1484
self .assert_zerosum_axes (random_samples , zerosum_axes )
0 commit comments