Skip to content

Commit fba64f0

Browse files
committed
Implement Dim ZeroSumNormal
1 parent c1be6bc commit fba64f0

File tree

5 files changed

+202
-6
lines changed

5 files changed

+202
-6
lines changed

pymc/dims/distributions/transforms.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import pytensor.tensor as pt
1415
import pytensor.xtensor as ptx
1516

1617
from pymc.logprob.transforms import Transform
@@ -51,3 +52,44 @@ def log_jac_det(self, value, *inputs):
5152

5253

5354
log_odds_transform = LogOddsTransform()
55+
56+
57+
class ZeroSumTransform(DimTransform):
58+
name = "zerosum"
59+
60+
def __init__(self, dims: tuple[str, ...]):
61+
self.dims = dims
62+
63+
@staticmethod
64+
def extend_dim(array, dim):
65+
n = (array.sizes[dim] + 1).astype("floatX")
66+
sum_vals = array.sum(dim)
67+
norm = sum_vals / (pt.sqrt(n) + n)
68+
fill_val = norm - sum_vals / pt.sqrt(n)
69+
70+
out = ptx.concat([array, fill_val], dim=dim)
71+
return out - norm
72+
73+
@staticmethod
74+
def reduce_dim(array, dim):
75+
n = array.sizes[dim].astype("floatX")
76+
last = array.isel({dim: -1})
77+
78+
sum_vals = -last * pt.sqrt(n)
79+
norm = sum_vals / (pt.sqrt(n) + n)
80+
return array.isel({dim: slice(None, -1)}) + norm
81+
82+
def forward(self, value, *rv_inputs):
83+
for dim in self.dims:
84+
value = self.reduce_dim(value, dim=dim)
85+
return value
86+
87+
def backward(self, value, *rv_inputs):
88+
for dim in self.dims:
89+
value = self.extend_dim(value, dim=dim)
90+
return value
91+
92+
def log_jac_det(self, value, *rv_inputs):
93+
# Use following once broadcast_like is implemented
94+
# as_xtensor(0).broadcast_like(value, exclude=self.dims)`
95+
return value.sum(self.dims) * 0

pymc/dims/distributions/vector.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,14 @@
1414
import pytensor.xtensor as ptx
1515
import pytensor.xtensor.random as ptxr
1616

17+
from pytensor.tensor import as_tensor
18+
from pytensor.xtensor import as_xtensor
1719
from pytensor.xtensor import random as pxr
1820

1921
from pymc.dims.distributions.core import VectorDimDistribution
22+
from pymc.dims.distributions.transforms import ZeroSumTransform
23+
from pymc.distributions.multivariate import ZeroSumNormalRV
24+
from pymc.util import UNSET
2025

2126

2227
class Categorical(VectorDimDistribution):
@@ -114,3 +119,80 @@ def dist(cls, mu, cov=None, *, chol=None, lower=True, core_dims=None, **kwargs):
114119
cov = chol.dot(chol.rename({d0: safe_name}), dim=d1).rename({safe_name: d1})
115120

