@@ -370,74 +370,73 @@ def local_subtensor_merge(fgraph, node):
370
370
"""
371
371
from pytensor .scan .op import Scan
372
372
373
- if isinstance (node .op , Subtensor ):
374
- u = node .inputs [0 ]
375
- if u .owner and isinstance (u .owner .op , Subtensor ):
376
- # We can merge :)
377
- # x actual tensor on which we are picking slices
378
- x = u .owner .inputs [0 ]
379
- # slices of the first applied subtensor
380
- slices1 = get_idx_list (u .owner .inputs , u .owner .op .idx_list )
381
- slices2 = get_idx_list (node .inputs , node .op .idx_list )
382
-
383
- # Don't try to do the optimization on do-while scan outputs,
384
- # as it will create a dependency on the shape of the outputs
385
- if (
386
- x .owner is not None
387
- and isinstance (x .owner .op , Scan )
388
- and x .owner .op .info .as_while
389
- ):
390
- return None
373
+ u = node .inputs [0 ]
374
+ if not (u .owner is not None and isinstance (u .owner .op , Subtensor )):
375
+ return None
391
376
392
- # Get the shapes of the vectors !
393
- try :
394
- # try not to introduce new shape into the graph
395
- xshape = fgraph .shape_feature .shape_of [x ]
396
- ushape = fgraph .shape_feature .shape_of [u ]
397
- except AttributeError :
398
- # Following the suggested use of shape_feature which should
399
- # consider the case when the compilation mode doesn't
400
- # include the ShapeFeature
401
- xshape = x .shape
402
- ushape = u .shape
403
-
404
- merged_slices = []
405
- pos_2 = 0
406
- pos_1 = 0
407
- while (pos_1 < len (slices1 )) and (pos_2 < len (slices2 )):
408
- slice1 = slices1 [pos_1 ]
409
- if isinstance (slice1 , slice ):
410
- merged_slices .append (
411
- merge_two_slices (
412
- fgraph , slice1 , xshape [pos_1 ], slices2 [pos_2 ], ushape [pos_2 ]
413
- )
414
- )
415
- pos_2 += 1
416
- else :
417
- merged_slices .append (slice1 )
418
- pos_1 += 1
419
-
420
- if pos_2 < len (slices2 ):
421
- merged_slices += slices2 [pos_2 :]
422
- else :
423
- merged_slices += slices1 [pos_1 :]
377
+ # We can merge :)
378
+ # x actual tensor on which we are picking slices
379
+ x = u .owner .inputs [0 ]
380
+ # slices of the first applied subtensor
381
+ slices1 = get_idx_list (u .owner .inputs , u .owner .op .idx_list )
382
+ slices2 = get_idx_list (node .inputs , node .op .idx_list )
424
383
425
- merged_slices = tuple (as_index_constant (s ) for s in merged_slices )
426
- subtens = Subtensor (merged_slices )
384
+ # Don't try to do the optimization on do-while scan outputs,
385
+ # as it will create a dependency on the shape of the outputs
386
+ if (
387
+ x .owner is not None
388
+ and isinstance (x .owner .op , Scan )
389
+ and x .owner .op .info .as_while
390
+ ):
391
+ return None
427
392
428
- sl_ins = get_slice_elements (
429
- merged_slices , lambda x : isinstance (x , Variable )
393
+ # Get the shapes of the vectors !
394
+ try :
395
+ # try not to introduce new shape into the graph
396
+ xshape = fgraph .shape_feature .shape_of [x ]
397
+ ushape = fgraph .shape_feature .shape_of [u ]
398
+ except AttributeError :
399
+ # Following the suggested use of shape_feature which should
400
+ # consider the case when the compilation mode doesn't
401
+ # include the ShapeFeature
402
+ xshape = x .shape
403
+ ushape = u .shape
404
+
405
+ merged_slices = []
406
+ pos_2 = 0
407
+ pos_1 = 0
408
+ while (pos_1 < len (slices1 )) and (pos_2 < len (slices2 )):
409
+ slice1 = slices1 [pos_1 ]
410
+ if isinstance (slice1 , slice ):
411
+ merged_slices .append (
412
+ merge_two_slices (
413
+ fgraph , slice1 , xshape [pos_1 ], slices2 [pos_2 ], ushape [pos_2 ]
414
+ )
430
415
)
431
- # Do not call make_node for test_value
432
- out = subtens (x , * sl_ins )
416
+ pos_2 += 1
417
+ else :
418
+ merged_slices .append (slice1 )
419
+ pos_1 += 1
433
420
434
- # Copy over previous output stacktrace
435
- # and stacktrace from previous slicing operation.
436
- # Why? Because, the merged slicing operation could have failed
437
- # because of either of the two original slicing operations
438
- orig_out = node .outputs [0 ]
439
- copy_stack_trace ([orig_out , node .inputs [0 ]], out )
440
- return [out ]
421
+ if pos_2 < len (slices2 ):
422
+ merged_slices += slices2 [pos_2 :]
423
+ else :
424
+ merged_slices += slices1 [pos_1 :]
425
+
426
+ merged_slices = tuple (as_index_constant (s ) for s in merged_slices )
427
+ subtens = Subtensor (merged_slices )
428
+
429
+ sl_ins = get_slice_elements (merged_slices , lambda x : isinstance (x , Variable ))
430
+ # Do not call make_node for test_value
431
+ out = subtens (x , * sl_ins )
432
+
433
+ # Copy over previous output stacktrace
434
+ # and stacktrace from previous slicing operation.
435
+ # Why? Because, the merged slicing operation could have failed
436
+ # because of either of the two original slicing operations
437
+ orig_out = node .outputs [0 ]
438
+ copy_stack_trace ([orig_out , node .inputs [0 ]], out )
439
+ return [out ]
441
440
442
441
443
442
@register_specialize
@@ -826,6 +825,12 @@ def merge_two_slices(fgraph, slice1, len1, slice2, len2):
826
825
if not isinstance (slice1 , slice ):
827
826
raise ValueError ("slice1 should be of type `slice`" )
828
827
828
+ # Simple case where one of the slices is useless
829
+ if is_full_slice (slice1 ):
830
+ return slice2
831
+ elif is_full_slice (slice2 ):
832
+ return slice1
833
+
829
834
sl1 , reverse1 = get_canonical_form_slice (slice1 , len1 )
830
835
sl2 , reverse2 = get_canonical_form_slice (slice2 , len2 )
831
836
0 commit comments