Skip to content

Commit d93c1c9

Browse files
committed
.WIP changes needed to implement ZSN
1 parent 777cdd3 commit d93c1c9

File tree

5 files changed

+245
-18
lines changed

5 files changed

+245
-18
lines changed

pymc/dims/distribution_core.py

Lines changed: 76 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -14,19 +14,23 @@
1414
from collections.abc import Callable, Sequence
1515
from itertools import chain
1616

17+
import numpy as np
18+
1719
from pytensor.graph import node_rewriter
1820
from pytensor.tensor.elemwise import DimShuffle
21+
from pytensor.tensor.random.op import RandomVariable
1922
from pytensor.xtensor import as_xtensor
2023
from pytensor.xtensor.basic import XTensorFromTensor, xtensor_from_tensor
2124
from pytensor.xtensor.type import XTensorVariable
2225

23-
from pymc import modelcontext
26+
from pymc import SymbolicRandomVariable, modelcontext
2427
from pymc.dims.model import with_dims
25-
from pymc.dims.transforms import log_odds_transform, log_transform
28+
from pymc.dims.transforms import DimTransform, log_odds_transform, log_transform
2629
from pymc.distributions.distribution import _support_point, support_point
2730
from pymc.distributions.shape_utils import DimsWithEllipsis, convert_dims
2831
from pymc.logprob.abstract import MeasurableOp, _logprob
2932
from pymc.logprob.rewriting import measurable_ir_rewrites_db
33+
from pymc.logprob.tensor import MeasurableDimShuffle
3034
from pymc.logprob.utils import filter_measurable_variables
3135
from pymc.util import UNSET
3236

@@ -46,24 +50,67 @@ def xtensor_from_tensor_support_point(xtensor_op, _, rv):
4650

4751

4852
class MeasurableXTensorFromTensor(MeasurableOp, XTensorFromTensor):
49-
pass
53+
__props__ = ("dims", "core_dims")
54+
55+
def __init__(self, dims, core_dims):
56+
super().__init__(dims=dims)
57+
self.core_dims = tuple(core_dims) if core_dims is not None else None
5058

5159

5260
@node_rewriter([XTensorFromTensor])
5361
def find_measurable_xtensor_from_tensor(fgraph, node) -> list[XTensorVariable] | None:
5462
if isinstance(node.op, MeasurableXTensorFromTensor):
5563
return None
5664

57-
if not filter_measurable_variables(node.inputs):
58-
return None
65+
xs = filter_measurable_variables(node.inputs)
66+
67+
if not xs:
68+
# Check if we have a transposition instead
69+
# The rewrite that introduces measurable tranpsoses refuses to apply to multivariate RVs
70+
# So we have a chance of inferring the core dims!
71+
[ds] = node.inputs
72+
ds_node = ds.owner
73+
if not (
74+
ds_node is not None
75+
and isinstance(ds_node.op, DimShuffle)
76+
and ds_node.op.is_transpose
77+
and filter_measurable_variables(ds_node.inputs)
78+
):
79+
return None
80+
[x] = ds_node.inputs
81+
if not (
82+
x.owner is not None and isinstance(x.owner.op, RandomVariable | SymbolicRandomVariable)
83+
):
84+
return None
85+
86+
measurable_x = MeasurableDimShuffle(**ds_node.op._props_dict())(x)
87+
88+
ndim_supp = x.owner.op.ndim_supp
89+
if ndim_supp:
90+
inverse_transpose = np.argsort(ds_node.op.shuffle)
91+
dims = node.op.dims
92+
dims_before_transpose = [dims[i] for i in inverse_transpose]
93+
core_dims = dims_before_transpose[-ndim_supp:]
94+
else:
95+
core_dims = ()
5996

60-
return [MeasurableXTensorFromTensor(dims=node.op.dims)(*node.inputs)]
97+
return [MeasurableXTensorFromTensor(dims=node.op.dims, core_dims=core_dims)(measurable_x)]
98+
else:
99+
# If this happens we know there's no measurable transpose in between and we can
100+
# safely infer the core_dims positionally when the inner logp is returned
101+
return [MeasurableXTensorFromTensor(dims=node.op.dims, core_dims=None)(*node.inputs)]
61102

62103

