Skip to content

Commit 8ab2d1f

Browse files
committed
Allow registering XTensorVariables directly in model
1 parent d818f27 commit 8ab2d1f

File tree

11 files changed

+242
-79
lines changed

11 files changed

+242
-79
lines changed

pymc/dims/__init__.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,15 @@ def __init__():
3636

3737
# Make PyMC aware of xtensor functionality
3838
MeasurableOp.register(XRV)
39-
lower_xtensor_query = optdb.query("+lower_xtensor")
40-
logprob_rewrites_db.register("lower_xtensor", lower_xtensor_query, "basic", position=0.1)
41-
initial_point_rewrites_db.register("lower_xtensor", lower_xtensor_query, "basic", position=0.1)
39+
logprob_rewrites_db.register(
40+
"pre_lower_xtensor", optdb.query("+lower_xtensor"), "basic", position=0.1
41+
)
42+
logprob_rewrites_db.register(
43+
"post_lower_xtensor", optdb.query("+lower_xtensor"), "cleanup", position=5.1
44+
)
45+
initial_point_rewrites_db.register(
46+
"lower_xtensor", optdb.query("+lower_xtensor"), "basic", position=0.1
47+
)
4248

4349
# TODO: Better model of probability of bugs
4450
day_of_conception = datetime.date(2025, 6, 17)
@@ -64,4 +70,4 @@ def __init__():
6470

6571
from pymc.dims import math
6672
from pymc.dims.distributions import *
67-
from pymc.dims.model import Data, Deterministic, Potential, with_dims
73+
from pymc.dims.model import Data, Deterministic, Potential

pymc/dims/distributions/core.py

Lines changed: 117 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,23 @@
1414
from collections.abc import Callable, Sequence
1515
from itertools import chain
1616

17+
import numpy as np
18+
19+
from pytensor.graph import node_rewriter
1720
from pytensor.tensor.elemwise import DimShuffle
21+
from pytensor.tensor.random.op import RandomVariable
1822
from pytensor.xtensor import as_xtensor
23+
from pytensor.xtensor.basic import XTensorFromTensor, xtensor_from_tensor
1924
from pytensor.xtensor.type import XTensorVariable
2025

21-
from pymc import modelcontext
22-
from pymc.dims.model import with_dims
23-
from pymc.distributions import transforms
26+
from pymc import SymbolicRandomVariable, modelcontext
27+
from pymc.dims.transforms import DimTransform, log_odds_transform, log_transform
2428
from pymc.distributions.distribution import _support_point, support_point
2529
from pymc.distributions.shape_utils import DimsWithEllipsis, convert_dims
30+
from pymc.logprob.abstract import MeasurableOp, _logprob
31+
from pymc.logprob.rewriting import measurable_ir_rewrites_db
32+
from pymc.logprob.tensor import MeasurableDimShuffle
33+
from pymc.logprob.utils import filter_measurable_variables
2634
from pymc.util import UNSET
2735

2836

@@ -34,25 +42,97 @@ def dimshuffle_support_point(ds_op, _, rv):
3442
return ds_op(support_point(rv))
3543

3644