116121
return super().dist([mu, cov], core_dims=core_dims, **kwargs)
122+
123+
124+
class ZeroSumNormal(VectorDimDistribution):
125+
"""Zero-sum multivariate normal distribution.
126+
127+
Parameters
128+
----------
129+
sigma : xtensor_like, optional
130+
The standard deviation of the underlying unconstrained normal distribution.
131+
Defaults to 1.0. It cannot have core dimensions.
132+
core_dims : Sequence of str, optional
133+
The axes along which the zero-sum constraint is applied.
134+
**kwargs
135+
Additional keyword arguments used to define the distribution.
136+
137+
Returns
138+
-------
139+
XTensorVariable
140+
An xtensor variable representing the zero-sum multivariate normal distribution.
141+
"""
142+
143+
@classmethod
144+
def __new__(
145+
cls, *args, core_dims=None, dims=None, default_transform=UNSET, observed=None, **kwargs
146+
):
147+
if core_dims is not None:
148+
if isinstance(core_dims, str):
149+
core_dims = (core_dims,)
150+
151+
# Create default_transform
152+
if observed is None and default_transform is UNSET:
153+
default_transform = ZeroSumTransform(dims=core_dims)
154+
155+
# If the user didn't specify dims, take it from core_dims
156+
# We need them to be forwarded to dist in the `dim_lenghts` argument
157+
if dims is None and core_dims is not None:
158+
dims = (..., *core_dims)
159+
160+
return super().__new__(
161+
*args,
162+
core_dims=core_dims,
163+
dims=dims,
164+
default_transform=default_transform,
165+
observed=observed,
166+
**kwargs,
167+
)
168+
169+
@classmethod
170+
def dist(cls, sigma=1.0, *, core_dims=None, dim_lengths, **kwargs):
171+
if isinstance(core_dims, str):
172+
core_dims = (core_dims,)
173+
if core_dims is None or len(core_dims) == 0:
174+
raise ValueError("ZeroSumNormal requires atleast 1 core_dims")
175+
176+
support_dims = as_xtensor(
177+
as_tensor([dim_lengths[core_dim] for core_dim in core_dims]), dims=("_",)
178+
)
179+
sigma = cls._as_xtensor(sigma)
180+
181+
return super().dist(
182+
[sigma, support_dims], core_dims=core_dims, dim_lengths=dim_lengths, **kwargs
183+
)
184+
185+
@classmethod
186+
def xrv_op(self, sigma, support_dims, core_dims, extra_dims=None, rng=None):
187+
sigma = as_xtensor(sigma)
188+
support_dims = as_xtensor(support_dims, dims=("_",))
189+
support_shape = support_dims.values
190+
core_rv = ZeroSumNormalRV.rv_op(sigma=sigma.values, support_shape=support_shape).owner.op
191+
xop = pxr.as_xrv(
192+
core_rv,
193+
core_inps_dims_map=[(), (0,)],
194+
core_out_dims_map=tuple(range(1, len(core_dims) + 1)),
195+
)
196+
# Dummy "_" core dim to absorb the support_shape vector
197+
# If ZeroSumNormal expected a scalar per support dim, this wouldn't be needed
198+
return xop(sigma, support_dims, core_dims=("_", *core_dims), extra_dims=extra_dims, rng=rng)

pymc/distributions/multivariate.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2664,6 +2664,7 @@ def logp(value, alpha, K):
26642664
class ZeroSumNormalRV(SymbolicRandomVariable):
26652665
"""ZeroSumNormal random variable."""
26662666

2667+
name = "ZeroSumNormal"
26672668
_print_name = ("ZeroSumNormal", "\\operatorname{ZeroSumNormal}")
26682669

26692670
@classmethod
@@ -2687,12 +2688,12 @@ def rv_op(cls, sigma, support_shape, *, size=None, rng=None):
26872688
zerosum_rv -= zerosum_rv.mean(axis=-axis - 1, keepdims=True)
26882689

26892690
support_str = ",".join([f"d{i}" for i in range(n_zerosum_axes)])
2690-
extended_signature = f"[rng],(),(s),[size]->[rng],({support_str})"
2691-
return ZeroSumNormalRV(
2692-
inputs=[rng, sigma, support_shape, size],
2691+
extended_signature = f"[rng],[size],(),(s)->[rng],({support_str})"
2692+
return cls(
2693+
inputs=[rng, size, sigma, support_shape],
26932694
outputs=[next_rng, zerosum_rv],
26942695
extended_signature=extended_signature,
2695-
)(rng, sigma, support_shape, size)
2696+
)(rng, size, sigma, support_shape)
26962697

26972698

26982699
class ZeroSumNormal(Distribution):
@@ -2828,7 +2829,7 @@ def zerosum_default_transform(op, rv):
28282829

28292830

28302831
@_logprob.register(ZeroSumNormalRV)
2831-
def zerosumnormal_logp(op, values, rng, sigma, support_shape, size, **kwargs):
2832+
def zerosumnormal_logp(op, values, rng, size, sigma, support_shape, **kwargs):
28322833
(value,) = values
28332834
shape = value.shape
28342835
n_zerosum_axes = op.ndim_supp

tests/dims/distributions/test_vector.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import pymc.distributions as regular_distributions
2020

2121
from pymc import Model
22-
from pymc.dims import Categorical, MvNormal
22+
from pymc.dims import Categorical, MvNormal, ZeroSumNormal
2323
from tests.dims.utils import assert_equivalent_logp_graph, assert_equivalent_random_graph
2424

2525

@@ -60,3 +60,21 @@ def test_mvnormal():
6060

