Skip to content

Commit 2165721

Browse files
committed
Specialize split as subtensor
1 parent f9dda00 commit 2165721

File tree

1 file changed

+34
-0
lines changed

1 file changed

+34
-0
lines changed

pytensor/tensor/rewriting/subtensor.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
Join,
2222
MakeVector,
2323
ScalarFromTensor,
24+
Split,
2425
TensorFromScalar,
2526
alloc,
2627
as_tensor,
@@ -616,6 +617,39 @@ def local_subtensor_remove_broadcastable_index(fgraph, node):
616617
return [node.inputs[0].dimshuffle(tuple(remain_dim))]
617618

618619

620+
@register_specialize("shape_unsafe")
621+
@node_rewriter(tracks=[Split])
622+
def split_to_subtensor(fgraph, node):
623+
"""Rewrite split(2)(x, 0) -> (x[:split_sizes[0]], x[split_sizes[0]:).
624+
625+
This allows lifting the underlying split close to the inputs, and increases fusion opportunities.
626+
It automatically handles unused split outputs as well
627+
628+
It only works for constant axis
629+
"""
630+
x, axis, split_sizes = node.inputs
631+
632+
n_splits = node.op.len_splits
633+
if n_splits <= 1:
634+
return [x]
635+
636+
if not isinstance(axis, Constant):
637+
return None
638+
639+
empty_slices = (slice(None),) * int(axis.data)
640+
ys = []
641+
642+
end = split_sizes[0]
643+
ys.append(x[(*empty_slices, slice(None, end))])
644+
prev_start = end
645+
for i in range(1, n_splits - 1):
646+
end = prev_start + split_sizes[i]
647+
ys.append(x[(*empty_slices, slice(prev_start, end))])
648+
prev_start = end
649+
ys.append(x[(*empty_slices, slice(prev_start, None))])
650+
return ys
651+
652+
619653
@register_infer_shape
620654
@register_useless
621655
@register_canonicalize

0 commit comments

Comments
 (0)