13
13
# limitations under the License.
14
14
from collections .abc import Callable , Sequence
15
15
from itertools import chain
16
+ from typing import cast
16
17
18
+ import numpy as np
19
+
20
+ from pytensor .graph import node_rewriter
17
21
from pytensor .graph .basic import Variable
18
22
from pytensor .tensor .elemwise import DimShuffle
23
+ from pytensor .tensor .random .op import RandomVariable
19
24
from pytensor .xtensor import as_xtensor
25
+ from pytensor .xtensor .basic import XTensorFromTensor , xtensor_from_tensor
20
26
from pytensor .xtensor .type import XTensorVariable
21
27
22
- from pymc import modelcontext
23
- from pymc .dims .model import with_dims
24
- from pymc .distributions import transforms
28
+ from pymc import SymbolicRandomVariable , modelcontext
29
+ from pymc .dims .distributions .transforms import DimTransform , log_odds_transform , log_transform
25
30
from pymc .distributions .distribution import _support_point , support_point
26
31
from pymc .distributions .shape_utils import DimsWithEllipsis , convert_dims_with_ellipsis
27
- from pymc .logprob .transforms import Transform
32
+ from pymc .logprob .abstract import MeasurableOp , _logprob
33
+ from pymc .logprob .rewriting import measurable_ir_rewrites_db
34
+ from pymc .logprob .tensor import MeasurableDimShuffle
35
+ from pymc .logprob .utils import filter_measurable_variables
28
36
from pymc .util import UNSET
29
37
30
38
@@ -36,25 +44,98 @@ def dimshuffle_support_point(ds_op, _, rv):
36
44
return ds_op (support_point (rv ))
37
45
38
46
47
+ @_support_point .register (XTensorFromTensor )
48
+ def xtensor_from_tensor_support_point (xtensor_op , _ , rv ):
49
+ # We remove the xtensor_from_tensor operation, so initial_point doesn't have to do a further lowering
50
+ return xtensor_op (support_point (rv ))
51
+
52
+
53
+ class MeasurableXTensorFromTensor (MeasurableOp , XTensorFromTensor ):
54
+ __props__ = ("dims" , "core_dims" ) # type: ignore[assignment]
55
+
56
+ def __init__ (self , dims , core_dims ):
57
+ super ().__init__ (dims = dims )
58
+ self .core_dims = tuple (core_dims ) if core_dims is not None else None
59
+
60
+
61
+ @node_rewriter ([XTensorFromTensor ])
62
+ def find_measurable_xtensor_from_tensor (fgraph , node ) -> list [XTensorVariable ] | None :
63
+ if isinstance (node .op , MeasurableXTensorFromTensor ):
64
+ return None
65
+
66
+ xs = filter_measurable_variables (node .inputs )
67
+
68
+ if not xs :
69
+ # Check if we have a transposition instead
70
+ # The rewrite that introduces measurable tranpsoses refuses to apply to multivariate RVs
71
+ # So we have a chance of inferring the core dims!
72
+ [ds ] = node .inputs
73
+ ds_node = ds .owner
74
+ if not (
75
+ ds_node is not None
76
+ and isinstance (ds_node .op , DimShuffle )
77
+ and ds_node .op .is_transpose
78
+ and filter_measurable_variables (ds_node .inputs )
79
+ ):
80
+ return None
81
+ [x ] = ds_node .inputs
82
+ if not (
83
+ x .owner is not None and isinstance (x .owner .op , RandomVariable | SymbolicRandomVariable )
84
+ ):
85
+ return None
86
+
87
+ measurable_x = MeasurableDimShuffle (** ds_node .op ._props_dict ())(x ) # type: ignore[attr-defined]
88
+
89
+ ndim_supp = x .owner .op .ndim_supp
90
+ if ndim_supp :
91
+ inverse_transpose = np .argsort (ds_node .op .shuffle )
92
+ dims = node .op .dims
93
+ dims_before_transpose = tuple (dims [i ] for i in inverse_transpose )
94
+ core_dims = dims_before_transpose [- ndim_supp :]
95
+ else :
96
+ core_dims = ()
97
+
98
+ new_out = MeasurableXTensorFromTensor (dims = node .op .dims , core_dims = core_dims )(measurable_x )
99
+ else :
100
+ # If this happens we know there's no measurable transpose in between and we can
101
+ # safely infer the core_dims positionally when the inner logp is returned
102
+ new_out = MeasurableXTensorFromTensor (dims = node .op .dims , core_dims = None )(* node .inputs )
103
+ return [cast (XTensorVariable , new_out )]
104
+
105
+
106
+ @_logprob .register (MeasurableXTensorFromTensor )
107
+ def measurable_xtensor_from_tensor (op , values , rv , ** kwargs ):
108
+ rv_logp = _logprob (rv .owner .op , tuple (v .values for v in values ), * rv .owner .inputs , ** kwargs )
109
+ if op .core_dims is None :
110
+ # The core_dims of the inner rv are on the right
111
+ dims = op .dims [: rv_logp .ndim ]
112
+ else :
113
+ # We inferred where the core_dims are!
114
+ dims = [d for d in op .dims if d not in op .core_dims ]
115
+ return xtensor_from_tensor (rv_logp , dims = dims )
116
+
117
+
118
+ measurable_ir_rewrites_db .register (
119
+ "measurable_xtensor_from_tensor" , find_measurable_xtensor_from_tensor , "basic" , "xtensor"
120
+ )
121
+
122
+
39
123
class DimDistribution :
40
124
"""Base class for PyMC distribution that wrap pytensor.xtensor.random operations, and follow xarray-like semantics."""
41
125
42
126
xrv_op : Callable
43
- default_transform : Transform | None = None
127
+ default_transform : DimTransform | None = None
44
128
45
129
@staticmethod
46
130
def _as_xtensor (x ):
47
131
try :
48
132
return as_xtensor (x )
49
133
except TypeError :
50
- try :
51
- return with_dims (x )
52
- except ValueError :
53
- raise ValueError (
54
- f"Variable { x } must have dims associated with it.\n "
55
- "To avoid subtle bugs, PyMC does not make any assumptions about the dims of parameters.\n "
56
- "Use `as_xtensor` with the `dims` keyword argument to specify the dims explicitly."
57
- )
134
+ raise ValueError (
135
+ f"Variable { x } must have dims associated with it.\n "
136
+ "To avoid subtle bugs, PyMC does not make any assumptions about the dims of parameters.\n "
137
+ "Use `as_xtensor` with the `dims` keyword argument to specify the dims explicitly."
138
+ )
58
139
59
140
def __new__ (
60
141
cls ,
@@ -119,10 +200,22 @@ def __new__(
119
200
else :
120
201
# Align observed dims with those of the RV
121
202
# TODO: If this fails give a more informative error message
122
- observed = observed .transpose (* rv_dims ).values
203
+ observed = observed .transpose (* rv_dims )
204
+
205
+ # Check user didn't pass regular transforms
206
+ if transform not in (UNSET , None ):
207
+ if not isinstance (transform , DimTransform ):
208
+ raise TypeError (
209
+ f"Transform must be a DimTransform, form pymc.dims.transforms, but got { type (transform )} ."
210
+ )
211
+ if default_transform not in (UNSET , None ):
212
+ if not isinstance (default_transform , DimTransform ):
213
+ raise TypeError (
214
+ f"default_transform must be a DimTransform, from pymc.dims.transforms, but got { type (default_transform )} ."
215
+ )
123
216
124
217
rv = model .register_rv (
125
- rv . values ,
218
+ rv ,
126
219
name = name ,
127
220
observed = observed ,
128
221
total_size = total_size ,
@@ -182,10 +275,10 @@ def dist(self, *args, core_dims: str | Sequence[str] | None = None, **kwargs):
182
275
class PositiveDimDistribution (DimDistribution ):
183
276
"""Base class for positive continuous distributions."""
184
277
185
- default_transform = transforms . log
278
+ default_transform = log_transform
186
279
187
280
188
281
class UnitDimDistribution (DimDistribution ):
189
282
"""Base class for unit-valued distributions."""
190
283
191
- default_transform = transforms . logodds
284
+ default_transform = log_odds_transform
0 commit comments