@@ -12116,7 +12116,7 @@ def _index_to_gather(x_shape: Sequence[int], idx: Sequence[Any],
1211612116 "arrays within JIT compiled functions)." )
1211712117 raise IndexError (msg )
1211812118
12119- start , step , slice_size = _preprocess_slice (i , x_shape [x_axis ])
12119+ start , step , slice_size = core . canonicalize_slice (i , x_shape [x_axis ])
1212012120 slice_shape .append (slice_size )
1212112121
1212212122 if core .definitely_equal (step , 1 ):
@@ -12319,65 +12319,6 @@ def _canonicalize_tuple_index(arr_ndim, idx):
1231912319 idx = tuple (idx ) + colons
1232012320 return idx
1232112321
12322- def _preprocess_slice (
12323- s : slice ,
12324- axis_size : core .DimSize
12325- ) -> tuple [core .DimSize , core .DimSize , core .DimSize ]:
12326- """Computes the start index, step, and size of the slice `x[s]`."""
12327- # See https://numpy.org/doc/stable/user/basics.indexing.html#slicing-and-striding
12328- # "this is harder to get right than you may think"
12329- # (from https://github.com/python/cpython/blob/939fc6d6eab9b7ea8c244d513610dbdd556503a7/Objects/sliceobject.c#L275)
12330- def convert_to_index (d : DimSize ) -> DimSize :
12331- # Convert np.array and jax.Array to int, leave symbolic dimensions alone
12332- try :
12333- return operator .index (d )
12334- except :
12335- return d
12336-
12337- # Must resolve statically if step is {<0, ==0, >0}
12338- step = convert_to_index (s .step ) if s .step is not None else 1
12339- try :
12340- if step == 0 :
12341- raise ValueError ("slice step cannot be zero" )
12342- step_gt_0 = (step > 0 )
12343- except core .InconclusiveDimensionOperation as e :
12344- raise core .InconclusiveDimensionOperation (
12345- f"In slice with non-constant elements the step ({ step } ) must " +
12346- f"be resolved statically if it is > 0 or < 0.\n Details: { e } " )
12347-
12348- def clamp_index (i : DimSize , which : str ):
12349- try :
12350- i_ge_0 = (i >= 0 )
12351- except core .InconclusiveDimensionOperation as e :
12352- raise core .InconclusiveDimensionOperation (
12353- f"In slice with non-constant elements the { which } ({ i } ) must " +
12354- f"be resolved statically if it is >= 0.\n Details: { e } " )
12355- if i_ge_0 :
12356- if step_gt_0 :
12357- return core .min_dim (axis_size , i )
12358- else :
12359- return core .min_dim (axis_size - 1 , i )
12360- else :
12361- if step_gt_0 :
12362- return core .max_dim (0 , axis_size + i )
12363- else :
12364- return core .max_dim (- 1 , axis_size + i )
12365-
12366- if s .start is None :
12367- start = 0 if step_gt_0 else axis_size - 1
12368- else :
12369- start = clamp_index (convert_to_index (s .start ), "start" )
12370-
12371- if s .stop is None :
12372- stop = axis_size if step_gt_0 else - 1
12373- else :
12374- stop = clamp_index (convert_to_index (s .stop ), "stop" )
12375-
12376- gap = step if step_gt_0 else - step
12377- distance = (stop - start ) if step_gt_0 else (start - stop )
12378- slice_size = core .max_dim (0 , distance + gap - 1 ) // gap
12379- return start , step , slice_size
12380-
1238112322
1238212323@export
1238312324def blackman (M : int ) -> Array :
0 commit comments