33It provides a consistent interface for operations on both NumPy arrays and PyTorch tensors.
44
55The functions in this module are designed to:
6+
671. Detect array types automatically (numpy.ndarray or torch.Tensor)
782. Perform operations using the appropriate library
8- 3. Preserve the input type in the output
9+ 3. Preserve the input type in the output, except for functions intended to operate on
10+ indices, which always return NDArrays for convenience.
9114. Minimize unnecessary type conversions
1012
11- Usage examples:
13+ ??? example "Some examples"
1214
13- ```python
14- import numpy as np
15- import torch
16- from pydvl.utils.array import array_zeros, array_concatenate, is_tensor
15+ ```python
16+ import numpy as np
17+ import torch
18+ from pydvl.utils.array import array_zeros, array_concatenate, is_tensor
1719
18- # Works with NumPy arrays
19- x_np = np.array([1, 2, 3])
20- zeros_np = array_zeros((3,), like=x_np) # Returns numpy.ndarray
20+ # Works with NumPy arrays
21+ x_np = np.array([1, 2, 3])
22+ zeros_np = array_zeros((3, ), like=x_np) # Returns numpy.ndarray
2123
22- # Works with PyTorch tensors
23- x_torch = torch.tensor([1, 2, 3])
24- zeros_torch = array_zeros((3,), like=x_torch) # Returns torch.Tensor
24+ # Works with PyTorch tensors
25+ x_torch = torch.tensor([1, 2, 3])
26+ zeros_torch = array_zeros((3, ), like=x_torch) # Returns torch.Tensor
2527
26- # Type checking
27- is_tensor(x_torch) # Returns True
28- is_tensor(x_np) # Returns False
28+ # Type checking
29+ is_tensor(x_torch) # Returns True
30+ is_tensor(x_np) # Returns False
2931
30- # Operations preserve types
31- result = array_concatenate([x_np, zeros_np]) # Returns numpy.ndarray
32- result = array_concatenate([x_torch, zeros_torch]) # Returns torch.Tensor
33- ```
32+ # Operations preserve types
33+ result = array_concatenate([x_np, zeros_np]) # Returns numpy.ndarray
34+ result = array_concatenate([x_torch, zeros_torch]) # Returns torch.Tensor
35+ ```
3436
3537The module uses a TypeVar `ArrayT` to ensure type preservation across functions,
3638allowing for proper static type checking with both array types.
3941from __future__ import annotations
4042
4143import warnings
44+ from numbers import Number
4245from types import ModuleType
4346from typing import (
4447 TYPE_CHECKING ,
7073 "array_ones" ,
7174 "array_zeros_like" ,
7275 "array_ones_like" ,
73- "array_where" ,
7476 "array_unique" ,
7577 "array_concatenate" ,
7678 "array_equal" ,
7779 "array_arange" ,
78- "array_slice" ,
7980 "array_index" ,
8081 "array_exp" ,
8182 "array_count_nonzero" ,
@@ -148,6 +149,8 @@ class Array(Protocol[DT]):
148149 in NumPy and PyTorch arrays.
149150 """
150151
152+ nbytes : int
153+
151154 @property
152155 def shape (self ) -> tuple [int , ...]: ...
153156
@@ -287,98 +290,65 @@ def array_ones(
287290
288291
289292@overload
290- def array_zeros_like (array : NDArray , dtype : Any | None = None ) -> NDArray : ...
293+ def array_zeros_like (a : NDArray , dtype : Any | None = None ) -> NDArray : ...
291294
292295
293296@overload
294- def array_zeros_like (array : Tensor , dtype : Any | None = None ) -> Tensor : ...
297+ def array_zeros_like (a : Tensor , dtype : Any | None = None ) -> Tensor : ...
295298
296299
297- def array_zeros_like (array : Array , dtype : Any | None = None ) -> Array :
300+ def array_zeros_like (a : NDArray | Tensor , dtype : Any | None = None ) -> NDArray | Tensor :
298301 """
299302 Create a zero-filled array with the same shape and type as `array`.
300303
301304 Args:
302- array : Reference array (numpy array or torch tensor).
305+ a : Reference array (numpy array or torch tensor).
303306 dtype: Data type (optional).
304307
305308 Returns:
306309 An array of zeros matching the input type.
310+
311+ Raises:
312+ TypeError:
307313 """
308- if is_tensor (array ):
314+ if is_tensor (a ):
309315 assert torch is not None # Keep mypy happy
310- tensor_array = cast (Tensor , array )
311- return cast (Array , torch .zeros_like (tensor_array , dtype = dtype ))
312- return cast (Array , np .zeros_like (array , dtype = dtype ))
316+ tensor_array = cast (Tensor , a )
317+ return cast (Tensor , torch .zeros_like (tensor_array , dtype = dtype ))
318+ elif is_numpy (a ):
319+ return cast (NDArray , np .zeros_like (a , dtype = dtype ))
320+ raise TypeError (f"Unsupported array type: { type (a ).__name__ } " )
313321
314322
315323@overload
316- def array_ones_like (array : NDArray , dtype : Any | None = None ) -> NDArray : ...
324+ def array_ones_like (a : NDArray , dtype : Any | None = None ) -> NDArray : ...
317325
318326
319327@overload
320- def array_ones_like (array : Tensor , dtype : Any | None = None ) -> Tensor : ...
328+ def array_ones_like (a : Tensor , dtype : Any | None = None ) -> Tensor : ...
321329
322330
323- def array_ones_like (array : Array , dtype : Any | None = None ) -> Array :
331+ def array_ones_like (a : NDArray | Tensor , dtype : Any | None = None ) -> NDArray | Tensor :
324332 """
325333 Create a one-filled array with the same shape and type as `array`.
326334
327335 Args:
328- array : Reference array (numpy array or torch tensor).
336+ a : Reference array (numpy array or torch tensor).
329337 dtype: Data type (optional).
330338
331339 Returns:
332340 An array of ones matching the input type.
333- """
334- if is_numpy (array ):
335- return cast (Array , np .ones_like (array , dtype = dtype ))
336- elif is_tensor (array ):
337- assert torch is not None
338- tensor_array = cast (Tensor , array )
339- return cast (Array , torch .ones_like (tensor_array , dtype = dtype ))
340- raise TypeError (f"Unsupported array type: { type (array ).__name__ } " )
341-
342-
343- @overload
344- def array_where (condition : NDArray , x : NDArray , y : NDArray ) -> NDArray : ...
345-
346-
347- @overload
348- def array_where (condition : Tensor , x : Tensor , y : Tensor ) -> Tensor : ...
349-
350-
351- def array_where (condition : Array , x : Array , y : Array ) -> Array :
352- """
353- Return elements chosen from x or y depending on condition.
354-
355- Args:
356- condition: Boolean mask.
357- x: Values selected where condition is True.
358- y: Values selected where condition is False.
359341
360- Returns :
361- An array with elements from x or y, following the input type.
342+ Raises :
343+ TypeError:
362344 """
363- # If any of the inputs is a tensor and torch is available, work with torch.
364- if any (is_tensor (a ) for a in (condition , x , y )):
345+ if is_numpy (a ):
346+ return cast (NDArray , np .ones_like (a , dtype = dtype ))
347+ elif is_tensor (a ):
365348 assert torch is not None
366- device = None
367- for a in (condition , x , y ):
368- if is_tensor (a ):
369- device = cast (Tensor , a ).device
370- break
371-
372- condition_tensor = (
373- condition
374- if is_tensor (condition )
375- else torch .as_tensor (to_numpy (condition ), device = device )
376- )
377- x_tensor = x if is_tensor (x ) else torch .as_tensor (to_numpy (x ), device = device )
378- y_tensor = y if is_tensor (y ) else torch .as_tensor (to_numpy (y ), device = device )
379- return cast (Array , torch .where (condition_tensor , x_tensor , y_tensor ))
380- else :
381- return cast (Array , np .where (condition , x , y ))
349+ tensor_array = cast (Tensor , a )
350+ return cast (Array , torch .ones_like (tensor_array , dtype = dtype ))
351+ raise TypeError (f"Unsupported array type: { type (a ).__name__ } " )
382352
383353
384354@overload
@@ -402,12 +372,12 @@ def array_unique(
402372@overload
403373def array_unique (
404374 array : Tensor , return_index : Literal [True ], ** kwargs : Any
405- ) -> Tuple [Tensor , Tensor ]: ...
375+ ) -> Tuple [Tensor , NDArray ]: ...
406376
407377
408378def array_unique (
409- array : ArrayT , return_index : bool = False , ** kwargs : Any
410- ) -> Union [ArrayT , Tuple [ArrayT , NDArray ]]:
379+ array : NDArray | Tensor , return_index : bool = False , ** kwargs : Any
380+ ) -> Union [NDArray | Tensor , Tuple [NDArray | Tensor , NDArray ]]:
411381 """
412382 Return the unique elements in an array, optionally with indices of their first
413383 occurrences.
@@ -436,8 +406,8 @@ def array_unique(
436406 indices_tensor = torch .tensor (
437407 indices , dtype = torch .long , device = tensor_array .device
438408 )
439- return cast (Tuple [ArrayT , NDArray ], (result , indices_tensor .cpu ().numpy ()))
440- return cast (ArrayT , result )
409+ return cast (Tuple [Tensor , NDArray ], (result , indices_tensor .cpu ().numpy ()))
410+ return cast (Tensor , result )
441411 else : # Fallback to numpy approach.
442412 numpy_array = to_numpy (array )
443413 if return_index :
@@ -447,15 +417,15 @@ def array_unique(
447417 return_index = True ,
448418 ** {k : v for k , v in kwargs .items () if k != "return_index" },
449419 )
450- return cast (Tuple [ArrayT , NDArray ], (unique_vals , indices ))
420+ return cast (Tuple [NDArray , NDArray ], (unique_vals , indices ))
451421 else :
452422 # Simple case - just unique values
453423 result = np .unique (
454424 numpy_array ,
455425 return_index = False ,
456426 ** {k : v for k , v in kwargs .items () if k != "return_index" },
457427 )
458- return cast (ArrayT , result )
428+ return cast (NDArray , result )
459429
460430
461431@overload
@@ -466,7 +436,9 @@ def array_concatenate(arrays: Sequence[NDArray], axis: int = 0) -> NDArray: ...
466436def array_concatenate (arrays : Sequence [Tensor ], axis : int = 0 ) -> Tensor : ...
467437
468438
469- def array_concatenate (arrays : Sequence [Array ], axis : int = 0 ) -> Array :
439+ def array_concatenate (
440+ arrays : Sequence [NDArray | Tensor ], axis : int = 0
441+ ) -> NDArray | Tensor :
470442 """
471443 Join a sequence of arrays along an existing axis.
472444
@@ -496,10 +468,10 @@ def array_concatenate(arrays: Sequence[Array], axis: int = 0) -> Array:
496468 tensor_arrays .append (a )
497469 else :
498470 tensor_arrays .append (torch .as_tensor (to_numpy (a ), device = device ))
499- return cast (Array , torch .cat (tensor_arrays , dim = axis ))
471+ return cast (Tensor , torch .cat (tensor_arrays , dim = axis ))
500472 # Otherwise, convert all arrays to numpy arrays.
501473 numpy_arrays = [to_numpy (a ) for a in arrays ]
502- return cast (Array , np .concatenate (numpy_arrays , axis = axis ))
474+ return cast (NDArray , np .concatenate (numpy_arrays , axis = axis ))
503475
504476
505477def array_equal (array1 : Array [Any ], array2 : Array [Any ]) -> bool :
@@ -569,35 +541,6 @@ def array_arange(
569541 return cast (Array , np .arange (start , stop , step , dtype = dtype )) # type: ignore
570542
571543
572- @overload
573- def array_slice (array : NDArray , indices : Any ) -> NDArray : ...
574-
575-
576- @overload
577- def array_slice (array : Tensor , indices : Any ) -> Tensor : ...
578-
579-
580- def array_slice (array : Array , indices : Any ) -> Array :
581- """
582- Slice an array in a type-agnostic way.
583-
584- Args:
585- array: The array to be sliced.
586- indices: The slicing indices (int, slice, list, etc.).
587-
588- Returns:
589- A sliced array with the same type as the input.
590-
591- Raises:
592- TypeError: If the input array does not support indexing.
593- """
594- if not hasattr (array , "__getitem__" ):
595- raise TypeError (
596- f"Provided object of type { type (array ).__name__ } is not indexable."
597- )
598- return array [indices ]
599-
600-
601544def array_index (array : Array , key : Array , dim : int = 0 ) -> Array :
602545 """
603546 Index into an array along the specified dimension.
@@ -703,6 +646,10 @@ def stratified_split_indices(
703646 return cast (Tuple [ArrayT , ArrayT ], (train_indices , test_indices ))
704647
705648
649+ @overload
650+ def atleast1d (a : Number ) -> NDArray : ...
651+
652+
706653@overload
707654def atleast1d (a : NDArray ) -> NDArray : ...
708655
@@ -711,25 +658,25 @@ def atleast1d(a: NDArray) -> NDArray: ...
711658def atleast1d (a : Tensor ) -> Tensor : ...
712659
713660
714- def atleast1d (a : Array ) -> Array :
715- """
716- Ensures that the array is at least 1D.
661+ def atleast1d (a : NDArray | Tensor | Number ) -> NDArray | Tensor :
662+ """Ensures that the input is at least 1D.
663+
664+ For scalar builtin types, the output is an NDArray. Scalar tensors are converted to
665+ 1D tensors
717666
718667 Args:
719- a: Input array-like object.
668+ a: Input array-like object or a scalar .
720669
721670 Returns:
722- The array , as a 1D structure.
671+ The input , as a 1D structure.
723672 """
724- if is_numpy (a ):
725- return cast (Array , np .atleast_1d (a ))
673+ if is_numpy (a ) or np . isscalar ( a ) :
674+ return cast (NDArray , np .atleast_1d (a )) # type: ignore
726675 if is_tensor (a ):
727676 assert torch is not None
728- tensor_array = cast (Tensor , a )
729- return cast (
730- Array , tensor_array .unsqueeze (0 ) if tensor_array .ndim == 0 else tensor_array
731- )
732- return cast (Array , np .atleast_1d (np .array (a )))
677+ t = cast (Tensor , a )
678+ return cast (Tensor , t .unsqueeze (0 ) if t .ndim == 0 else t )
679+ raise TypeError (f"Unsupported array or scalar type: { type (a ).__name__ } " )
733680
734681
735682@overload
@@ -755,13 +702,13 @@ def check_X_y(
755702
756703
757704def check_X_y (
758- X : Array ,
759- y : Array ,
705+ X : NDArray | Tensor ,
706+ y : NDArray | Tensor ,
760707 * ,
761708 multi_output : bool = False ,
762709 estimator : str | object | None = None ,
763710 copy : bool = False ,
764- ) -> Tuple [Array , Array ]:
711+ ) -> Tuple [NDArray | Tensor , NDArray | Tensor ]:
765712 """
766713 Validate X and y mimicking the functionality of sklearn's check_X_y.
767714
0 commit comments