1414from collections .abc import Callable , Sequence
1515from itertools import chain
1616
17+ import numpy as np
18+
19+ from pytensor .graph import node_rewriter
1720from pytensor .tensor .elemwise import DimShuffle
21+ from pytensor .tensor .random .op import RandomVariable
1822from pytensor .xtensor import as_xtensor
23+ from pytensor .xtensor .basic import XTensorFromTensor , xtensor_from_tensor
1924from pytensor .xtensor .type import XTensorVariable
2025
21- from pymc import modelcontext
22- from pymc .dims .model import with_dims
23- from pymc .distributions import transforms
26+ from pymc import SymbolicRandomVariable , modelcontext
27+ from pymc .dims .transforms import DimTransform , log_odds_transform , log_transform
2428from pymc .distributions .distribution import _support_point , support_point
2529from pymc .distributions .shape_utils import DimsWithEllipsis , convert_dims
30+ from pymc .logprob .abstract import MeasurableOp , _logprob
31+ from pymc .logprob .rewriting import measurable_ir_rewrites_db
32+ from pymc .logprob .tensor import MeasurableDimShuffle
33+ from pymc .logprob .utils import filter_measurable_variables
2634from pymc .util import UNSET
2735
2836
@@ -34,25 +42,97 @@ def dimshuffle_support_point(ds_op, _, rv):
3442 return ds_op (support_point (rv ))
3543
3644
45+ @_support_point .register (XTensorFromTensor )
46+ def xtensor_from_tensor_support_point (xtensor_op , _ , rv ):
47+ # We remove the xtensor_from_tensor operation, so initial_point doesn't have to do a further lowering
48+ return xtensor_op (support_point (rv ))
49+
50+
51+ class MeasurableXTensorFromTensor (MeasurableOp , XTensorFromTensor ):
52+ __props__ = ("dims" , "core_dims" )
53+
54+ def __init__ (self , dims , core_dims ):
55+ super ().__init__ (dims = dims )
56+ self .core_dims = tuple (core_dims ) if core_dims is not None else None
57+
58+
59+ @node_rewriter ([XTensorFromTensor ])
60+ def find_measurable_xtensor_from_tensor (fgraph , node ) -> list [XTensorVariable ] | None :
61+ if isinstance (node .op , MeasurableXTensorFromTensor ):
62+ return None
63+
64+ xs = filter_measurable_variables (node .inputs )
65+
66+ if not xs :
67+ # Check if we have a transposition instead
68+ # The rewrite that introduces measurable tranpsoses refuses to apply to multivariate RVs
69+ # So we have a chance of inferring the core dims!
70+ [ds ] = node .inputs
71+ ds_node = ds .owner
72+ if not (
73+ ds_node is not None
74+ and isinstance (ds_node .op , DimShuffle )
75+ and ds_node .op .is_transpose
76+ and filter_measurable_variables (ds_node .inputs )
77+ ):
78+ return None
79+ [x ] = ds_node .inputs
80+ if not (
81+ x .owner is not None and isinstance (x .owner .op , RandomVariable | SymbolicRandomVariable )
82+ ):
83+ return None
84+
85+ measurable_x = MeasurableDimShuffle (** ds_node .op ._props_dict ())(x )
86+
87+ ndim_supp = x .owner .op .ndim_supp
88+ if ndim_supp :
89+ inverse_transpose = np .argsort (ds_node .op .shuffle )
90+ dims = node .op .dims
91+ dims_before_transpose = [dims [i ] for i in inverse_transpose ]
92+ core_dims = dims_before_transpose [- ndim_supp :]
93+ else :
94+ core_dims = ()
95+
96+ return [MeasurableXTensorFromTensor (dims = node .op .dims , core_dims = core_dims )(measurable_x )]
97+ else :
98+ # If this happens we know there's no measurable transpose in between and we can
99+ # safely infer the core_dims positionally when the inner logp is returned
100+ return [MeasurableXTensorFromTensor (dims = node .op .dims , core_dims = None )(* node .inputs )]
101+
102+
103+ @_logprob .register (MeasurableXTensorFromTensor )
104+ def measurable_xtensor_from_tensor (op , values , rv , ** kwargs ):
105+ rv_logp = _logprob (rv .owner .op , tuple (v .values for v in values ), * rv .owner .inputs , ** kwargs )
106+ if op .core_dims is None :
107+ # The core_dims of the inner rv are on the right
108+ dims = op .dims [: rv_logp .ndim ]
109+ else :
110+ # We inferred where the core_dims are!
111+ dims = [d for d in op .dims if d not in op .core_dims ]
112+ return xtensor_from_tensor (rv_logp , dims = dims )
113+
114+
115+ measurable_ir_rewrites_db .register (
116+ "measurable_xtensor_from_tensor" , find_measurable_xtensor_from_tensor , "basic" , "xtensor"
117+ )
118+
119+
37120class DimDistribution :
38121 """Base class for PyMC distribution that wrap pytensor.xtensor.random operations, and follow xarray-like semantics."""
39122
40123 xrv_op : Callable
41- default_transform : Callable | None = None
124+ default_transform : DimTransform | None = None
42125
43126 @staticmethod
44127 def _as_xtensor (x ):
45128 try :
46129 return as_xtensor (x )
47130 except TypeError :
48- try :
49- return with_dims (x )
50- except ValueError :
51- raise ValueError (
52- f"Variable { x } must have dims associated with it.\n "
53- "To avoid subtle bugs, PyMC does not make any assumptions about the dims of parameters.\n "
54- "Use `as_xtensor` with the `dims` keyword argument to specify the dims explicitly."
55- )
131+ raise ValueError (
132+ f"Variable { x } must have dims associated with it.\n "
133+ "To avoid subtle bugs, PyMC does not make any assumptions about the dims of parameters.\n "
134+ "Use `as_xtensor` with the `dims` keyword argument to specify the dims explicitly."
135+ )
56136
57137 def __new__ (
58138 cls ,
@@ -117,10 +197,22 @@ def __new__(
117197 else :
118198 # Align observed dims with those of the RV
119199 # TODO: If this fails give a more informative error message
120- observed = observed .transpose (* rv_dims ).values
200+ observed = observed .transpose (* rv_dims )
201+
202+ # Check user didn't pass regular transforms
203+ if transform not in (UNSET , None ):
204+ if not isinstance (transform , DimTransform ):
205+ raise TypeError (
206+ f"Transform must be a DimTransform, form pymc.dims.transforms, but got { type (transform )} ."
207+ )
208+ if default_transform not in (UNSET , None ):
209+ if not isinstance (default_transform , DimTransform ):
210+ raise TypeError (
211+ f"default_transform must be a DimTransform, from pymc.dims.transforms, but got { type (default_transform )} ."
212+ )
121213
122214 rv = model .register_rv (
123- rv . values ,
215+ rv ,
124216 name = name ,
125217 observed = observed ,
126218 total_size = total_size ,
@@ -151,13 +243,16 @@ def dist(
151243 if dims_dict is None :
152244 extra_dims = None
153245 else :
154- parameter_implied_dims = set (
155- chain .from_iterable (param .type .dims for param in dist_params )
156- )
246+ # Exclude dims that are implied by the parameters or core_dims
247+ implied_dims = set (chain .from_iterable (param .type .dims for param in dist_params ))
248+ if core_dims is not None :
249+ if isinstance (core_dims , str ):
250+ implied_dims .add (core_dims )
251+ else :
252+ implied_dims .update (core_dims )
253+
157254 extra_dims = {
158- dim : length
159- for dim , length in dims_dict .items ()
160- if dim not in parameter_implied_dims
255+ dim : length for dim , length in dims_dict .items () if dim not in implied_dims
161256 }
162257 return cls .xrv_op (* dist_params , extra_dims = extra_dims , core_dims = core_dims , ** kwargs )
163258
@@ -177,10 +272,10 @@ def dist(self, *args, core_dims: str | Sequence[str] | None = None, **kwargs):
177272class PositiveDimDistribution (DimDistribution ):
178273 """Base class for positive continuous distributions."""
179274
180- default_transform = transforms . log
275+ default_transform = log_transform
181276
182277
183278class UnitDimDistribution (DimDistribution ):
184279 """Base class for unit-valued distributions."""
185280
186- default_transform = transforms . logodds
281+ default_transform = log_odds_transform
0 commit comments