Skip to content

Commit e1ce1c3

Browse files
committed
Refactor lower_aligned helper
1 parent 45a33ad commit e1ce1c3

File tree

3 files changed

+25
-42
lines changed

3 files changed

+25
-42
lines changed

pytensor/xtensor/rewriting/shape.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
)
1010
from pytensor.xtensor.basic import tensor_from_xtensor, xtensor_from_tensor
1111
from pytensor.xtensor.rewriting.basic import register_lower_xtensor
12+
from pytensor.xtensor.rewriting.utils import lower_aligned
1213
from pytensor.xtensor.shape import (
1314
Concat,
1415
ExpandDims,
@@ -70,15 +71,7 @@ def lower_concat(fgraph, node):
7071
concat_axis = out_dims.index(concat_dim)
7172

7273
# Convert input XTensors to Tensors and align batch dimensions
73-
tensor_inputs = []
74-
for inp in node.inputs:
75-
inp_dims = inp.type.dims
76-
order = [
77-
inp_dims.index(out_dim) if out_dim in inp_dims else "x"
78-
for out_dim in out_dims
79-
]
80-
tensor_inp = tensor_from_xtensor(inp).dimshuffle(order)
81-
tensor_inputs.append(tensor_inp)
74+
tensor_inputs = [lower_aligned(inp, out_dims) for inp in node.inputs]
8275

8376
# Broadcast non-concatenated dimensions of each input
8477
non_concat_shape = [None] * len(out_dims)

pytensor/xtensor/rewriting/utils.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,12 @@
1+
import typing
2+
from collections.abc import Sequence
3+
14
from pytensor.compile import optdb
25
from pytensor.graph.rewriting.basic import NodeRewriter, in2out
36
from pytensor.graph.rewriting.db import EquilibriumDB, RewriteDatabase
47
from pytensor.tensor.rewriting.ofg import inline_ofg_expansion
8+
from pytensor.tensor.variable import TensorVariable
9+
from pytensor.xtensor.type import XTensorVariable
510

611

712
lower_xtensor_db = EquilibriumDB(ignore_newtrees=False)
@@ -49,3 +54,10 @@ def register(inner_rewriter: RewriteDatabase | NodeRewriter):
4954
**kwargs,
5055
)
5156
return node_rewriter
57+
58+
59+
def lower_aligned(x: XTensorVariable, out_dims: Sequence[str]) -> TensorVariable:
60+
"""Lower an XTensorVariable to a TensorVariable so that it's dimensions are aligned with "out_dims"."""
61+
inp_dims = {d: i for i, d in enumerate(x.type.dims)}
62+
ds_order = tuple(inp_dims.get(dim, "x") for dim in out_dims)
63+
return typing.cast(TensorVariable, x.values.dimshuffle(ds_order))

pytensor/xtensor/rewriting/vectorization.py

Lines changed: 11 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22
from pytensor.tensor.blockwise import Blockwise
33
from pytensor.tensor.elemwise import Elemwise
44
from pytensor.tensor.random.utils import compute_batch_shape
5-
from pytensor.xtensor.basic import tensor_from_xtensor, xtensor_from_tensor
6-
from pytensor.xtensor.rewriting.utils import register_lower_xtensor
5+
from pytensor.xtensor.basic import xtensor_from_tensor
6+
from pytensor.xtensor.rewriting.utils import lower_aligned, register_lower_xtensor
77
from pytensor.xtensor.vectorization import XRV, XBlockwise, XElemwise
88

99

@@ -13,15 +13,7 @@ def lower_elemwise(fgraph, node):
1313
out_dims = node.outputs[0].type.dims
1414

1515
# Convert input XTensors to Tensors and align batch dimensions
16-
tensor_inputs = []
17-
for inp in node.inputs:
18-
inp_dims = inp.type.dims
19-
order = [
20-
inp_dims.index(out_dim) if out_dim in inp_dims else "x"
21-
for out_dim in out_dims
22-
]
23-
tensor_inp = tensor_from_xtensor(inp).dimshuffle(order)
24-
tensor_inputs.append(tensor_inp)
16+
tensor_inputs = [lower_aligned(inp, out_dims) for inp in node.inputs]
2517

2618
tensor_outs = Elemwise(scalar_op=node.op.scalar_op)(
2719
*tensor_inputs, return_list=True
@@ -42,17 +34,10 @@ def lower_blockwise(fgraph, node):
4234
batch_dims = node.outputs[0].type.dims[:batch_ndim]
4335

4436
# Convert input Tensors to XTensors, align batch dimensions and place core dimension at the end
45-
tensor_inputs = []
46-
for inp, core_dims in zip(node.inputs, op.core_dims[0]):
47-
inp_dims = inp.type.dims
48-
# Align the batch dims of the input, and place the core dims on the right
49-
batch_order = [
50-
inp_dims.index(batch_dim) if batch_dim in inp_dims else "x"
51-
for batch_dim in batch_dims
52-
]
53-
core_order = [inp_dims.index(core_dim) for core_dim in core_dims]
54-
tensor_inp = tensor_from_xtensor(inp).dimshuffle(batch_order + core_order)
55-
tensor_inputs.append(tensor_inp)
37+
tensor_inputs = [
38+
lower_aligned(inp, batch_dims + core_dims)
39+
for inp, core_dims in zip(node.inputs, op.core_dims[0], strict=True)
40+
]
5641

5742
signature = op.signature or getattr(op.core_op, "gufunc_signature", None)
5843
if signature is None:
@@ -92,17 +77,10 @@ def lower_rv(fgraph, node):
9277
param_batch_dims = old_out.type.dims[len(op.extra_dims) : batch_ndim]
9378

9479
# Convert params Tensors to XTensors, align batch dimensions and place core dimension at the end
95-
tensor_params = []
96-
for inp, core_dims in zip(params, op.core_dims[0]):
97-
inp_dims = inp.type.dims
98-
# Align the batch dims of the input, and place the core dims on the right
99-
batch_order = [
100-
inp_dims.index(batch_dim) if batch_dim in inp_dims else "x"
101-
for batch_dim in param_batch_dims
102-
]
103-
core_order = [inp_dims.index(core_dim) for core_dim in core_dims]
104-
tensor_inp = tensor_from_xtensor(inp).dimshuffle(batch_order + core_order)
105-
tensor_params.append(tensor_inp)
80+
tensor_params = [
81+
lower_aligned(inp, param_batch_dims + core_dims)
82+
for inp, core_dims in zip(params, op.core_dims[0], strict=True)
83+
]
10684

10785
size = None
10886
if op.extra_dims:

0 commit comments

Comments
 (0)