@@ -1399,19 +1399,44 @@ def test_issue_3706(self):
1399
1399
1400
1400
class TestZeroSumNormal :
1401
1401
@pytest .mark .parametrize (
1402
- "dims, zerosum_axes, shape " ,
1402
+ "dims, zerosum_axes" ,
1403
1403
[
1404
- (("regions" , "answers" ), None , None ),
1405
- (("regions" , "answers" ), 1 , None ),
1406
- (("regions" , "answers" ), 2 , None ),
1407
- (None , None , (len (COORDS ["regions" ]), len (COORDS ["answers" ]))),
1408
- (None , 1 , (len (COORDS ["regions" ]), len (COORDS ["answers" ]))),
1409
- (None , 2 , (len (COORDS ["regions" ]), len (COORDS ["answers" ]))),
1404
+ (("regions" , "answers" ), None ),
1405
+ (("regions" , "answers" ), 1 ),
1406
+ (("regions" , "answers" ), 2 ),
1410
1407
],
1411
1408
)
1412
- def test_zsn_dims_shape (self , dims , zerosum_axes , shape ):
1409
+ def test_zsn_dims (self , dims , zerosum_axes ):
1413
1410
with pm .Model (coords = COORDS ) as m :
1414
- v = pm .ZeroSumNormal ("v" , dims = dims , shape = shape , zerosum_axes = zerosum_axes )
1411
+ v = pm .ZeroSumNormal ("v" , dims = dims , zerosum_axes = zerosum_axes )
1412
+ s = pm .sample (10 , chains = 1 , tune = 100 )
1413
+
1414
+ # to test forward graph
1415
+ random_samples = pm .draw (v , draws = 10 )
1416
+
1417
+ assert s .posterior .v .shape == (1 , 10 , len (COORDS ["regions" ]), len (COORDS ["answers" ]))
1418
+
1419
+ ndim_supp = v .owner .op .ndim_supp
1420
+ zerosum_axes = np .arange (- ndim_supp , 0 )
1421
+ nonzero_axes = np .arange (v .ndim - ndim_supp )
1422
+ for samples in [
1423
+ s .posterior .v ,
1424
+ random_samples ,
1425
+ ]:
1426
+ self .assert_zerosum_axes (samples , zerosum_axes )
1427
+ self .assert_zerosum_axes (samples , nonzero_axes , check_zerosum_axes = False )
1428
+
1429
+ @pytest .mark .parametrize (
1430
+ "zerosum_axes, shape" ,
1431
+ [
1432
+ (None , (len (COORDS ["regions" ]), len (COORDS ["answers" ]))),
1433
+ (1 , (len (COORDS ["regions" ]), len (COORDS ["answers" ]))),
1434
+ (2 , (len (COORDS ["regions" ]), len (COORDS ["answers" ]))),
1435
+ ],
1436
+ )
1437
+ def test_zsn_shape (self , shape , zerosum_axes ):
1438
+ with pm .Model (coords = COORDS ) as m :
1439
+ v = pm .ZeroSumNormal ("v" , shape = shape , zerosum_axes = zerosum_axes )
1415
1440
s = pm .sample (10 , chains = 1 , tune = 100 )
1416
1441
1417
1442
# to test forward graph
0 commit comments