28
28
is_allowed_extension_array_dtype ,
29
29
is_duck_array ,
30
30
is_duck_dask_array ,
31
+ is_full_slice ,
31
32
is_scalar ,
32
33
is_valid_numpy_dtype ,
33
34
to_0d_array ,
43
44
from xarray .namedarray ._typing import _Shape , duckarray
44
45
from xarray .namedarray .parallelcompat import ChunkManagerEntrypoint
45
46
47
+ BasicIndexerType = int | np .integer | slice
48
+ OuterIndexerType = BasicIndexerType | np .ndarray [Any , np .dtype [np .integer ]]
49
+
46
50
47
51
@dataclass
48
52
class IndexSelResult :
@@ -300,19 +304,83 @@ def slice_slice(old_slice: slice, applied_slice: slice, size: int) -> slice:
300
304
return slice (start , stop , step )
301
305
302
306
303
- def _index_indexer_1d (old_indexer , applied_indexer , size : int ):
304
- if isinstance (applied_indexer , slice ) and applied_indexer == slice (None ):
307
+ def normalize_array (
308
+ array : np .ndarray [Any , np .dtype [np .integer ]], size : int
309
+ ) -> np .ndarray [Any , np .dtype [np .integer ]]:
310
+ """
311
+ Ensure that the given array only contains positive values.
312
+
313
+ Examples
314
+ --------
315
+ >>> normalize_array(np.array([-1, -2, -3, -4]), 10)
316
+ array([9, 8, 7, 6])
317
+ >>> normalize_array(np.array([-5, 3, 5, -1, 8]), 12)
318
+ array([ 7, 3, 5, 11, 8])
319
+ """
320
+ if np .issubdtype (array .dtype , np .unsignedinteger ):
321
+ return array
322
+
323
+ return np .where (array >= 0 , array , array + size )
324
+
325
+
326
+ def slice_slice_by_array (
327
+ old_slice : slice ,
328
+ array : np .ndarray [Any , np .dtype [np .integer ]],
329
+ size : int ,
330
+ ) -> np .ndarray [Any , np .dtype [np .integer ]]:
331
+ """Given a slice and the size of the dimension to which it will be applied,
332
+ index it with an array to return a new array equivalent to applying
333
+ the slices sequentially
334
+
335
+ Examples
336
+ --------
337
+ >>> slice_slice_by_array(slice(2, 10), np.array([1, 3, 5]), 12)
338
+ array([3, 5, 7])
339
+ >>> slice_slice_by_array(slice(1, None, 2), np.array([1, 3, 7, 8]), 20)
340
+ array([ 3, 7, 15, 17])
341
+ >>> slice_slice_by_array(slice(None, None, -1), np.array([2, 4, 7]), 20)
342
+ array([17, 15, 12])
343
+ """
344
+ # to get a concrete slice, limited to the size of the array
345
+ normalized_slice = normalize_slice (old_slice , size )
346
+
347
+ size_after_slice = len (range (* normalized_slice .indices (size )))
348
+ normalized_array = normalize_array (array , size_after_slice )
349
+
350
+ new_indexer = normalized_array * normalized_slice .step + normalized_slice .start
351
+
352
+ if np .any (new_indexer >= size ):
353
+ raise IndexError ("indices out of bounds" ) # TODO: more helpful error message
354
+
355
+ return new_indexer
356
+
357
+
358
+ def _index_indexer_1d (
359
+ old_indexer : OuterIndexerType ,
360
+ applied_indexer : OuterIndexerType ,
361
+ size : int ,
362
+ ) -> OuterIndexerType :
363
+ if is_full_slice (applied_indexer ):
305
364
# shortcut for the usual case
306
365
return old_indexer
366
+ if is_full_slice (old_indexer ):
367
+ # shortcut for full slices
368
+ return applied_indexer
369
+
370
+ indexer : OuterIndexerType
307
371
if isinstance (old_indexer , slice ):
308
372
if isinstance (applied_indexer , slice ):
309
373
indexer = slice_slice (old_indexer , applied_indexer , size )
310
374
elif isinstance (applied_indexer , integer_types ):
311
- indexer = range (* old_indexer .indices (size ))[applied_indexer ] # type: ignore[assignment]
375
+ indexer = range (* old_indexer .indices (size ))[applied_indexer ]
312
376
else :
313
- indexer = _expand_slice (old_indexer , size )[ applied_indexer ]
314
- else :
377
+ indexer = slice_slice_by_array (old_indexer , applied_indexer , size )
378
+ elif isinstance ( old_indexer , np . ndarray ) :
315
379
indexer = old_indexer [applied_indexer ]
380
+ else :
381
+ # should be unreachable
382
+ raise ValueError ("cannot index integers. Please open an issuec-" )
383
+
316
384
return indexer
317
385
318
386
@@ -389,7 +457,7 @@ class BasicIndexer(ExplicitIndexer):
389
457
390
458
__slots__ = ()
391
459
392
- def __init__ (self , key : tuple [int | np . integer | slice , ...]):
460
+ def __init__ (self , key : tuple [BasicIndexerType , ...]):
393
461
if not isinstance (key , tuple ):
394
462
raise TypeError (f"key must be a tuple: { key !r} " )
395
463
@@ -421,9 +489,7 @@ class OuterIndexer(ExplicitIndexer):
421
489
422
490
def __init__ (
423
491
self ,
424
- key : tuple [
425
- int | np .integer | slice | np .ndarray [Any , np .dtype [np .generic ]], ...
426
- ],
492
+ key : tuple [BasicIndexerType | np .ndarray [Any , np .dtype [np .generic ]], ...],
427
493
):
428
494
if not isinstance (key , tuple ):
429
495
raise TypeError (f"key must be a tuple: { key !r} " )
@@ -629,7 +695,8 @@ def __init__(self, array: Any, key: ExplicitIndexer | None = None):
629
695
630
696
def _updated_key (self , new_key : ExplicitIndexer ) -> BasicIndexer | OuterIndexer :
631
697
iter_new_key = iter (expanded_indexer (new_key .tuple , self .ndim ))
632
- full_key = []
698
+
699
+ full_key : list [OuterIndexerType ] = []
633
700
for size , k in zip (self .array .shape , self .key .tuple , strict = True ):
634
701
if isinstance (k , integer_types ):
635
702
full_key .append (k )
@@ -638,7 +705,7 @@ def _updated_key(self, new_key: ExplicitIndexer) -> BasicIndexer | OuterIndexer:
638
705
full_key_tuple = tuple (full_key )
639
706
640
707
if all (isinstance (k , integer_types + (slice ,)) for k in full_key_tuple ):
641
- return BasicIndexer (full_key_tuple )
708
+ return BasicIndexer (cast ( tuple [ BasicIndexerType , ...], full_key_tuple ) )
642
709
return OuterIndexer (full_key_tuple )
643
710
644
711
@property
0 commit comments