63104
@_logprob.register(MeasurableXTensorFromTensor)
64105
def measurable_xtensor_from_tensor(op, values, rv, **kwargs):
65106
rv_logp = _logprob(rv.owner.op, tuple(v.values for v in values), *rv.owner.inputs, **kwargs)
66-
return xtensor_from_tensor(rv_logp, dims=op.dims)
107+
if op.core_dims is None:
108+
# The core_dims of the inner rv are on the right
109+
dims = op.dims[: rv_logp.ndim]
110+
else:
111+
# We inferred where the core_dims are!
112+
dims = [d for d in op.dims if d not in op.core_dims]
113+
return xtensor_from_tensor(rv_logp, dims=dims)
67114

68115

69116
measurable_ir_rewrites_db.register(
@@ -75,7 +122,7 @@ class DimDistribution:
75122
"""Base class for PyMC distribution that wrap pytensor.xtensor.random operations, and follow xarray-like semantics."""
76123

77124
xrv_op: Callable
78-
default_transform: Callable | None = None
125+
default_transform: DimTransform | None = None
79126

80127
@staticmethod
81128
def _as_xtensor(x):
@@ -156,6 +203,18 @@ def __new__(
156203
# TODO: If this fails give a more informative error message
157204
observed = observed.transpose(*rv_dims)
158205

206+
# Check user didn't pass regular transforms
207+
if transform not in (UNSET, None):
208+
if not isinstance(transform, DimTransform):
209+
raise TypeError(
210+
f"Transform must be a DimTransform, form pymc.dims.transforms, but got {type(transform)}."
211+
)
212+
if default_transform not in (UNSET, None):
213+
if not isinstance(default_transform, DimTransform):
214+
raise TypeError(
215+
f"default_transform must be a DimTransform, from pymc.dims.transforms, but got {type(default_transform)}."
216+
)
217+
159218
rv = model.register_rv(
160219
rv,
161220
name=name,
@@ -188,13 +247,16 @@ def dist(
188247
if dims_dict is None:
189248
extra_dims = None
190249
else:
191-
parameter_implied_dims = set(
192-
chain.from_iterable(param.type.dims for param in dist_params)
193-
)
250+
# Exclude dims that are implied by the parameters or core_dims
251+
implied_dims = set(chain.from_iterable(param.type.dims for param in dist_params))
252+
if core_dims is not None:
253+
if isinstance(core_dims, str):
254+
implied_dims.add(core_dims)
255+
else:
256+
implied_dims.update(core_dims)
257+
194258
extra_dims = {
195-
dim: length
196-
for dim, length in dims_dict.items()
197-
if dim not in parameter_implied_dims
259+
dim: length for dim, length in dims_dict.items() if dim not in implied_dims
198260
}
199261
return cls.xrv_op(*dist_params, extra_dims=extra_dims, core_dims=core_dims, **kwargs)
200262

pymc/dims/distributions.py

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,16 +14,23 @@
1414
import pytensor.xtensor as ptx
1515
import pytensor.xtensor.random as pxr
1616

17+
from pytensor.tensor import as_tensor
18+
from pytensor.tensor.random.utils import normalize_size_param
19+
from pytensor.xtensor import as_xtensor
20+
1721
from pymc.dims.distribution_core import (
1822
DimDistribution,
1923
MultivariateDimDistribution,
2024
PositiveDimDistribution,
2125
UnitDimDistribution,
2226
)
27+
from pymc.dims.transforms import ZeroSumTransform
2328
from pymc.distributions.continuous import Beta as RegularBeta
2429
from pymc.distributions.continuous import Gamma as RegularGamma
2530
from pymc.distributions.continuous import HalfStudentTRV, flat, halfflat
2631
from pymc.distributions.continuous import InverseGamma as RegularInverseGamma
32+
from pymc.distributions.multivariate import ZeroSumNormalRV
33+
from pymc.util import UNSET
2734

2835

2936
def _get_sigma_from_either_sigma_or_tau(*, sigma, tau):
@@ -221,3 +228,93 @@ def dist(cls, mu, cov=None, *, chol=None, lower=True, core_dims=None, **kwargs):
221228
cov = chol.dot(chol.rename({d0: safe_name}), dim=d1).rename({safe_name: d1})
222229

223230
return super().dist([mu, cov], core_dims=core_dims, **kwargs)
231+
232+
233+
class DimZeroSumNormalRV(ZeroSumNormalRV):
234+
def make_node(self, rng, size, sigma, support_shape):
235+
if not self.input_types[1].in_same_class(normalize_size_param(size).type):
236+
# We need to rebuild the graph with new size type
237+
return self.rv_op(sigma, support_shape, size=size, rng=rng).owner
238+
return super().make_node(rng, size, sigma, support_shape)
239+
240+
241+
class ZeroSumNormal(MultivariateDimDistribution):
242+
@classmethod
243+
def __new__(
244+
cls, *args, core_dims=None, dims=None, default_transform=UNSET, observed=None, **kwargs
245+
):
246+
if core_dims is not None:
247+
if isinstance(core_dims, str):
248+
core_dims = (core_dims,)
249+
250+
# Create default_transform
251+
if observed is None and default_transform is UNSET:
252+
default_transform = ZeroSumTransform(dims=core_dims)
253+
254+
# If the user didn't specify dims, take it from core_dims
255+
# We need them to be forwarded to dist in the `dims_dict` argument
256+
if dims is None and core_dims is not None:
257+
dims = (..., *core_dims)
258+
259+
return super().__new__(
260+
*args,
261+
core_dims=core_dims,
262+
dims=dims,
263+
default_transform=default_transform,
264+
observed=observed,
265+
**kwargs,
266+
)
267+
268+
@classmethod
269+
def dist(cls, sigma=1.0, *, core_dims=None, dims_dict, **kwargs):
270+
if isinstance(core_dims, str):
271+
core_dims = (core_dims,)
272+
if core_dims is None or len(core_dims) == 0:
273+
raise ValueError("ZeroSumNormal requires atleast 1 core_dims")
274+
275+
support_dims = as_xtensor(
276+
as_tensor([dims_dict[core_dim] for core_dim in core_dims]), dims=("_",)
277+
)
278+
sigma = cls._as_xtensor(sigma)
279+
280+
return super().dist(
281+
[sigma, support_dims], core_dims=core_dims, dims_dict=dims_dict, **kwargs
282+
)
283+
284+
# def multivariate_normal(
285+
# mean,
286+
# cov,
287+
# *,
288+
# core_dims: Sequence[str],
289+
# extra_dims=None,
290+
# rng=None,
291+
# method: Literal["cholesky", "svd", "eigh"] = "cholesky",
292+
# ):
293+
# mean = as_xtensor(mean)
294+
# if len(core_dims) != 2:
295+
# raise ValueError(
296+
# f"multivariate_normal requires 2 core_dims, got {len(core_dims)}"
297+
# )
298+
#
299+
# # Align core_dims, so that the dim that exists in mean comes before the one that only exists in cov
300+
# # This will be the core dimension of the output
301+
# if core_dims[0] not in mean.type.dims:
302+
# core_dims = core_dims[::-1]
303+
#
304+
# xop = _as_xrv(ptr.MvNormalRV(method=method))
305+
# return xop(mean, cov, core_dims=core_dims, extra_dims=extra_dims, rng=rng)
306+
307+
@classmethod
308+
def xrv_op(self, sigma, support_dims, core_dims, extra_dims=None, rng=None):
309+
sigma = as_xtensor(sigma)
310+
support_dims = as_xtensor(support_dims, dims=("_",))
311+
support_shape = support_dims.values
312+
core_rv = DimZeroSumNormalRV.rv_op(sigma=sigma.values, support_shape=support_shape).owner.op
313+
xop = pxr._as_xrv(
314+
core_rv,
315+
core_inps_dims_map=[(), (0,)],
316+
core_out_dims_map=tuple(range(1, len(core_dims) + 1)),
317+
)
318+
# Dummy "_" core dim to absorb the support_shape vector
319+
# If ZeroSumNormal expected a scalar per support dim, this wouldn't be needed
320+
return xop(sigma, support_dims, core_dims=("_", *core_dims), extra_dims=extra_dims, rng=rng)

pymc/distributions/distribution.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -367,6 +367,7 @@ def __init__(
367367

368368
kwargs.setdefault("inline", True)
369369
kwargs.setdefault("strict", True)
370+
kwargs.setdefault("on_unused_input", "ignore")
370371
super().__init__(*args, **kwargs)
371372

372373
def update(self, node: Apply) -> dict[Variable, Variable]:

pymc/distributions/multivariate.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2664,6 +2664,7 @@ def logp(value, alpha, K):
26642664
class ZeroSumNormalRV(SymbolicRandomVariable):
26652665
"""ZeroSumNormal random variable."""
26662666

2667+
name = "ZeroSumNormal"
26672668
_print_name = ("ZeroSumNormal", "\\operatorname{ZeroSumNormal}")
26682669

26692670
@classmethod
@@ -2687,12 +2688,12 @@ def rv_op(cls, sigma, support_shape, *, size=None, rng=None):
26872688
zerosum_rv -= zerosum_rv.mean(axis=-axis - 1, keepdims=True)
26882689

26892690
support_str = ",".join([f"d{i}" for i in range(n_zerosum_axes)])
2690-
extended_signature = f"[rng],(),(s),[size]->[rng],({support_str})"
2691-
return ZeroSumNormalRV(
2692-
inputs=[rng, sigma, support_shape, size],
2691+
extended_signature = f"[rng],[size],(),(s)->[rng],({support_str})"
2692+
return cls(
2693+
inputs=[rng, size, sigma, support_shape],
26932694
outputs=[next_rng, zerosum_rv],
26942695
extended_signature=extended_signature,
2695-
)(rng, sigma, support_shape, size)
2696+
)(rng, size, sigma, support_shape)
26962697

26972698

26982699
class ZeroSumNormal(Distribution):

tests/dims/test_distributions.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
# Copyright 2025 - present The PyMC Developers
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import numpy as np
15+
16+
from pymc import Model, draw
17+
from pymc.dims import ZeroSumNormal
18+
from pymc.distributions import ZeroSumNormal as RegularZeroSumNormal
19+
20+
21+
def test_zerosumnormal():
22+
coords = {"time": range(5), "item": range(3)}
23+
24+
with Model(coords=coords) as model:
25+
zsn_item = ZeroSumNormal("zsn_item", core_dims="item", dims=("time", "item"))
26+
zsn_time = ZeroSumNormal("zsn_time", core_dims="time", dims=("time", "item"))
27+
zsn_item_time = ZeroSumNormal("zsn_item_time", core_dims=("item", "time"))
28+
assert zsn_item.type.dims == ("time", "item")
29+
assert zsn_time.type.dims == ("time", "item")
30+
assert zsn_item_time.type.dims == ("item", "time")
31+
32+
zsn_item_draw, zsn_time_draw, zsn_item_time_draw = draw(
33+
[zsn_item, zsn_time, zsn_item_time], random_seed=1
34+
)
35+
assert zsn_item_draw.shape == (5, 3)
36+
np.testing.assert_allclose(zsn_item_draw.mean(-1), 0, atol=1e-13)
37+
assert not np.allclose(zsn_item_draw.mean(0), 0, atol=1e-13)
38+
39+
assert zsn_time_draw.shape == (5, 3)
40+
np.testing.assert_allclose(zsn_time_draw.mean(0), 0, atol=1e-13)
41+
assert not np.allclose(zsn_time_draw.mean(-1), 0, atol=1e-13)
42+
43+
assert zsn_item_time_draw.shape == (3, 5)
44+
np.testing.assert_allclose(zsn_item_time_draw.mean(), 0, atol=1e-13)
45+
46+
with Model(coords=coords) as ref_model:
47+
# Check that the ZeroSumNormal can be used in a model
48+
RegularZeroSumNormal("zsn_item", dims=("time", "item"))
49+
RegularZeroSumNormal("zsn_time", dims=("item", "time"))
50+
RegularZeroSumNormal("zsn_item_time", n_zerosum_axes=2, dims=("item", "time"))
51+
52+
# Check initial_point and logp
53+
ip = model.initial_point()
54+
ref_ip = ref_model.initial_point()
55+
assert ip.keys() == ref_ip.keys()
56+
for i, (ip_value, ref_ip_value) in enumerate(zip(ip.values(), ref_ip.values())):
57+
if i == 1:
58+
# zsn_time is actually transposed in the original model
59+
ip_value = ip_value.T
60+
np.testing.assert_allclose(ip_value, ref_ip_value)
61+
62+
logp_fn = model.compile_logp()
63+
ref_logp_fn = ref_model.compile_logp()
64+
logp_fn(ip)
65+
# np.testing.assert_allclose(logp_fn(ip), ref_logp_fn(ref_ip))
66+
# Test a new

0 commit comments

Comments
 (0)