6161
assert_equivalent_random_graph(model, reference_model)
6262
assert_equivalent_logp_graph(model, reference_model)
63+
64+
65+
def test_zerosumnormal():
66+
coords = {"a": range(3), "b": range(2)}
67+
with Model(coords=coords) as model:
68+
ZeroSumNormal("x", core_dims=("b",), dims=("a", "b"))
69+
ZeroSumNormal("y", sigma=3, core_dims=("b",), dims=("a", "b"))
70+
ZeroSumNormal("z", core_dims=("a", "b"), dims=("a", "b"))
71+
72+
with Model(coords=coords) as reference_model:
73+
regular_distributions.ZeroSumNormal("x", dims=("a", "b"))
74+
regular_distributions.ZeroSumNormal("y", sigma=3, n_zerosum_axes=1, dims=("a", "b"))
75+
regular_distributions.ZeroSumNormal("z", n_zerosum_axes=2, dims=("a", "b"))
76+
77+
assert_equivalent_random_graph(model, reference_model)
78+
# Logp is correct, but we have join(..., -1) and join(..., 1), that don't get canonicalized to the same
79+
# Should work once https://github.com/pymc-devs/pytensor/issues/1505 is fixed
80+
# assert_equivalent_logp_graph(model, reference_model)

tests/dims/test_model.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,3 +172,56 @@ def test_complex_model():
172172
tune=200, chains=2, draws=50, compute_convergence_checks=False, progressbar=False
173173
)
174174
pm.sample_posterior_predictive(idata, progressbar=False)
175+
176+
177+
def test_zerosumnormal_model():
178+
coords = {"time": range(5), "item": range(3)}
179+
180+
with pm.Model(coords=coords) as model:
181+
zsn_item = pmd.ZeroSumNormal("zsn_item", core_dims="item", dims=("time", "item"))
182+
zsn_time = pmd.ZeroSumNormal("zsn_time", core_dims="time", dims=("time", "item"))
183+
zsn_item_time = pmd.ZeroSumNormal("zsn_item_time", core_dims=("item", "time"))
184+
assert zsn_item.type.dims == ("time", "item")
185+
assert zsn_time.type.dims == ("time", "item")
186+
assert zsn_item_time.type.dims == ("item", "time")
187+
188+
zsn_item_draw, zsn_time_draw, zsn_item_time_draw = pm.draw(
189+
[zsn_item, zsn_time, zsn_item_time], random_seed=1
190+
)
191+
assert zsn_item_draw.shape == (5, 3)
192+
np.testing.assert_allclose(zsn_item_draw.mean(-1), 0, atol=1e-13)
193+
assert not np.allclose(zsn_item_draw.mean(0), 0, atol=1e-13)
194+
195+
assert zsn_time_draw.shape == (5, 3)
196+
np.testing.assert_allclose(zsn_time_draw.mean(0), 0, atol=1e-13)
197+
assert not np.allclose(zsn_time_draw.mean(-1), 0, atol=1e-13)
198+
199+
assert zsn_item_time_draw.shape == (3, 5)
200+
np.testing.assert_allclose(zsn_item_time_draw.mean(), 0, atol=1e-13)
201+
202+
with pm.Model(coords=coords) as ref_model:
203+
# Check that the ZeroSumNormal can be used in a model
204+
pm.ZeroSumNormal("zsn_item", dims=("time", "item"))
205+
pm.ZeroSumNormal("zsn_time", dims=("item", "time"))
206+
pm.ZeroSumNormal("zsn_item_time", n_zerosum_axes=2, dims=("item", "time"))
207+
208+
# Check initial_point and logp
209+
ip = model.initial_point()
210+
ref_ip = ref_model.initial_point()
211+
assert ip.keys() == ref_ip.keys()
212+
for i, (ip_value, ref_ip_value) in enumerate(zip(ip.values(), ref_ip.values())):
213+
if i == 1:
214+
# zsn_time is actually transposed in the original model
215+
ip_value = ip_value.T
216+
np.testing.assert_allclose(ip_value, ref_ip_value)
217+
218+
logp_fn = model.compile_logp()
219+
ref_logp_fn = ref_model.compile_logp()
220+
np.testing.assert_allclose(logp_fn(ip), ref_logp_fn(ref_ip))
221+
222+
# Test a new point
223+
rng = np.random.default_rng(68)
224+
new_ip = ip.copy()
225+
for key in new_ip:
226+
new_ip[key] += rng.uniform(size=new_ip[key].shape)
227+
np.testing.assert_allclose(logp_fn(new_ip), ref_logp_fn(new_ip))

0 commit comments

Comments
 (0)