Skip to content

Commit 9a04236

Browse files
committed
Allow registering XTensorVariables directly in model
1 parent 6c0a5d1 commit 9a04236

File tree

12 files changed

+284
-118
lines changed

12 files changed

+284
-118
lines changed

pymc/dims/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,9 @@ def __init__():
3939
logprob_rewrites_db.register(
4040
"pre_lower_xtensor", optdb.query("+lower_xtensor"), "basic", position=0.1
4141
)
42+
logprob_rewrites_db.register(
43+
"post_lower_xtensor", optdb.query("+lower_xtensor"), "cleanup", position=5.1
44+
)
4245
initial_point_rewrites_db.register(
4346
"lower_xtensor", optdb.query("+lower_xtensor"), "basic", position=0.1
4447
)

pymc/dims/distributions/core.py

Lines changed: 110 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -13,18 +13,26 @@
1313
# limitations under the License.
1414
from collections.abc import Callable, Sequence
1515
from itertools import chain
16+
from typing import cast
1617

18+
import numpy as np
19+
20+
from pytensor.graph import node_rewriter
1721
from pytensor.graph.basic import Variable
1822
from pytensor.tensor.elemwise import DimShuffle
23+
from pytensor.tensor.random.op import RandomVariable
1924
from pytensor.xtensor import as_xtensor
25+
from pytensor.xtensor.basic import XTensorFromTensor, xtensor_from_tensor
2026
from pytensor.xtensor.type import XTensorVariable
2127

22-
from pymc import modelcontext
23-
from pymc.dims.model import with_dims
24-
from pymc.distributions import transforms
28+
from pymc import SymbolicRandomVariable, modelcontext
29+
from pymc.dims.distributions.transforms import DimTransform, log_odds_transform, log_transform
2530
from pymc.distributions.distribution import _support_point, support_point
2631
from pymc.distributions.shape_utils import DimsWithEllipsis, convert_dims_with_ellipsis
27-
from pymc.logprob.transforms import Transform
32+
from pymc.logprob.abstract import MeasurableOp, _logprob
33+
from pymc.logprob.rewriting import measurable_ir_rewrites_db
34+
from pymc.logprob.tensor import MeasurableDimShuffle
35+
from pymc.logprob.utils import filter_measurable_variables
2836
from pymc.util import UNSET
2937

3038

@@ -36,25 +44,98 @@ def dimshuffle_support_point(ds_op, _, rv):
3644
return ds_op(support_point(rv))
3745

3846

47+
@_support_point.register(XTensorFromTensor)
48+
def xtensor_from_tensor_support_point(xtensor_op, _, rv):
49+
# We remove the xtensor_from_tensor operation, so initial_point doesn't have to do a further lowering
50+
return xtensor_op(support_point(rv))
51+
52+
53+
class MeasurableXTensorFromTensor(MeasurableOp, XTensorFromTensor):
54+
__props__ = ("dims", "core_dims") # type: ignore[assignment]
55+
56+
def __init__(self, dims, core_dims):
57+
super().__init__(dims=dims)
58+
self.core_dims = tuple(core_dims) if core_dims is not None else None
59+
60+
61+
@node_rewriter([XTensorFromTensor])
62+
def find_measurable_xtensor_from_tensor(fgraph, node) -> list[XTensorVariable] | None:
63+
if isinstance(node.op, MeasurableXTensorFromTensor):
64+
return None
65+
66+
xs = filter_measurable_variables(node.inputs)
67+
68+
if not xs:
69+
# Check if we have a transposition instead
70+
# The rewrite that introduces measurable tranpsoses refuses to apply to multivariate RVs
71+
# So we have a chance of inferring the core dims!
72+
[ds] = node.inputs
73+
ds_node = ds.owner
74+
if not (
75+
ds_node is not None
76+
and isinstance(ds_node.op, DimShuffle)
77+
and ds_node.op.is_transpose
78+
and filter_measurable_variables(ds_node.inputs)
79+
):
80+
return None
81+
[x] = ds_node.inputs
82+
if not (
83+
x.owner is not None and isinstance(x.owner.op, RandomVariable | SymbolicRandomVariable)
84+
):
85+
return None
86+
87+
measurable_x = MeasurableDimShuffle(**ds_node.op._props_dict())(x) # type: ignore[attr-defined]
88+
89+
ndim_supp = x.owner.op.ndim_supp
90+
if ndim_supp:
91+
inverse_transpose = np.argsort(ds_node.op.shuffle)
92+
dims = node.op.dims
93+
dims_before_transpose = tuple(dims[i] for i in inverse_transpose)
94+
core_dims = dims_before_transpose[-ndim_supp:]
95+
else:
96+
core_dims = ()
97+
98+
new_out = MeasurableXTensorFromTensor(dims=node.op.dims, core_dims=core_dims)(measurable_x)
99+
else:
100+
# If this happens we know there's no measurable transpose in between and we can
101+
# safely infer the core_dims positionally when the inner logp is returned
102+
new_out = MeasurableXTensorFromTensor(dims=node.op.dims, core_dims=None)(*node.inputs)
103+
return [cast(XTensorVariable, new_out)]
104+
105+
106+
@_logprob.register(MeasurableXTensorFromTensor)
107+
def measurable_xtensor_from_tensor(op, values, rv, **kwargs):
108+
rv_logp = _logprob(rv.owner.op, tuple(v.values for v in values), *rv.owner.inputs, **kwargs)
109+
if op.core_dims is None:
110+
# The core_dims of the inner rv are on the right
111+
dims = op.dims[: rv_logp.ndim]
112+
else:
113+
# We inferred where the core_dims are!
114+
dims = [d for d in op.dims if d not in op.core_dims]
115+
return xtensor_from_tensor(rv_logp, dims=dims)
116+
117+
118+
measurable_ir_rewrites_db.register(
119+
"measurable_xtensor_from_tensor", find_measurable_xtensor_from_tensor, "basic", "xtensor"
120+
)
121+
122+
39123
class DimDistribution:
40124
"""Base class for PyMC distribution that wrap pytensor.xtensor.random operations, and follow xarray-like semantics."""
41125

