Skip to content

Commit 7f76f23

Browse files
committed
Implement Dim ZeroSumNormal
1 parent e3e9e42 commit 7f76f23

File tree

5 files changed

+183
-5
lines changed

5 files changed

+183
-5
lines changed

pymc/dims/distributions/vector.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,15 @@
1414
import pytensor.xtensor as ptx
1515
import pytensor.xtensor.random as ptxr
1616

17+
from pytensor.tensor import as_tensor
1718
from pytensor.tensor.random.utils import normalize_size_param
19+
from pytensor.xtensor import as_xtensor
1820
from pytensor.xtensor import random as pxr
1921

2022
from pymc.dims.distributions.core import VectorDimDistribution
23+
from pymc.dims.transforms import ZeroSumTransform
2124
from pymc.distributions.multivariate import ZeroSumNormalRV
25+
from pymc.util import UNSET
2226

2327

2428
class Categorical(VectorDimDistribution):
@@ -100,3 +104,62 @@ def make_node(self, rng, size, sigma, support_shape):
100104
# We need to rebuild the graph with new size type
101105
return self.rv_op(sigma, support_shape, size=size, rng=rng).owner
102106
return super().make_node(rng, size, sigma, support_shape)
107+
108+
109+
class ZeroSumNormal(VectorDimDistribution):
110+
@classmethod
111+
def __new__(
112+
cls, *args, core_dims=None, dims=None, default_transform=UNSET, observed=None, **kwargs
113+
):
114+
if core_dims is not None:
115+
if isinstance(core_dims, str):
116+
core_dims = (core_dims,)
117+
118+
# Create default_transform
119+
if observed is None and default_transform is UNSET:
120+
default_transform = ZeroSumTransform(dims=core_dims)
121+
122+
# If the user didn't specify dims, take it from core_dims
123+
# We need them to be forwarded to dist in the `dims_dict` argument
124+
if dims is None and core_dims is not None:
125+
dims = (..., *core_dims)
126+
127+
return super().__new__(
128+
*args,
129+
core_dims=core_dims,
130+
dims=dims,
131+
default_transform=default_transform,
132+
observed=observed,
133+
**kwargs,
134+
)
135+
136+
@classmethod
137+
def dist(cls, sigma=1.0, *, core_dims=None, dims_dict, **kwargs):
138+
if isinstance(core_dims, str):
139+
core_dims = (core_dims,)
140+
if core_dims is None or len(core_dims) == 0:
141+
raise ValueError("ZeroSumNormal requires atleast 1 core_dims")
142+
143+
support_dims = as_xtensor(
144+
as_tensor([dims_dict[core_dim] for core_dim in core_dims]), dims=("_",)
145+
)
146+
sigma = cls._as_xtensor(sigma)
147+
148+
return super().dist(
149+
[sigma, support_dims], core_dims=core_dims, dims_dict=dims_dict, **kwargs
150+
)
151+
152+
@classmethod
153+
def xrv_op(self, sigma, support_dims, core_dims, extra_dims=None, rng=None):
154+
sigma = as_xtensor(sigma)
155+
support_dims = as_xtensor(support_dims, dims=("_",))
156+
support_shape = support_dims.values
157+
core_rv = DimZeroSumNormalRV.rv_op(sigma=sigma.values, support_shape=support_shape).owner.op
158+
xop = pxr._as_xrv(
159+
core_rv,
160+
core_inps_dims_map=[(), (0,)],
161+
core_out_dims_map=tuple(range(1, len(core_dims) + 1)),
162+
)
163+
# Dummy "_" core dim to absorb the support_shape vector
164+
# If ZeroSumNormal expected a scalar per support dim, this wouldn't be needed
165+
return xop(sigma, support_dims, core_dims=("_", *core_dims), extra_dims=extra_dims, rng=rng)

pymc/dims/transforms.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import pytensor.tensor as pt
1415
import pytensor.xtensor as ptx
1516

1617
from pymc.logprob.transforms import Transform
@@ -51,3 +52,44 @@ def log_jac_det(self, value, *inputs):
5152

5253

5354
log_odds_transform = LogOddsTransform()
55+
56+
57+
class ZeroSumTransform(DimTransform):
58+
name = "zerosum"
59+
60+
def __init__(self, dims: tuple[str, ...]):
61+
self.dims = dims
62+
63+
@staticmethod
64+
def extend_dim(array, dim):
65+
n = (array.sizes[dim] + 1).astype("floatX")
66+
sum_vals = array.sum(dim)
67+
norm = sum_vals / (pt.sqrt(n) + n)
68+
fill_val = norm - sum_vals / pt.sqrt(n)
69+
70+
out = ptx.concat([array, fill_val], dim=dim)
71+
return out - norm
72+
73+
@staticmethod
74+
def reduce_dim(array, dim):
75+
n = array.sizes[dim].astype("floatX")
76+
last = array.isel({dim: -1})
77+
78+
sum_vals = -last * pt.sqrt(n)
79+
norm = sum_vals / (pt.sqrt(n) + n)
80+
return array.isel({dim: slice(None, -1)}) + norm
81+
82+
def forward(self, value, *rv_inputs):
83+
for dim in self.dims:
84+
value = self.reduce_dim(value, dim=dim)
85+
return value
86+
87+
def backward(self, value, *rv_inputs):
88+
for dim in self.dims:
89+
value = self.extend_dim(value, dim=dim)
90+
return value
91+
92+
def log_jac_det(self, value, *rv_inputs):
93+
# Use following once broadcast_like is implemented
94+
# as_xtensor(0).broadcast_like(value, exclude=self.dims)`
95+
return value.sum(self.dims) * 0

