1414from collections .abc import Callable , Sequence
1515from itertools import chain
1616
17+ from pytensor .graph import node_rewriter
1718from pytensor .tensor .elemwise import DimShuffle
1819from pytensor .xtensor import as_xtensor
20+ from pytensor .xtensor .basic import XTensorFromTensor , xtensor_from_tensor
1921from pytensor .xtensor .type import XTensorVariable
2022
2123from pymc import modelcontext
2224from pymc .dims .model import with_dims
23- from pymc .distributions import transforms
25+ from pymc .dims . transforms import log_odds_transform , log_transform
2426from pymc .distributions .distribution import _support_point , support_point
2527from pymc .distributions .shape_utils import DimsWithEllipsis , convert_dims
28+ from pymc .logprob .abstract import MeasurableOp , _logprob
29+ from pymc .logprob .rewriting import measurable_ir_rewrites_db
30+ from pymc .logprob .utils import filter_measurable_variables
2631from pymc .util import UNSET
2732
2833
@@ -34,6 +39,38 @@ def dimshuffle_support_point(ds_op, _, rv):
3439 return ds_op (support_point (rv ))
3540
3641
42+ @_support_point .register (XTensorFromTensor )
43+ def xtensor_from_tensor_support_point (xtensor_op , _ , rv ):
44+ # We remove the xtensor_from_tensor operation, so initial_point doesn't have to do a further lowering
45+ return xtensor_op (support_point (rv ))
46+
47+
48+ class MeasurableXTensorFromTensor (MeasurableOp , XTensorFromTensor ):
49+ pass
50+
51+
52+ @node_rewriter ([XTensorFromTensor ])
53+ def find_measurable_xtensor_from_tensor (fgraph , node ) -> list [XTensorVariable ] | None :
54+ if isinstance (node .op , MeasurableXTensorFromTensor ):
55+ return None
56+
57+ if not filter_measurable_variables (node .inputs ):
58+ return None
59+
60+ return [MeasurableXTensorFromTensor (dims = node .op .dims )(* node .inputs )]
61+
62+
63+ @_logprob .register (MeasurableXTensorFromTensor )
64+ def measurable_xtensor_from_tensor (op , values , rv , ** kwargs ):
65+ rv_logp = _logprob (rv .owner .op , tuple (v .values for v in values ), * rv .owner .inputs , ** kwargs )
66+ return xtensor_from_tensor (rv_logp , dims = op .dims )
67+
68+
69+ measurable_ir_rewrites_db .register (
70+ "measurable_xtensor_from_tensor" , find_measurable_xtensor_from_tensor , "basic" , "xtensor"
71+ )
72+
73+
3774class DimDistribution :
3875 """Base class for PyMC distribution that wrap pytensor.xtensor.random operations, and follow xarray-like semantics."""
3976
@@ -117,10 +154,10 @@ def __new__(
117154 else :
118155 # Align observed dims with those of the RV
119156 # TODO: If this fails give a more informative error message
120- observed = observed .transpose (* rv_dims ). values
157+ observed = observed .transpose (* rv_dims )
121158
122159 rv = model .register_rv (
123- rv . values ,
160+ rv ,
124161 name = name ,
125162 observed = observed ,
126163 total_size = total_size ,
@@ -177,10 +214,10 @@ def dist(self, *args, core_dims: str | Sequence[str] | None = None, **kwargs):
177214class PositiveDimDistribution (DimDistribution ):
178215 """Base class for positive continuous distributions."""
179216
180- default_transform = transforms . log
217+ default_transform = log_transform
181218
182219
183220class UnitDimDistribution (DimDistribution ):
184221 """Base class for unit-valued distributions."""
185222
186- default_transform = transforms . logodds
223+ default_transform = log_odds_transform
0 commit comments