@@ -1543,6 +1543,52 @@ def test_zsn_variance(self, sigma, n):
1543
1543
1544
1544
np .testing .assert_allclose (empirical_var , theoretical_var , rtol = 1e-02 )
1545
1545
1546
+ @pytest .mark .parametrize (
1547
+ "sigma, shape, zerosum_axes, mvn_axes" ,
1548
+ [
1549
+ (5 , 3 , None , [- 1 ]),
1550
+ (2 , 6 , None , [- 1 ]),
1551
+ (5 , (7 , 3 ), None , [- 1 ]),
1552
+ (5 , (2 , 7 , 3 ), 2 , [1 , 2 ]),
1553
+ ],
1554
+ )
1555
+ def test_zsn_logp (self , sigma , shape , zerosum_axes , mvn_axes ):
1556
+
1557
+ zsn_dist = pm .ZeroSumNormal .dist (sigma = sigma , shape = shape , zerosum_axes = zerosum_axes )
1558
+ zsn_logp = pm .logp (zsn_dist , value = np .zeros (shape )).eval ()
1559
+ mvn_logp = self .logp_norm (value = np .zeros (shape ), sigma = sigma , axes = mvn_axes )
1560
+
1561
+ np .testing .assert_allclose (zsn_logp , mvn_logp )
1562
+
1563
+ def logp_norm (self , value , sigma , axes ):
1564
+ """
1565
+ Special case of the MvNormal, that's equivalent to the ZSN.
1566
+ Only to test the ZSN logp
1567
+ """
1568
+ axes = [ax if ax >= 0 else value .ndim + ax for ax in axes ]
1569
+ if len (set (axes )) < len (axes ):
1570
+ raise ValueError ("Must specify unique zero sum axes" )
1571
+ other_axes = [ax for ax in range (value .ndim ) if ax not in axes ]
1572
+ new_order = other_axes + axes
1573
+ reshaped_value = np .reshape (
1574
+ np .transpose (value , new_order ), [value .shape [ax ] for ax in other_axes ] + [- 1 ]
1575
+ )
1576
+
1577
+ degrees_of_freedom = np .prod ([value .shape [ax ] - 1 for ax in axes ])
1578
+ full_size = np .prod ([value .shape [ax ] for ax in axes ])
1579
+
1580
+ ns = value .shape [- 1 ]
1581
+ psdet = (0.5 * np .log (2 * np .pi ) + np .log (sigma )) * degrees_of_freedom / full_size
1582
+ exp = 0.5 * (reshaped_value / sigma ) ** 2
1583
+ inds = np .ones_like (value , dtype = "bool" )
1584
+ for ax in axes :
1585
+ inds = np .logical_and (inds , np .abs (np .mean (value , axis = ax , keepdims = True )) < 1e-9 )
1586
+ inds = np .reshape (
1587
+ np .transpose (inds , new_order ), [value .shape [ax ] for ax in other_axes ] + [- 1 ]
1588
+ )[..., 0 ]
1589
+
1590
+ return np .where (inds , np .sum (- psdet - exp , axis = - 1 ), - np .inf )
1591
+
1546
1592
1547
1593
class TestMvStudentTCov (BaseTestDistributionRandom ):
1548
1594
def mvstudentt_rng_fn (self , size , nu , mu , cov , rng ):
0 commit comments