Skip to content

Commit 35cd657

Browse files
committed
Allow batched scalar sigma in ZeroSumNormal
1 parent 0fd7b9e commit 35cd657

File tree

2 files changed

+46
-10
lines changed

2 files changed

+46
-10
lines changed

pymc/distributions/multivariate.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2591,8 +2591,7 @@ class ZeroSumNormal(Distribution):
25912591
sigma : tensor_like of float
25922592
Scale parameter (sigma > 0).
25932593
It's actually the standard deviation of the underlying, unconstrained Normal distribution.
2594-
Defaults to 1 if not specified.
2595-
For now, ``sigma`` has to be a scalar, to ensure the zero-sum constraint.
2594+
Defaults to 1 if not specified. ``sigma`` cannot have length > 1 across the zero-sum axes.
25962595
n_zerosum_axes: int, defaults to 1
25972596
Number of axes along which the zero-sum constraint is enforced, starting from the rightmost position.
25982597
Defaults to 1, i.e the rightmost axis.
@@ -2606,8 +2605,7 @@ class ZeroSumNormal(Distribution):
26062605
26072606
Warnings
26082607
--------
2609-
``sigma`` has to be a scalar, to ensure the zero-sum constraint.
2610-
The ability to specify a vector of ``sigma`` may be added in future versions.
2608+
Currently, ``sigma``cannot have length > 1 across the zero-sum axes to ensure the zero-sum constraint.
26112609
26122610
``n_zerosum_axes`` has to be > 0. If you want the behavior of ``n_zerosum_axes = 0``,
26132611
just use ``pm.Normal``.
@@ -2669,8 +2667,8 @@ def dist(cls, sigma=1, n_zerosum_axes=None, support_shape=None, **kwargs):
26692667
n_zerosum_axes = cls.check_zerosum_axes(n_zerosum_axes)
26702668

26712669
sigma = pt.as_tensor_variable(floatX(sigma))
2672-
if sigma.ndim > 0:
2673-
raise ValueError("sigma has to be a scalar")
2670+
if not all(sigma.type.broadcastable[-n_zerosum_axes:]):
2671+
raise ValueError("sigma must have length one across the zero-sum axes")
26742672

26752673
support_shape = get_support_shape(
26762674
support_shape=support_shape,
@@ -2681,9 +2679,7 @@ def dist(cls, sigma=1, n_zerosum_axes=None, support_shape=None, **kwargs):
26812679
if support_shape is None:
26822680
if n_zerosum_axes > 0:
26832681
raise ValueError("You must specify dims, shape or support_shape parameter")
2684-
# TODO: edge-case doesn't work for now, because pt.stack in get_support_shape fails
2685-
# else:
2686-
# support_shape = () # because it's just a Normal in that case
2682+
26872683
support_shape = pt.as_tensor_variable(intX(support_shape))
26882684

26892685
assert n_zerosum_axes == pt.get_vector_length(
@@ -2706,7 +2702,12 @@ def check_zerosum_axes(cls, n_zerosum_axes: Optional[int]) -> int:
27062702

27072703
@classmethod
27082704
def rv_op(cls, sigma, n_zerosum_axes, support_shape, size=None):
2709-
shape = to_tuple(size) + tuple(support_shape)
2705+
if size is not None:
2706+
shape = tuple(size) + tuple(support_shape)
2707+
else:
2708+
# Size is implied by shape of sigma
2709+
shape = tuple(sigma.shape[:-n_zerosum_axes]) + tuple(support_shape)
2710+
27102711
normal_dist = pm.Normal.dist(sigma=sigma, shape=shape)
27112712

27122713
if n_zerosum_axes > normal_dist.ndim:

tests/distributions/test_multivariate.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1704,6 +1704,41 @@ def test_does_not_upcast_to_float64(self):
17041704
pm.ZeroSumNormal("b", sigma=1, shape=(2,))
17051705
m.logp()
17061706

1707+
def test_batched_sigma(self):
1708+
sigma = pt.scalar("sigma")
1709+
core_zsn = pm.ZeroSumNormal.dist(sigma=sigma, n_zerosum_axes=2, support_shape=(3, 2))
1710+
core_test_value = pm.draw(core_zsn, random_seed=1709, givens={sigma: 2.5})
1711+
batch_test_value = np.broadcast_to(core_test_value, (5, 3, 2))
1712+
batch_test_sigma = np.arange(1, 6).astype(core_zsn.type.dtype)
1713+
ref_logp = pm.logp(core_zsn, core_test_value)
1714+
ref_logp_fn = pytensor.function([sigma], ref_logp)
1715+
expected_logp = np.stack([ref_logp_fn(test_sigma) for test_sigma in batch_test_sigma])
1716+
1717+
# Explicit batch dim from shape
1718+
batch_zsn = pm.ZeroSumNormal.dist(
1719+
sigma=batch_test_sigma[:, None, None], n_zerosum_axes=2, shape=(5, 3, 2)
1720+
)
1721+
assert pm.draw(batch_zsn).shape == (5, 3, 2)
1722+
np.testing.assert_allclose(
1723+
pm.logp(batch_zsn, batch_test_value).eval(),
1724+
expected_logp,
1725+
)
1726+
1727+
# Implicit batch dim from sigma
1728+
batch_zsn = pm.ZeroSumNormal.dist(
1729+
sigma=batch_test_sigma[:, None, None], n_zerosum_axes=2, support_shape=(3, 2)
1730+
)
1731+
assert pm.draw(batch_zsn).shape == (5, 3, 2)
1732+
np.testing.assert_allclose(
1733+
pm.logp(batch_zsn, batch_test_value).eval(),
1734+
expected_logp,
1735+
)
1736+
1737+
with pytest.raises(ValueError, match="sigma must have length one across the zero-sum axes"):
1738+
pm.ZeroSumNormal.dist(
1739+
sigma=batch_test_sigma[None, :, None], n_zerosum_axes=2, support_shape=(3, 2)
1740+
)
1741+
17071742

17081743
class TestMvStudentTCov(BaseTestDistributionRandom):
17091744
def mvstudentt_rng_fn(self, size, nu, mu, scale, rng):

0 commit comments

Comments
 (0)