22from pytensor .tensor .blockwise import Blockwise
33from pytensor .tensor .elemwise import Elemwise
44from pytensor .tensor .random .utils import compute_batch_shape
5- from pytensor .xtensor .basic import tensor_from_xtensor , xtensor_from_tensor
6- from pytensor .xtensor .rewriting .utils import register_lower_xtensor
5+ from pytensor .xtensor .basic import xtensor_from_tensor
6+ from pytensor .xtensor .rewriting .utils import lower_aligned , register_lower_xtensor
77from pytensor .xtensor .vectorization import XRV , XBlockwise , XElemwise
88
99
@@ -13,15 +13,7 @@ def lower_elemwise(fgraph, node):
1313 out_dims = node .outputs [0 ].type .dims
1414
1515 # Convert input XTensors to Tensors and align batch dimensions
16- tensor_inputs = []
17- for inp in node .inputs :
18- inp_dims = inp .type .dims
19- order = [
20- inp_dims .index (out_dim ) if out_dim in inp_dims else "x"
21- for out_dim in out_dims
22- ]
23- tensor_inp = tensor_from_xtensor (inp ).dimshuffle (order )
24- tensor_inputs .append (tensor_inp )
16+ tensor_inputs = [lower_aligned (inp , out_dims ) for inp in node .inputs ]
2517
2618 tensor_outs = Elemwise (scalar_op = node .op .scalar_op )(
2719 * tensor_inputs , return_list = True
@@ -42,17 +34,10 @@ def lower_blockwise(fgraph, node):
4234 batch_dims = node .outputs [0 ].type .dims [:batch_ndim ]
4335
4436 # Convert input Tensors to XTensors, align batch dimensions and place core dimension at the end
45- tensor_inputs = []
46- for inp , core_dims in zip (node .inputs , op .core_dims [0 ]):
47- inp_dims = inp .type .dims
48- # Align the batch dims of the input, and place the core dims on the right
49- batch_order = [
50- inp_dims .index (batch_dim ) if batch_dim in inp_dims else "x"
51- for batch_dim in batch_dims
52- ]
53- core_order = [inp_dims .index (core_dim ) for core_dim in core_dims ]
54- tensor_inp = tensor_from_xtensor (inp ).dimshuffle (batch_order + core_order )
55- tensor_inputs .append (tensor_inp )
37+ tensor_inputs = [
38+ lower_aligned (inp , batch_dims + core_dims )
39+ for inp , core_dims in zip (node .inputs , op .core_dims [0 ], strict = True )
40+ ]
5641
5742 signature = op .signature or getattr (op .core_op , "gufunc_signature" , None )
5843 if signature is None :
@@ -92,17 +77,10 @@ def lower_rv(fgraph, node):
9277 param_batch_dims = old_out .type .dims [len (op .extra_dims ) : batch_ndim ]
9378
9479 # Convert params Tensors to XTensors, align batch dimensions and place core dimension at the end
95- tensor_params = []
96- for inp , core_dims in zip (params , op .core_dims [0 ]):
97- inp_dims = inp .type .dims
98- # Align the batch dims of the input, and place the core dims on the right
99- batch_order = [
100- inp_dims .index (batch_dim ) if batch_dim in inp_dims else "x"
101- for batch_dim in param_batch_dims
102- ]
103- core_order = [inp_dims .index (core_dim ) for core_dim in core_dims ]
104- tensor_inp = tensor_from_xtensor (inp ).dimshuffle (batch_order + core_order )
105- tensor_params .append (tensor_inp )
80+ tensor_params = [
81+ lower_aligned (inp , param_batch_dims + core_dims )
82+ for inp , core_dims in zip (params , op .core_dims [0 ], strict = True )
83+ ]
10684
10785 size = None
10886 if op .extra_dims :
0 commit comments