@@ -1543,41 +1543,39 @@ def test_zsn_variance(self, sigma, n):
1543
1543
],
1544
1544
)
1545
1545
def test_zsn_logp (self , sigma , shape , zerosum_axes , mvn_axes ):
1546
+ def logp_norm (value , sigma , axes ):
1547
+ """
1548
+ Special case of the MvNormal, that's equivalent to the ZSN.
1549
+ Only to test the ZSN logp
1550
+ """
1551
+ axes = [ax if ax >= 0 else value .ndim + ax for ax in axes ]
1552
+ if len (set (axes )) < len (axes ):
1553
+ raise ValueError ("Must specify unique zero sum axes" )
1554
+ other_axes = [ax for ax in range (value .ndim ) if ax not in axes ]
1555
+ new_order = other_axes + axes
1556
+ reshaped_value = np .reshape (
1557
+ np .transpose (value , new_order ), [value .shape [ax ] for ax in other_axes ] + [- 1 ]
1558
+ )
1546
1559
1547
- zsn_dist = pm .ZeroSumNormal .dist (sigma = sigma , shape = shape , zerosum_axes = zerosum_axes )
1548
- zsn_logp = pm .logp (zsn_dist , value = np .zeros (shape )).eval ()
1549
- mvn_logp = self .logp_norm (value = np .zeros (shape ), sigma = sigma , axes = mvn_axes )
1560
+ degrees_of_freedom = np .prod ([value .shape [ax ] - 1 for ax in axes ])
1561
+ full_size = np .prod ([value .shape [ax ] for ax in axes ])
1550
1562
1551
- np .testing .assert_allclose (zsn_logp , mvn_logp )
1563
+ psdet = (0.5 * np .log (2 * np .pi ) + np .log (sigma )) * degrees_of_freedom / full_size
1564
+ exp = 0.5 * (reshaped_value / sigma ) ** 2
1565
+ inds = np .ones_like (value , dtype = "bool" )
1566
+ for ax in axes :
1567
+ inds = np .logical_and (inds , np .abs (np .mean (value , axis = ax , keepdims = True )) < 1e-9 )
1568
+ inds = np .reshape (
1569
+ np .transpose (inds , new_order ), [value .shape [ax ] for ax in other_axes ] + [- 1 ]
1570
+ )[..., 0 ]
1552
1571
1553
- def logp_norm (self , value , sigma , axes ):
1554
- """
1555
- Special case of the MvNormal, that's equivalent to the ZSN.
1556
- Only to test the ZSN logp
1557
- """
1558
- axes = [ax if ax >= 0 else value .ndim + ax for ax in axes ]
1559
- if len (set (axes )) < len (axes ):
1560
- raise ValueError ("Must specify unique zero sum axes" )
1561
- other_axes = [ax for ax in range (value .ndim ) if ax not in axes ]
1562
- new_order = other_axes + axes
1563
- reshaped_value = np .reshape (
1564
- np .transpose (value , new_order ), [value .shape [ax ] for ax in other_axes ] + [- 1 ]
1565
- )
1566
-
1567
- degrees_of_freedom = np .prod ([value .shape [ax ] - 1 for ax in axes ])
1568
- full_size = np .prod ([value .shape [ax ] for ax in axes ])
1572
+ return np .where (inds , np .sum (- psdet - exp , axis = - 1 ), - np .inf )
1569
1573
1570
- ns = value .shape [- 1 ]
1571
- psdet = (0.5 * np .log (2 * np .pi ) + np .log (sigma )) * degrees_of_freedom / full_size
1572
- exp = 0.5 * (reshaped_value / sigma ) ** 2
1573
- inds = np .ones_like (value , dtype = "bool" )
1574
- for ax in axes :
1575
- inds = np .logical_and (inds , np .abs (np .mean (value , axis = ax , keepdims = True )) < 1e-9 )
1576
- inds = np .reshape (
1577
- np .transpose (inds , new_order ), [value .shape [ax ] for ax in other_axes ] + [- 1 ]
1578
- )[..., 0 ]
1574
+ zsn_dist = pm .ZeroSumNormal .dist (sigma = sigma , shape = shape , zerosum_axes = zerosum_axes )
1575
+ zsn_logp = pm .logp (zsn_dist , value = np .zeros (shape )).eval ()
1576
+ mvn_logp = logp_norm (value = np .zeros (shape ), sigma = sigma , axes = mvn_axes )
1579
1577
1580
- return np .where ( inds , np . sum ( - psdet - exp , axis = - 1 ), - np . inf )
1578
+ np .testing . assert_allclose ( zsn_logp , mvn_logp )
1581
1579
1582
1580
1583
1581
class TestMvStudentTCov (BaseTestDistributionRandom ):
0 commit comments