42126
xrv_op: Callable
43-
default_transform: Transform | None = None
127+
default_transform: DimTransform | None = None
44128

45129
@staticmethod
46130
def _as_xtensor(x):
47131
try:
48132
return as_xtensor(x)
49133
except TypeError:
50-
try:
51-
return with_dims(x)
52-
except ValueError:
53-
raise ValueError(
54-
f"Variable {x} must have dims associated with it.\n"
55-
"To avoid subtle bugs, PyMC does not make any assumptions about the dims of parameters.\n"
56-
"Use `as_xtensor` with the `dims` keyword argument to specify the dims explicitly."
57-
)
134+
raise ValueError(
135+
f"Variable {x} must have dims associated with it.\n"
136+
"To avoid subtle bugs, PyMC does not make any assumptions about the dims of parameters.\n"
137+
"Use `as_xtensor` with the `dims` keyword argument to specify the dims explicitly."
138+
)
58139

59140
def __new__(
60141
cls,
@@ -119,10 +200,22 @@ def __new__(
119200
else:
120201
# Align observed dims with those of the RV
121202
# TODO: If this fails give a more informative error message
122-
observed = observed.transpose(*rv_dims).values
203+
observed = observed.transpose(*rv_dims)
204+
205+
# Check user didn't pass regular transforms
206+
if transform not in (UNSET, None):
207+
if not isinstance(transform, DimTransform):
208+
raise TypeError(
209+
f"Transform must be a DimTransform, form pymc.dims.transforms, but got {type(transform)}."
210+
)
211+
if default_transform not in (UNSET, None):
212+
if not isinstance(default_transform, DimTransform):
213+
raise TypeError(
214+
f"default_transform must be a DimTransform, from pymc.dims.transforms, but got {type(default_transform)}."
215+
)
123216

124217
rv = model.register_rv(
125-
rv.values,
218+
rv,
126219
name=name,
127220
observed=observed,
128221
total_size=total_size,
@@ -182,10 +275,10 @@ def dist(self, *args, core_dims: str | Sequence[str] | None = None, **kwargs):
182275
class PositiveDimDistribution(DimDistribution):
183276
"""Base class for positive continuous distributions."""
184277

185-
default_transform = transforms.log
278+
default_transform = log_transform
186279

187280

188281
class UnitDimDistribution(DimDistribution):
189282
"""Base class for unit-valued distributions."""
190283

191-
default_transform = transforms.logodds
284+
default_transform = log_odds_transform

pymc/dims/distributions/transforms.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
# Copyright 2025 - present The PyMC Developers
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import pytensor.xtensor as ptx
15+
16+
from pymc.logprob.transforms import Transform
17+
18+
19+
class DimTransform(Transform):
20+
"""Base class for transforms that are applied to dim distriubtions."""
21+
22+
23+
class LogTransform(DimTransform):
24+
name = "log"
25+
26+
def forward(self, value, *inputs):
27+
return ptx.math.log(value)
28+
29+
def backward(self, value, *inputs):
30+
return ptx.math.exp(value)
31+
32+
def log_jac_det(self, value, *inputs):
33+
return value
34+
35+
36+
log_transform = LogTransform()
37+
38+
39+
class LogOddsTransform(DimTransform):
40+
name = "logodds"
41+
42+
def backward(self, value, *inputs):
43+
return ptx.math.expit(value)
44+
45+
def forward(self, value, *inputs):
46+
return ptx.math.log(value / (1 - value))
47+
48+
def log_jac_det(self, value, *inputs):
49+
sigmoid_value = ptx.math.sigmoid(value)
50+
return ptx.math.log(sigmoid_value) + ptx.math.log1p(-sigmoid_value)
51+
52+
53+
log_odds_transform = LogOddsTransform()

pymc/dims/model.py

Lines changed: 3 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515

1616
from pytensor.tensor import TensorVariable
1717
from pytensor.xtensor import as_xtensor
18-
from pytensor.xtensor.basic import TensorFromXTensor
1918
from pytensor.xtensor.type import XTensorVariable
2019

2120
from pymc.data import Data as RegularData
@@ -30,38 +29,13 @@
3029
from pymc.model.core import Potential as RegularPotential
3130

3231

33-
def with_dims(x: TensorVariable | XTensorVariable, model: Model | None = None) -> XTensorVariable:
34-
"""Recover the dims of a variable that was registered in the Model."""
35-
if isinstance(x, XTensorVariable):
36-
return x
37-
38-
if (x.owner is not None) and isinstance(x.owner.op, TensorFromXTensor):
39-
dims = x.owner.inputs[0].type.dims
40-
return as_xtensor(x, dims=dims, name=x.name)
41-
42-
# Try accessing the model context to get dims
43-
try:
44-
model = modelcontext(model)
45-
if (
46-
model.named_vars.get(x.name, None) is x
47-
and (dims := model.named_vars_to_dims.get(x.name, None)) is not None
48-
):
49-
return as_xtensor(x, dims=dims, name=x.name)
50-
except TypeError:
51-
pass
52-
53-
raise ValueError(f"variable {x} doesn't have dims associated with it")
54-
55-
5632
def Data(
5733
name: str, value, dims: Dims = None, model: Model | None = None, **kwargs
5834
) -> XTensorVariable:
5935
"""Wrapper around pymc.Data that returns an XtensorVariable.
6036
6137
Dimensions are required if the input is not a scalar.
6238
These are always forwarded to the model object.
63-
64-
The respective TensorVariable is registered in the model
6539
"""
6640
model = modelcontext(model)
6741
dims = convert_dims(dims) # type: ignore[assignment]
@@ -90,12 +64,9 @@ def _register_and_return_xtensor_variable(
9064
value = value.transpose(*dims)
9165
# Regardless of whether dims are provided, we now have them
9266
dims = value.type.dims
93-
# Register the equivalent TensorVariable with the model so it doesn't see XTensorVariables directly.
94-
value = value.values # type: ignore[union-attr]
95-
96-
value = registration_func(name, value, dims=dims, model=model)
97-
98-
return as_xtensor(value, dims=dims, name=name) # type: ignore[arg-type]
67+
else:
68+
value = as_xtensor(value, dims=dims, name=name) # type: ignore[arg-type]
69+
return registration_func(name, value, dims=dims, model=model)
9970

10071

10172
def Deterministic(
@@ -107,8 +78,6 @@ def Deterministic(
10778
If the input is not an XTensorVariable, it is converted to one using `as_xtensor`. Dims are required if the input is not a scalar.
10879
10980
The dimensions of the resulting XTensorVariable are always forwarded to the model object.
110-
111-
The respective TensorVariable is registered in the model
11281
"""
11382
return _register_and_return_xtensor_variable(name, value, dims, model, RegularDeterministic)
11483

@@ -122,7 +91,5 @@ def Potential(
12291
If the input is not an XTensorVariable, it is converted to one using `as_xtensor`. Dims are required if the input is not a scalar.
12392
12493
The dimensions of the resulting XTensorVariable are always forwarded to the model object.
125-
126-
The respective TensorVariable is registered in the model.
12794
"""
12895
return _register_and_return_xtensor_variable(name, value, dims, model, RegularPotential)

0 commit comments

Comments
 (0)