@@ -336,35 +336,46 @@ def local_subtensor_of_dot(fgraph, node):
336
336
@node_rewriter ([Subtensor ])
337
337
def local_useless_slice (fgraph , node ):
338
338
"""
339
- Remove Subtensor of the form X[0, :] -> X[0]
339
+ Remove Subtensor of the form:
340
+ 1. X[0, :] -> X[0]
341
+ 2. X[:] -> X
342
+
340
343
"""
341
- if isinstance (node .op , Subtensor ):
342
- slices = get_idx_list (node .inputs , node .op .idx_list )
343
- last_slice = len (slices )
344
- for s in slices [::- 1 ]:
345
- # check if slice and then check slice indices
346
- if (
347
- isinstance (s , slice )
348
- and s .start is None
349
- and s .stop is None
350
- and (
351
- s .step is None
352
- or extract_constant (s .step , only_process_constants = True ) == 1
353
- )
354
- ):
355
- last_slice -= 1
356
- else :
357
- break
358
- # check if we removed something
359
- if last_slice < len (slices ):
360
- subtens = Subtensor (slices [:last_slice ])
361
- sl_ins = get_slice_elements (
362
- slices [:last_slice ], lambda x : isinstance (x , Variable )
344
+ idxs = get_idx_list (node .inputs , node .op .idx_list )
345
+
346
+ if not idxs :
347
+ return [node .inputs [0 ]]
348
+
349
+ last_useless_slice = len (idxs )
350
+ for s in idxs [::- 1 ]:
351
+ # check if slice and then check slice indices
352
+ if (
353
+ isinstance (s , slice )
354
+ and s .start is None
355
+ and s .stop is None
356
+ and (
357
+ s .step is None
358
+ or extract_constant (s .step , only_process_constants = True ) == 1
359
+ )
360
+ ):
361
+ last_useless_slice -= 1
362
+ else :
363
+ break
364
+ # check if we removed something
365
+ if last_useless_slice < len (idxs ):
366
+ new_idxs = idxs [:last_useless_slice ]
367
+ if new_idxs :
368
+ new_subtensor = Subtensor (new_idxs )
369
+ new_subtensor_inputs = get_slice_elements (
370
+ new_idxs , lambda x : isinstance (x , Variable )
363
371
)
364
- out = subtens (node .inputs [0 ], * sl_ins )
372
+ out = new_subtensor (node .inputs [0 ], * new_subtensor_inputs )
365
373
# Copy over previous output stacktrace
366
374
copy_stack_trace (node .outputs , out )
367
375
return [out ]
376
+ else :
377
+ # Subtensor is not needed at all
378
+ return [node .inputs [0 ]]
368
379
369
380
370
381
# fast_compile to allow opt subtensor(cast{float32}(make_vector))
@@ -747,7 +758,13 @@ def local_subtensor_make_vector(fgraph, node):
747
758
make_vector_op = x .owner .op
748
759
749
760
if isinstance (node .op , Subtensor ):
750
- (idx ,) = node .op .idx_list
761
+ idxs = node .op .idx_list
762
+
763
+ # Subtensor has no indexes, return make_vector
764
+ if not idxs :
765
+ return [x ]
766
+
767
+ (idx ,) = idxs
751
768
752
769
if isinstance (idx , (aes .ScalarType , TensorType )):
753
770
old_idx , idx = idx , node .inputs [1 ]
@@ -903,7 +920,11 @@ def local_set_to_inc_subtensor(fgraph, node):
903
920
@node_rewriter ([Subtensor ])
904
921
def local_useless_subtensor (fgraph , node ):
905
922
"""Remove `Subtensor` if it takes the full input."""
906
- # This optimization needs ShapeOpt and fgraph.shape_feature
923
+
924
+ if not node .op .idx_list :
925
+ return [node .inputs [0 ]]
926
+
927
+ # The more elaborate optimization needs ShapeOpt and fgraph.shape_feature
907
928
if not hasattr (fgraph , "shape_feature" ):
908
929
return
909
930
0 commit comments