Skip to content

Commit 688a70b

Browse files
committed
Avoid canonicalization of slices when merging non-overlapping slices in local_subtensor_merge
1 parent 44be813 commit 688a70b

File tree

1 file changed

+68
-63
lines changed

1 file changed

+68
-63
lines changed

pytensor/tensor/rewriting/subtensor.py

Lines changed: 68 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -370,74 +370,73 @@ def local_subtensor_merge(fgraph, node):
370370
"""
371371
from pytensor.scan.op import Scan
372372

373-
if isinstance(node.op, Subtensor):
374-
u = node.inputs[0]
375-
if u.owner and isinstance(u.owner.op, Subtensor):
376-
# We can merge :)
377-
# x actual tensor on which we are picking slices
378-
x = u.owner.inputs[0]
379-
# slices of the first applied subtensor
380-
slices1 = get_idx_list(u.owner.inputs, u.owner.op.idx_list)
381-
slices2 = get_idx_list(node.inputs, node.op.idx_list)
382-
383-
# Don't try to do the optimization on do-while scan outputs,
384-
# as it will create a dependency on the shape of the outputs
385-
if (
386-
x.owner is not None
387-
and isinstance(x.owner.op, Scan)
388-
and x.owner.op.info.as_while
389-
):
390-
return None
373+
u = node.inputs[0]
374+
if not (u.owner is not None and isinstance(u.owner.op, Subtensor)):
375+
return None
391376

392-
# Get the shapes of the vectors !
393-
try:
394-
# try not to introduce new shape into the graph
395-
xshape = fgraph.shape_feature.shape_of[x]
396-
ushape = fgraph.shape_feature.shape_of[u]
397-
except AttributeError:
398-
# Following the suggested use of shape_feature which should
399-
# consider the case when the compilation mode doesn't
400-
# include the ShapeFeature
401-
xshape = x.shape
402-
ushape = u.shape
403-
404-
merged_slices = []
405-
pos_2 = 0
406-
pos_1 = 0
407-
while (pos_1 < len(slices1)) and (pos_2 < len(slices2)):
408-
slice1 = slices1[pos_1]
409-
if isinstance(slice1, slice):
410-
merged_slices.append(
411-
merge_two_slices(
412-
fgraph, slice1, xshape[pos_1], slices2[pos_2], ushape[pos_2]
413-
)
414-
)
415-
pos_2 += 1
416-
else:
417-
merged_slices.append(slice1)
418-
pos_1 += 1
419-
420-
if pos_2 < len(slices2):
421-
merged_slices += slices2[pos_2:]
422-
else:
423-
merged_slices += slices1[pos_1:]
377+
# We can merge :)
378+
# x actual tensor on which we are picking slices
379+
x = u.owner.inputs[0]
380+
# slices of the first applied subtensor
381+
slices1 = get_idx_list(u.owner.inputs, u.owner.op.idx_list)
382+
slices2 = get_idx_list(node.inputs, node.op.idx_list)
424383

425-
merged_slices = tuple(as_index_constant(s) for s in merged_slices)
426-
subtens = Subtensor(merged_slices)
384+
# Don't try to do the optimization on do-while scan outputs,
385+
# as it will create a dependency on the shape of the outputs
386+
if (
387+
x.owner is not None
388+
and isinstance(x.owner.op, Scan)
389+
and x.owner.op.info.as_while
390+
):
391+
return None
427392

428-
sl_ins = get_slice_elements(
429-
merged_slices, lambda x: isinstance(x, Variable)
393+
# Get the shapes of the vectors !
394+
try:
395+
# try not to introduce new shape into the graph
396+
xshape = fgraph.shape_feature.shape_of[x]
397+
ushape = fgraph.shape_feature.shape_of[u]
398+
except AttributeError:
399+
# Following the suggested use of shape_feature which should
400+
# consider the case when the compilation mode doesn't
401+
# include the ShapeFeature
402+
xshape = x.shape
403+
ushape = u.shape
404+
405+
merged_slices = []
406+
pos_2 = 0
407+
pos_1 = 0
408+
while (pos_1 < len(slices1)) and (pos_2 < len(slices2)):
409+
slice1 = slices1[pos_1]
410+
if isinstance(slice1, slice):
411+
merged_slices.append(
412+
merge_two_slices(
413+
fgraph, slice1, xshape[pos_1], slices2[pos_2], ushape[pos_2]
414+
)
430415
)
431-
# Do not call make_node for test_value
432-
out = subtens(x, *sl_ins)
416+
pos_2 += 1
417+
else:
418+
merged_slices.append(slice1)
419+
pos_1 += 1
433420

434-
# Copy over previous output stacktrace
435-
# and stacktrace from previous slicing operation.
436-
# Why? Because, the merged slicing operation could have failed
437-
# because of either of the two original slicing operations
438-
orig_out = node.outputs[0]
439-
copy_stack_trace([orig_out, node.inputs[0]], out)
440-
return [out]
421+
if pos_2 < len(slices2):
422+
merged_slices += slices2[pos_2:]
423+
else:
424+
merged_slices += slices1[pos_1:]
425+
426+
merged_slices = tuple(as_index_constant(s) for s in merged_slices)
427+
subtens = Subtensor(merged_slices)
428+
429+
sl_ins = get_slice_elements(merged_slices, lambda x: isinstance(x, Variable))
430+
# Do not call make_node for test_value
431+
out = subtens(x, *sl_ins)
432+
433+
# Copy over previous output stacktrace
434+
# and stacktrace from previous slicing operation.
435+
# Why? Because, the merged slicing operation could have failed
436+
# because of either of the two original slicing operations
437+
orig_out = node.outputs[0]
438+
copy_stack_trace([orig_out, node.inputs[0]], out)
439+
return [out]
441440

442441

443442
@register_specialize
@@ -788,6 +787,12 @@ def merge_two_slices(fgraph, slice1, len1, slice2, len2):
788787
if not isinstance(slice1, slice):
789788
raise ValueError("slice1 should be of type `slice`")
790789

790+
# Simple case where one of the slices is useless
791+
if is_full_slice(slice1):
792+
return slice2
793+
elif is_full_slice(slice2):
794+
return slice1
795+
791796
sl1, reverse1 = get_canonical_form_slice(slice1, len1)
792797
sl2, reverse2 = get_canonical_form_slice(slice2, len2)
793798

0 commit comments

Comments
 (0)