1414from collections .abc import Callable , Sequence
1515from itertools import chain
1616
17+ import numpy as np
18+
1719from pytensor .graph import node_rewriter
1820from pytensor .tensor .elemwise import DimShuffle
21+ from pytensor .tensor .random .op import RandomVariable
1922from pytensor .xtensor import as_xtensor
2023from pytensor .xtensor .basic import XTensorFromTensor , xtensor_from_tensor
2124from pytensor .xtensor .type import XTensorVariable
2225
23- from pymc import modelcontext
26+ from pymc import SymbolicRandomVariable , modelcontext
2427from pymc .dims .model import with_dims
25- from pymc .dims .transforms import log_odds_transform , log_transform
28+ from pymc .dims .transforms import DimTransform , log_odds_transform , log_transform
2629from pymc .distributions .distribution import _support_point , support_point
2730from pymc .distributions .shape_utils import DimsWithEllipsis , convert_dims
2831from pymc .logprob .abstract import MeasurableOp , _logprob
2932from pymc .logprob .rewriting import measurable_ir_rewrites_db
33+ from pymc .logprob .tensor import MeasurableDimShuffle
3034from pymc .logprob .utils import filter_measurable_variables
3135from pymc .util import UNSET
3236
@@ -46,24 +50,67 @@ def xtensor_from_tensor_support_point(xtensor_op, _, rv):
4650
4751
4852class MeasurableXTensorFromTensor (MeasurableOp , XTensorFromTensor ):
49- pass
53+ __props__ = ("dims" , "core_dims" )
54+
55+ def __init__ (self , dims , core_dims ):
56+ super ().__init__ (dims = dims )
57+ self .core_dims = tuple (core_dims ) if core_dims is not None else None
5058
5159
5260@node_rewriter ([XTensorFromTensor ])
5361def find_measurable_xtensor_from_tensor (fgraph , node ) -> list [XTensorVariable ] | None :
5462 if isinstance (node .op , MeasurableXTensorFromTensor ):
5563 return None
5664
57- if not filter_measurable_variables (node .inputs ):
58- return None
65+ xs = filter_measurable_variables (node .inputs )
66+
67+ if not xs :
68+ # Check if we have a transposition instead
69+ # The rewrite that introduces measurable tranpsoses refuses to apply to multivariate RVs
70+ # So we have a chance of inferring the core dims!
71+ [ds ] = node .inputs
72+ ds_node = ds .owner
73+ if not (
74+ ds_node is not None
75+ and isinstance (ds_node .op , DimShuffle )
76+ and ds_node .op .is_transpose
77+ and filter_measurable_variables (ds_node .inputs )
78+ ):
79+ return None
80+ [x ] = ds_node .inputs
81+ if not (
82+ x .owner is not None and isinstance (x .owner .op , RandomVariable | SymbolicRandomVariable )
83+ ):
84+ return None
85+
86+ measurable_x = MeasurableDimShuffle (** ds_node .op ._props_dict ())(x )
87+
88+ ndim_supp = x .owner .op .ndim_supp
89+ if ndim_supp :
90+ inverse_transpose = np .argsort (ds_node .op .shuffle )
91+ dims = node .op .dims
92+ dims_before_transpose = [dims [i ] for i in inverse_transpose ]
93+ core_dims = dims_before_transpose [- ndim_supp :]
94+ else :
95+ core_dims = ()
5996
60- return [MeasurableXTensorFromTensor (dims = node .op .dims )(* node .inputs )]
97+ return [MeasurableXTensorFromTensor (dims = node .op .dims , core_dims = core_dims )(measurable_x )]
98+ else :
99+ # If this happens we know there's no measurable transpose in between and we can
100+ # safely infer the core_dims positionally when the inner logp is returned
101+ return [MeasurableXTensorFromTensor (dims = node .op .dims , core_dims = None )(* node .inputs )]
61102
62103
63104@_logprob .register (MeasurableXTensorFromTensor )
64105def measurable_xtensor_from_tensor (op , values , rv , ** kwargs ):
65106 rv_logp = _logprob (rv .owner .op , tuple (v .values for v in values ), * rv .owner .inputs , ** kwargs )
66- return xtensor_from_tensor (rv_logp , dims = op .dims )
107+ if op .core_dims is None :
108+ # The core_dims of the inner rv are on the right
109+ dims = op .dims [: rv_logp .ndim ]
110+ else :
111+ # We inferred where the core_dims are!
112+ dims = [d for d in op .dims if d not in op .core_dims ]
113+ return xtensor_from_tensor (rv_logp , dims = dims )
67114
68115
69116measurable_ir_rewrites_db .register (
@@ -75,7 +122,7 @@ class DimDistribution:
75122 """Base class for PyMC distribution that wrap pytensor.xtensor.random operations, and follow xarray-like semantics."""
76123
77124 xrv_op : Callable
78- default_transform : Callable | None = None
125+ default_transform : DimTransform | None = None
79126
80127 @staticmethod
81128 def _as_xtensor (x ):
@@ -156,6 +203,18 @@ def __new__(
156203 # TODO: If this fails give a more informative error message
157204 observed = observed .transpose (* rv_dims )
158205
206+ # Check user didn't pass regular transforms
207+ if transform not in (UNSET , None ):
208+ if not isinstance (transform , DimTransform ):
209+ raise TypeError (
210+ f"Transform must be a DimTransform, form pymc.dims.transforms, but got { type (transform )} ."
211+ )
212+ if default_transform not in (UNSET , None ):
213+ if not isinstance (default_transform , DimTransform ):
214+ raise TypeError (
215+ f"default_transform must be a DimTransform, from pymc.dims.transforms, but got { type (default_transform )} ."
216+ )
217+
159218 rv = model .register_rv (
160219 rv ,
161220 name = name ,
@@ -188,13 +247,16 @@ def dist(
188247 if dims_dict is None :
189248 extra_dims = None
190249 else :
191- parameter_implied_dims = set (
192- chain .from_iterable (param .type .dims for param in dist_params )
193- )
250+ # Exclude dims that are implied by the parameters or core_dims
251+ implied_dims = set (chain .from_iterable (param .type .dims for param in dist_params ))
252+ if core_dims is not None :
253+ if isinstance (core_dims , str ):
254+ implied_dims .add (core_dims )
255+ else :
256+ implied_dims .update (core_dims )
257+
194258 extra_dims = {
195- dim : length
196- for dim , length in dims_dict .items ()
197- if dim not in parameter_implied_dims
259+ dim : length for dim , length in dims_dict .items () if dim not in implied_dims
198260 }
199261 return cls .xrv_op (* dist_params , extra_dims = extra_dims , core_dims = core_dims , ** kwargs )
200262
0 commit comments