pymc/distributions/distribution.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -367,6 +367,7 @@ def __init__(
367367

368368
kwargs.setdefault("inline", True)
369369
kwargs.setdefault("strict", True)
370+
kwargs.setdefault("on_unused_input", "ignore")
370371
super().__init__(*args, **kwargs)
371372

372373
def update(self, node: Apply) -> dict[Variable, Variable]:

pymc/distributions/multivariate.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2664,6 +2664,7 @@ def logp(value, alpha, K):
26642664
class ZeroSumNormalRV(SymbolicRandomVariable):
26652665
"""ZeroSumNormal random variable."""
26662666

2667+
name = "ZeroSumNormal"
26672668
_print_name = ("ZeroSumNormal", "\\operatorname{ZeroSumNormal}")
26682669

26692670
@classmethod
@@ -2687,12 +2688,12 @@ def rv_op(cls, sigma, support_shape, *, size=None, rng=None):
26872688
zerosum_rv -= zerosum_rv.mean(axis=-axis - 1, keepdims=True)
26882689

26892690
support_str = ",".join([f"d{i}" for i in range(n_zerosum_axes)])
2690-
extended_signature = f"[rng],(),(s),[size]->[rng],({support_str})"
2691-
return ZeroSumNormalRV(
2692-
inputs=[rng, sigma, support_shape, size],
2691+
extended_signature = f"[rng],[size],(),(s)->[rng],({support_str})"
2692+
return cls(
2693+
inputs=[rng, size, sigma, support_shape],
26932694
outputs=[next_rng, zerosum_rv],
26942695
extended_signature=extended_signature,
2695-
)(rng, sigma, support_shape, size)
2696+
)(rng, size, sigma, support_shape)
26962697

26972698

26982699
class ZeroSumNormal(Distribution):
@@ -2828,7 +2829,7 @@ def zerosum_default_transform(op, rv):
28282829

28292830

28302831
@_logprob.register(ZeroSumNormalRV)
2831-
def zerosumnormal_logp(op, values, rng, sigma, support_shape, size, **kwargs):
2832+
def zerosumnormal_logp(op, values, rng, size, sigma, support_shape, **kwargs):
28322833
(value,) = values
28332834
shape = value.shape
28342835
n_zerosum_axes = op.ndim_supp

tests/dims/test_distributions.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
# Copyright 2025 - present The PyMC Developers
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import numpy as np
15+
16+
from pymc import Model, draw
17+
from pymc.dims import ZeroSumNormal
18+
from pymc.distributions import ZeroSumNormal as RegularZeroSumNormal
19+
20+
21+
def test_zerosumnormal():
22+
coords = {"time": range(5), "item": range(3)}
23+
24+
with Model(coords=coords) as model:
25+
zsn_item = ZeroSumNormal("zsn_item", core_dims="item", dims=("time", "item"))
26+
zsn_time = ZeroSumNormal("zsn_time", core_dims="time", dims=("time", "item"))
27+
zsn_item_time = ZeroSumNormal("zsn_item_time", core_dims=("item", "time"))
28+
assert zsn_item.type.dims == ("time", "item")
29+
assert zsn_time.type.dims == ("time", "item")
30+
assert zsn_item_time.type.dims == ("item", "time")
31+
32+
zsn_item_draw, zsn_time_draw, zsn_item_time_draw = draw(
33+
[zsn_item, zsn_time, zsn_item_time], random_seed=1
34+
)
35+
assert zsn_item_draw.shape == (5, 3)
36+
np.testing.assert_allclose(zsn_item_draw.mean(-1), 0, atol=1e-13)
37+
assert not np.allclose(zsn_item_draw.mean(0), 0, atol=1e-13)
38+
39+
assert zsn_time_draw.shape == (5, 3)
40+
np.testing.assert_allclose(zsn_time_draw.mean(0), 0, atol=1e-13)
41+
assert not np.allclose(zsn_time_draw.mean(-1), 0, atol=1e-13)
42+
43+
assert zsn_item_time_draw.shape == (3, 5)
44+
np.testing.assert_allclose(zsn_item_time_draw.mean(), 0, atol=1e-13)
45+
46+
with Model(coords=coords) as ref_model:
47+
# Check that the ZeroSumNormal can be used in a model
48+
RegularZeroSumNormal("zsn_item", dims=("time", "item"))
49+
RegularZeroSumNormal("zsn_time", dims=("item", "time"))
50+
RegularZeroSumNormal("zsn_item_time", n_zerosum_axes=2, dims=("item", "time"))
51+
52+
# Check initial_point and logp
53+
ip = model.initial_point()
54+
ref_ip = ref_model.initial_point()
55+
assert ip.keys() == ref_ip.keys()
56+
for i, (ip_value, ref_ip_value) in enumerate(zip(ip.values(), ref_ip.values())):
57+
if i == 1:
58+
# zsn_time is actually transposed in the original model
59+
ip_value = ip_value.T
60+
np.testing.assert_allclose(ip_value, ref_ip_value)
61+
62+
logp_fn = model.compile_logp()
63+
ref_logp_fn = ref_model.compile_logp()
64+
np.testing.assert_allclose(logp_fn(ip), ref_logp_fn(ref_ip))
65+
66+
# Test a new point
67+
rng = np.random.default_rng(68)
68+
new_ip = ip.copy()
69+
for key in new_ip:
70+
new_ip[key] += rng.uniform(size=new_ip[key].shape)
71+
np.testing.assert_allclose(logp_fn(new_ip), ref_logp_fn(new_ip))

0 commit comments

Comments
 (0)