Skip to content

Commit 48dafe9

Browse files
committed
Nicer format for ZSN logp test
1 parent ba5f3a1 commit 48dafe9

File tree

1 file changed

+28
-30
lines changed

1 file changed

+28
-30
lines changed

pymc/tests/distributions/test_multivariate.py

Lines changed: 28 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1543,41 +1543,39 @@ def test_zsn_variance(self, sigma, n):
15431543
],
15441544
)
15451545
def test_zsn_logp(self, sigma, shape, zerosum_axes, mvn_axes):
1546+
def logp_norm(value, sigma, axes):
1547+
"""
1548+
Special case of the MvNormal, that's equivalent to the ZSN.
1549+
Only to test the ZSN logp
1550+
"""
1551+
axes = [ax if ax >= 0 else value.ndim + ax for ax in axes]
1552+
if len(set(axes)) < len(axes):
1553+
raise ValueError("Must specify unique zero sum axes")
1554+
other_axes = [ax for ax in range(value.ndim) if ax not in axes]
1555+
new_order = other_axes + axes
1556+
reshaped_value = np.reshape(
1557+
np.transpose(value, new_order), [value.shape[ax] for ax in other_axes] + [-1]
1558+
)
15461559

1547-
zsn_dist = pm.ZeroSumNormal.dist(sigma=sigma, shape=shape, zerosum_axes=zerosum_axes)
1548-
zsn_logp = pm.logp(zsn_dist, value=np.zeros(shape)).eval()
1549-
mvn_logp = self.logp_norm(value=np.zeros(shape), sigma=sigma, axes=mvn_axes)
1560+
degrees_of_freedom = np.prod([value.shape[ax] - 1 for ax in axes])
1561+
full_size = np.prod([value.shape[ax] for ax in axes])
15501562

1551-
np.testing.assert_allclose(zsn_logp, mvn_logp)
1563+
psdet = (0.5 * np.log(2 * np.pi) + np.log(sigma)) * degrees_of_freedom / full_size
1564+
exp = 0.5 * (reshaped_value / sigma) ** 2
1565+
inds = np.ones_like(value, dtype="bool")
1566+
for ax in axes:
1567+
inds = np.logical_and(inds, np.abs(np.mean(value, axis=ax, keepdims=True)) < 1e-9)
1568+
inds = np.reshape(
1569+
np.transpose(inds, new_order), [value.shape[ax] for ax in other_axes] + [-1]
1570+
)[..., 0]
15521571

1553-
def logp_norm(self, value, sigma, axes):
1554-
"""
1555-
Special case of the MvNormal, that's equivalent to the ZSN.
1556-
Only to test the ZSN logp
1557-
"""
1558-
axes = [ax if ax >= 0 else value.ndim + ax for ax in axes]
1559-
if len(set(axes)) < len(axes):
1560-
raise ValueError("Must specify unique zero sum axes")
1561-
other_axes = [ax for ax in range(value.ndim) if ax not in axes]
1562-
new_order = other_axes + axes
1563-
reshaped_value = np.reshape(
1564-
np.transpose(value, new_order), [value.shape[ax] for ax in other_axes] + [-1]
1565-
)
1566-
1567-
degrees_of_freedom = np.prod([value.shape[ax] - 1 for ax in axes])
1568-
full_size = np.prod([value.shape[ax] for ax in axes])
1572+
return np.where(inds, np.sum(-psdet - exp, axis=-1), -np.inf)
15691573

1570-
ns = value.shape[-1]
1571-
psdet = (0.5 * np.log(2 * np.pi) + np.log(sigma)) * degrees_of_freedom / full_size
1572-
exp = 0.5 * (reshaped_value / sigma) ** 2
1573-
inds = np.ones_like(value, dtype="bool")
1574-
for ax in axes:
1575-
inds = np.logical_and(inds, np.abs(np.mean(value, axis=ax, keepdims=True)) < 1e-9)
1576-
inds = np.reshape(
1577-
np.transpose(inds, new_order), [value.shape[ax] for ax in other_axes] + [-1]
1578-
)[..., 0]
1574+
zsn_dist = pm.ZeroSumNormal.dist(sigma=sigma, shape=shape, zerosum_axes=zerosum_axes)
1575+
zsn_logp = pm.logp(zsn_dist, value=np.zeros(shape)).eval()
1576+
mvn_logp = logp_norm(value=np.zeros(shape), sigma=sigma, axes=mvn_axes)
15791577

1580-
return np.where(inds, np.sum(-psdet - exp, axis=-1), -np.inf)
1578+
np.testing.assert_allclose(zsn_logp, mvn_logp)
15811579

15821580

15831581
class TestMvStudentTCov(BaseTestDistributionRandom):

0 commit comments

Comments
 (0)