Skip to content

Commit f363118

Browse files
committed
Add ZSN logp test
1 parent 85da56c commit f363118

File tree

1 file changed

+46
-0
lines changed

1 file changed

+46
-0
lines changed

pymc/tests/distributions/test_multivariate.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1543,6 +1543,52 @@ def test_zsn_variance(self, sigma, n):
15431543

15441544
np.testing.assert_allclose(empirical_var, theoretical_var, rtol=1e-02)
15451545

1546+
@pytest.mark.parametrize(
1547+
"sigma, shape, zerosum_axes, mvn_axes",
1548+
[
1549+
(5, 3, None, [-1]),
1550+
(2, 6, None, [-1]),
1551+
(5, (7, 3), None, [-1]),
1552+
(5, (2, 7, 3), 2, [1, 2]),
1553+
],
1554+
)
1555+
def test_zsn_logp(self, sigma, shape, zerosum_axes, mvn_axes):
1556+
1557+
zsn_dist = pm.ZeroSumNormal.dist(sigma=sigma, shape=shape, zerosum_axes=zerosum_axes)
1558+
zsn_logp = pm.logp(zsn_dist, value=np.zeros(shape)).eval()
1559+
mvn_logp = self.logp_norm(value=np.zeros(shape), sigma=sigma, axes=mvn_axes)
1560+
1561+
np.testing.assert_allclose(zsn_logp, mvn_logp)
1562+
1563+
def logp_norm(self, value, sigma, axes):
1564+
"""
1565+
Special case of the MvNormal, that's equivalent to the ZSN.
1566+
Only to test the ZSN logp
1567+
"""
1568+
axes = [ax if ax >= 0 else value.ndim + ax for ax in axes]
1569+
if len(set(axes)) < len(axes):
1570+
raise ValueError("Must specify unique zero sum axes")
1571+
other_axes = [ax for ax in range(value.ndim) if ax not in axes]
1572+
new_order = other_axes + axes
1573+
reshaped_value = np.reshape(
1574+
np.transpose(value, new_order), [value.shape[ax] for ax in other_axes] + [-1]
1575+
)
1576+
1577+
degrees_of_freedom = np.prod([value.shape[ax] - 1 for ax in axes])
1578+
full_size = np.prod([value.shape[ax] for ax in axes])
1579+
1580+
ns = value.shape[-1]
1581+
psdet = (0.5 * np.log(2 * np.pi) + np.log(sigma)) * degrees_of_freedom / full_size
1582+
exp = 0.5 * (reshaped_value / sigma) ** 2
1583+
inds = np.ones_like(value, dtype="bool")
1584+
for ax in axes:
1585+
inds = np.logical_and(inds, np.abs(np.mean(value, axis=ax, keepdims=True)) < 1e-9)
1586+
inds = np.reshape(
1587+
np.transpose(inds, new_order), [value.shape[ax] for ax in other_axes] + [-1]
1588+
)[..., 0]
1589+
1590+
return np.where(inds, np.sum(-psdet - exp, axis=-1), -np.inf)
1591+
15461592

15471593
class TestMvStudentTCov(BaseTestDistributionRandom):
15481594
def mvstudentt_rng_fn(self, size, nu, mu, cov, rng):

0 commit comments

Comments
 (0)