2
2
from pytensor .tensor .blockwise import Blockwise
3
3
from pytensor .tensor .elemwise import Elemwise
4
4
from pytensor .tensor .random .utils import compute_batch_shape
5
- from pytensor .xtensor .basic import tensor_from_xtensor , xtensor_from_tensor
6
- from pytensor .xtensor .rewriting .utils import register_lower_xtensor
5
+ from pytensor .xtensor .basic import xtensor_from_tensor
6
+ from pytensor .xtensor .rewriting .utils import lower_aligned , register_lower_xtensor
7
7
from pytensor .xtensor .vectorization import XRV , XBlockwise , XElemwise
8
8
9
9
@@ -13,15 +13,7 @@ def lower_elemwise(fgraph, node):
13
13
out_dims = node .outputs [0 ].type .dims
14
14
15
15
# Convert input XTensors to Tensors and align batch dimensions
16
- tensor_inputs = []
17
- for inp in node .inputs :
18
- inp_dims = inp .type .dims
19
- order = [
20
- inp_dims .index (out_dim ) if out_dim in inp_dims else "x"
21
- for out_dim in out_dims
22
- ]
23
- tensor_inp = tensor_from_xtensor (inp ).dimshuffle (order )
24
- tensor_inputs .append (tensor_inp )
16
+ tensor_inputs = [lower_aligned (inp , out_dims ) for inp in node .inputs ]
25
17
26
18
tensor_outs = Elemwise (scalar_op = node .op .scalar_op )(
27
19
* tensor_inputs , return_list = True
@@ -42,17 +34,10 @@ def lower_blockwise(fgraph, node):
42
34
batch_dims = node .outputs [0 ].type .dims [:batch_ndim ]
43
35
44
36
# Convert input Tensors to XTensors, align batch dimensions and place core dimension at the end
45
- tensor_inputs = []
46
- for inp , core_dims in zip (node .inputs , op .core_dims [0 ]):
47
- inp_dims = inp .type .dims
48
- # Align the batch dims of the input, and place the core dims on the right
49
- batch_order = [
50
- inp_dims .index (batch_dim ) if batch_dim in inp_dims else "x"
51
- for batch_dim in batch_dims
52
- ]
53
- core_order = [inp_dims .index (core_dim ) for core_dim in core_dims ]
54
- tensor_inp = tensor_from_xtensor (inp ).dimshuffle (batch_order + core_order )
55
- tensor_inputs .append (tensor_inp )
37
+ tensor_inputs = [
38
+ lower_aligned (inp , batch_dims + core_dims )
39
+ for inp , core_dims in zip (node .inputs , op .core_dims [0 ], strict = True )
40
+ ]
56
41
57
42
signature = op .signature or getattr (op .core_op , "gufunc_signature" , None )
58
43
if signature is None :
@@ -92,17 +77,10 @@ def lower_rv(fgraph, node):
92
77
param_batch_dims = old_out .type .dims [len (op .extra_dims ) : batch_ndim ]
93
78
94
79
# Convert params Tensors to XTensors, align batch dimensions and place core dimension at the end
95
- tensor_params = []
96
- for inp , core_dims in zip (params , op .core_dims [0 ]):
97
- inp_dims = inp .type .dims
98
- # Align the batch dims of the input, and place the core dims on the right
99
- batch_order = [
100
- inp_dims .index (batch_dim ) if batch_dim in inp_dims else "x"
101
- for batch_dim in param_batch_dims
102
- ]
103
- core_order = [inp_dims .index (core_dim ) for core_dim in core_dims ]
104
- tensor_inp = tensor_from_xtensor (inp ).dimshuffle (batch_order + core_order )
105
- tensor_params .append (tensor_inp )
80
+ tensor_params = [
81
+ lower_aligned (inp , param_batch_dims + core_dims )
82
+ for inp , core_dims in zip (params , op .core_dims [0 ], strict = True )
83
+ ]
106
84
107
85
size = None
108
86
if op .extra_dims :
0 commit comments