@@ -221,6 +221,16 @@ def analyze(x):
221
221
step , is_step_constant = analyze (theslice .step )
222
222
length , is_length_constant = analyze (length )
223
223
224
+ if (
225
+ is_start_constant
226
+ and is_stop_constant
227
+ and is_step_constant
228
+ and is_length_constant
229
+ ):
230
+ _start , _stop , _step = slice (start , stop , step ).indices (length )
231
+ if _start <= _stop and _step >= 1 :
232
+ return slice (_start , _stop , _step ), 1
233
+
224
234
if step is None :
225
235
step = 1
226
236
is_step_constant = True
@@ -722,32 +732,51 @@ def make_node(self, x, *inputs):
722
732
f"Incompatible types for Subtensor template. Expected { input .type } , got { expected_type } ."
723
733
)
724
734
725
- # infer the broadcasting pattern
726
- padded = get_constant_idx (
727
- self . idx_list , ( None ,) + inputs , allow_partial = True
728
- ) + [ slice ( None , None , None )] * ( x . type . ndim - len ( idx_list ))
735
+ padded = [
736
+ * get_idx_list (( None ,) + inputs , self . idx_list ),
737
+ * [ slice ( None , None , None )] * ( x . type . ndim - len ( idx_list )),
738
+ ]
729
739
730
740
out_shape = []
731
- for i , (p , s ) in enumerate (zip (padded , x .type .shape )):
732
- if isinstance (p , slice ):
733
- if s == 1 :
734
- start = p .start
735
- try :
736
- start = get_underlying_scalar_constant_value (start )
737
- except NotScalarConstantError :
738
- pass
739
- if start is None or start == 0 :
740
- start = p .start
741
- if start is None :
742
- start = 0
743
- if p .stop is None or (
744
- isinstance (p .stop , (int , np .integer , np .ndarray ))
745
- and p .stop > start
746
- ):
747
- out_shape .append (1 )
748
- continue
749
741
742
+ def extract_const (value ):
743
+ if value is None :
744
+ return value , True
745
+ try :
746
+ value = get_underlying_scalar_constant_value (value )
747
+ return value , True
748
+ except NotScalarConstantError :
749
+ return value , False
750
+
751
+ for the_slice , length in zip (padded , x .type .shape ):
752
+ if not isinstance (the_slice , slice ):
753
+ continue
754
+
755
+ if length is None :
750
756
out_shape .append (None )
757
+ continue
758
+
759
+ start = the_slice .start
760
+ stop = the_slice .stop
761
+ step = the_slice .step
762
+
763
+ is_slice_const = True
764
+
765
+ start , is_const = extract_const (start )
766
+ is_slice_const = is_slice_const and is_const
767
+
768
+ stop , is_const = extract_const (stop )
769
+ is_slice_const = is_slice_const and is_const
770
+
771
+ step , is_const = extract_const (step )
772
+ is_slice_const = is_slice_const and is_const
773
+
774
+ if not is_slice_const :
775
+ out_shape .append (None )
776
+ continue
777
+
778
+ slice_length = len (range (* slice (start , stop , step ).indices (length )))
779
+ out_shape .append (slice_length )
751
780
752
781
return Apply (
753
782
self ,
0 commit comments