45+
@_support_point.register(XTensorFromTensor)
46+
def xtensor_from_tensor_support_point(xtensor_op, _, rv):
47+
# We remove the xtensor_from_tensor operation, so initial_point doesn't have to do a further lowering
48+
return xtensor_op(support_point(rv))
49+
50+
51+
class MeasurableXTensorFromTensor(MeasurableOp, XTensorFromTensor):
52+
__props__ = ("dims", "core_dims")
53+
54+
def __init__(self, dims, core_dims):
55+
super().__init__(dims=dims)
56+
self.core_dims = tuple(core_dims) if core_dims is not None else None
57+
58+
59+
@node_rewriter([XTensorFromTensor])
60+
def find_measurable_xtensor_from_tensor(fgraph, node) -> list[XTensorVariable] | None:
61+
if isinstance(node.op, MeasurableXTensorFromTensor):
62+
return None
63+
64+
xs = filter_measurable_variables(node.inputs)
65+
66+
if not xs:
67+
# Check if we have a transposition instead
68+
# The rewrite that introduces measurable tranpsoses refuses to apply to multivariate RVs
69+
# So we have a chance of inferring the core dims!
70+
[ds] = node.inputs
71+
ds_node = ds.owner
72+
if not (
73+
ds_node is not None
74+
and isinstance(ds_node.op, DimShuffle)
75+
and ds_node.op.is_transpose
76+
and filter_measurable_variables(ds_node.inputs)
77+
):
78+
return None
79+
[x] = ds_node.inputs
80+
if not (
81+
x.owner is not None and isinstance(x.owner.op, RandomVariable | SymbolicRandomVariable)
82+
):
83+
return None
84+
85+
measurable_x = MeasurableDimShuffle(**ds_node.op._props_dict())(x)
86+
87+
ndim_supp = x.owner.op.ndim_supp
88+
if ndim_supp:
89+
inverse_transpose = np.argsort(ds_node.op.shuffle)
90+
dims = node.op.dims
91+
dims_before_transpose = [dims[i] for i in inverse_transpose]
92+
core_dims = dims_before_transpose[-ndim_supp:]
93+
else:
94+
core_dims = ()
95+
96+
return [MeasurableXTensorFromTensor(dims=node.op.dims, core_dims=core_dims)(measurable_x)]
97+
else:
98+
# If this happens we know there's no measurable transpose in between and we can
99+
# safely infer the core_dims positionally when the inner logp is returned
100+
return [MeasurableXTensorFromTensor(dims=node.op.dims, core_dims=None)(*node.inputs)]
101+
102+
103+
@_logprob.register(MeasurableXTensorFromTensor)
104+
def measurable_xtensor_from_tensor(op, values, rv, **kwargs):
105+
rv_logp = _logprob(rv.owner.op, tuple(v.values for v in values), *rv.owner.inputs, **kwargs)
106+
if op.core_dims is None:
107+
# The core_dims of the inner rv are on the right
108+
dims = op.dims[: rv_logp.ndim]
109+
else:
110+
# We inferred where the core_dims are!
111+
dims = [d for d in op.dims if d not in op.core_dims]
112+
return xtensor_from_tensor(rv_logp, dims=dims)
113+
114+
115+
measurable_ir_rewrites_db.register(
116+
"measurable_xtensor_from_tensor", find_measurable_xtensor_from_tensor, "basic", "xtensor"
117+
)
118+
119+
37120
class DimDistribution:
38121
"""Base class for PyMC distribution that wrap pytensor.xtensor.random operations, and follow xarray-like semantics."""
39122

40123
xrv_op: Callable
41-
default_transform: Callable | None = None
124+
default_transform: DimTransform | None = None
42125

43126
@staticmethod
44127
def _as_xtensor(x):
45128
try:
46129
return as_xtensor(x)
47130
except TypeError:
48-
try:
49-
return with_dims(x)
50-
except ValueError:
51-
raise ValueError(
52-
f"Variable {x} must have dims associated with it.\n"
53-
"To avoid subtle bugs, PyMC does not make any assumptions about the dims of parameters.\n"
54-
"Use `as_xtensor` with the `dims` keyword argument to specify the dims explicitly."
55-
)
131+
raise ValueError(
132+
f"Variable {x} must have dims associated with it.\n"
133+
"To avoid subtle bugs, PyMC does not make any assumptions about the dims of parameters.\n"
134+
"Use `as_xtensor` with the `dims` keyword argument to specify the dims explicitly."
135+
)
56136

57137
def __new__(
58138
cls,
@@ -117,10 +197,22 @@ def __new__(
117197
else:
118198
# Align observed dims with those of the RV
119199
# TODO: If this fails give a more informative error message
120-
observed = observed.transpose(*rv_dims).values
200+
observed = observed.transpose(*rv_dims)
201+
202+
# Check user didn't pass regular transforms
203+
if transform not in (UNSET, None):
204+
if not isinstance(transform, DimTransform):
205+
raise TypeError(
206+
f"Transform must be a DimTransform, form pymc.dims.transforms, but got {type(transform)}."
207+
)
208+
if default_transform not in (UNSET, None):
209+
if not isinstance(default_transform, DimTransform):
210+
raise TypeError(
211+
f"default_transform must be a DimTransform, from pymc.dims.transforms, but got {type(default_transform)}."
212+
)
121213

122214
rv = model.register_rv(
123-
rv.values,
215+
rv,
124216
name=name,
125217
observed=observed,
126218
total_size=total_size,
@@ -151,13 +243,16 @@ def dist(
151243
if dims_dict is None:
152244
extra_dims = None
153245
else:
154-
parameter_implied_dims = set(
155-
chain.from_iterable(param.type.dims for param in dist_params)
156-
)
246+
# Exclude dims that are implied by the parameters or core_dims
247+
implied_dims = set(chain.from_iterable(param.type.dims for param in dist_params))
248+
if core_dims is not None:
249+
if isinstance(core_dims, str):
250+
implied_dims.add(core_dims)
251+
else:
252+
implied_dims.update(core_dims)
253+
157254
extra_dims = {
158-
dim: length
159-
for dim, length in dims_dict.items()
160-
if dim not in parameter_implied_dims
255+
dim: length for dim, length in dims_dict.items() if dim not in implied_dims
161256
}
162257
return cls.xrv_op(*dist_params, extra_dims=extra_dims, core_dims=core_dims, **kwargs)
163258

@@ -177,10 +272,10 @@ def dist(self, *args, core_dims: str | Sequence[str] | None = None, **kwargs):
177272
class PositiveDimDistribution(DimDistribution):
178273
"""Base class for positive continuous distributions."""
179274

180-
default_transform = transforms.log
275+
default_transform = log_transform
181276

182277

183278
class UnitDimDistribution(DimDistribution):
184279
"""Base class for unit-valued distributions."""
185280

186-
default_transform = transforms.logodds
281+
default_transform = log_odds_transform

pymc/dims/model.py

Lines changed: 0 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515

1616
from pytensor.tensor import TensorVariable
1717
from pytensor.xtensor import as_xtensor
18-
from pytensor.xtensor.basic import TensorFromXTensor
1918
from pytensor.xtensor.type import XTensorVariable
2019

2120
from pymc.data import Data as RegularData
@@ -25,29 +24,6 @@
2524
from pymc.model.core import Potential as RegularPotential
2625

2726

28-
def with_dims(x: TensorVariable | XTensorVariable, model: Model | None = None) -> XTensorVariable:
29-
"""Recover the dims of a variable that was registered in the Model."""
30-
if isinstance(x, XTensorVariable):
31-
return x
32-
33-
if (x.owner is not None) and isinstance(x.owner.op, TensorFromXTensor):
34-
dims = x.owner.inputs[0].type.dims
35-
return as_xtensor(x, dims=dims, name=x.name)
36-
37-
# Try accessing the model context to get dims
38-
try:
39-
model = modelcontext(model)
40-
if (
41-
model.named_vars.get(x.name, None) is x
42-
and (dims := model.named_vars_to_dims.get(x.name, None)) is not None
43-
):
44-
return as_xtensor(x, dims=dims, name=x.name)
45-
except TypeError:
46-
pass
47-
48-
raise ValueError(f"variable {x} doesn't have dims associated with it")
49-
50-
5127
def Data(
5228
name: str, value, dims: DimsWithEllipsis = None, model: Model | None = None, **kwargs
5329
) -> XTensorVariable:
@@ -79,8 +55,6 @@ def _register_and_return_xtensor_variable(
7955
value = value.transpose(*dims)
8056
# Regardless of whether dims are provided, we now have them
8157
dims = value.type.dims
82-
# Register the equivalent TensorVariable with the model so it doesn't see XTensorVariables directly.
83-
value = value.values
8458

8559
value = registration_func(name, value, dims=dims, model=model)
8660

pymc/dims/transforms.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
# Copyright 2025 - present The PyMC Developers
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import pytensor.xtensor as ptx
15+
16+
from pymc.logprob.transforms import Transform
17+
18+
19+
class DimTransform(Transform):
20+
"""Base class for transforms that are applied to dim distriubtions."""
21+
22+
23+
class LogTransform(DimTransform):
24+
name = "log"
25+
26+
def forward(self, value, *inputs):
27+
return ptx.math.log(value)
28+
29+
def backward(self, value, *inputs):
30+
return ptx.math.exp(value)
31+
32+
def log_jac_det(self, value, *inputs):
33+
return value
34+
35+
36+
log_transform = LogTransform()
37+
38+
39+
class LogOddsTransform(DimTransform):
40+
name = "logodds"
41+
42+
def backward(self, value, *inputs):
43+
return ptx.math.expit(value)
44+
45+
def forward(self, value, *inputs):
46+
return ptx.math.log(value / (1 - value))
47+
48+
def log_jac_det(self, value, *inputs):
49+
sigmoid_value = ptx.math.sigmoid(value)
50+
return ptx.math.log(sigmoid_value) + ptx.math.log1p(-sigmoid_value)
51+
52+
53+
log_odds_transform = LogOddsTransform()

pymc/initial_point.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@
2020
import pytensor
2121
import pytensor.tensor as pt
2222

23-
from pytensor.graph.basic import Constant, Variable
23+
from pytensor.compile.ops import TypeCastingOp
24+
from pytensor.graph.basic import Apply, Constant, Variable
2425
from pytensor.graph.fg import FunctionGraph
2526
from pytensor.graph.rewriting.db import RewriteDatabaseQuery, SequenceDB
2627
from pytensor.tensor.variable import TensorVariable
@@ -195,6 +196,14 @@ def inner(seed, *args, **kwargs):
195196
return make_seeded_function(func)
196197

197198

199+
class InitialPoint(TypeCastingOp):
200+
def make_node(self, var):
201+
return Apply(self, [var], [var.type()])
202+
203+
204+
initial_point_op = InitialPoint()
205+
206+
198207
def make_initial_point_expression(
199208
*,
200209
free_rvs: Sequence[TensorVariable],
@@ -235,6 +244,9 @@ def make_initial_point_expression(
235244

236245
# Clone free_rvs so we don't modify the original graph
237246
initial_point_fgraph = FunctionGraph(outputs=free_rvs, clone=True)
247+
# Wrap each rv in an initial_point Operation to avoid losing dependency between the RVs
248+
replacements = tuple((rv, initial_point_op(rv)) for rv in initial_point_fgraph.outputs)
249+
toposort_replace(initial_point_fgraph, replacements, reverse=True)
238250

239251
# Apply any rewrites necessary to compute the initial points.
240252
initial_point_rewriter = initial_point_rewrites_db.query(initial_point_basic_query)
@@ -254,10 +266,10 @@ def make_initial_point_expression(
254266
if isinstance(strategy, str):
255267
if strategy == "support_point":
256268
try:
257-
value = support_point(variable)
269+
value = support_point(variable.owner.inputs[0])
258270
except NotImplementedError:
259271
warnings.warn(
260-
f"Moment not defined for variable {variable} of type "
272+
f"support_point not defined for variable {variable} of type "
261273
f"{variable.owner.op.__class__.__name__}, defaulting to "
262274
f"a draw from the prior. This can lead to difficulties "
263275
f"during tuning. You can manually define an initval or "

0 commit comments

Comments
 (0)