3737from typing import (
3838 TYPE_CHECKING ,
3939 Any ,
40- Iterator ,
41- List ,
4240 Literal ,
43- Protocol ,
4441 Sequence ,
45- Tuple ,
4642 TypeVar ,
4743 Union ,
4844 cast ,
4945 overload ,
50- runtime_checkable ,
5146)
5247
5348import numpy as np
5449import sklearn .utils
5550from numpy .typing import ArrayLike , NDArray
56- from typing_extensions import Self
5751
5852__all__ = [
5953 "is_tensor" ,
7064 "check_X_y" ,
7165 "check_X_y_torch" ,
7266 "try_torch_import" ,
67+ "ArrayT" ,
68+ "ArrayRetT" ,
7369]
7470
7571DT = TypeVar ("DT" , bound = np .generic )
@@ -110,67 +106,7 @@ def is_numpy(array: Any) -> bool:
110106 return isinstance (array , np .ndarray )
111107
112108
113- @runtime_checkable
114- class Array (Protocol [DT ]):
115- """Protocol defining a common interface for NumPy arrays and PyTorch tensors.
116-
117- This protocol defines the essential methods and properties required for array-like
118- operations in PyDVL. It serves as a structural type for both numpy.ndarray
119- and torch.Tensor, enabling type-safe generic functions that work with either type.
120-
121- The generic parameter DT represents the data type of the array elements.
122-
123- !!! note "Type Preservation"
124- Functions that accept Array types will generally preserve the input type
125- in their outputs. For example, if you pass a torch.Tensor, you'll get a
126- torch.Tensor back; if you pass a numpy.ndarray, you'll get a numpy.ndarray back.
127-
128- !!! warning
129- This is a "best-effort" implementation that covers the methods and properties
130- needed by PyDVL, but it is not a complete representation of all functionality
131- in NumPy and PyTorch arrays.
132- """
133-
134- @property
135- def nbytes (self ) -> int : ...
136-
137- @property
138- def shape (self ) -> tuple [int , ...]: ...
139-
140- @property
141- def ndim (self ) -> int : ...
142-
143- @property
144- def dtype (self ) -> Any : ...
145-
146- def __len__ (self ) -> int : ...
147-
148- def __getitem__ (self , key : Any ) -> Self : ...
149-
150- def __iter__ (self ) -> Iterator : ...
151-
152- def __add__ (self , other : Array ) -> Array : ...
153-
154- def __sub__ (self , other ) -> Array : ...
155-
156- def __mul__ (self , other ) -> Array : ...
157-
158- def __matmul__ (self , other ) -> Array : ...
159-
160- def __array__ (self , dtype : DT | None = None ) -> NDArray [DT ]: ...
161-
162- def flatten (self , * args , ** kwargs ) -> Self : ...
163-
164- def reshape (self , * args : Any , ** kwargs : Any ) -> Self : ...
165-
166- def tolist (self ) -> list : ...
167-
168- def item (self ) -> DT : ...
169-
170- def sum (self , * args : Any , ** kwargs : Any ) -> Self : ...
171-
172-
173- def to_tensor (array : Array | ArrayLike ) -> Tensor :
109+ def to_tensor (array : NDArray | Tensor | ArrayLike ) -> Tensor :
174110 """
175111 Convert array to torch.Tensor if it's not already.
176112
@@ -190,7 +126,7 @@ def to_tensor(array: Array | ArrayLike) -> Tensor:
190126 return cast (Tensor , torch .as_tensor (array ))
191127
192128
193- def to_numpy (array : Array | ArrayLike ) -> NDArray :
129+ def to_numpy (array : NDArray | Tensor | ArrayLike ) -> NDArray :
194130 """
195131 Convert array to a numpy.ndarray if it's not already.
196132
@@ -207,7 +143,7 @@ def to_numpy(array: Array | ArrayLike) -> NDArray:
207143 return cast (NDArray , np .asarray (array ))
208144
209145
210- ShapeType = Union [int , Tuple [int , ...], List [int ]]
146+ ShapeType = Union [int , tuple [int , ...], list [int ]]
211147
212148
213149@overload
@@ -219,7 +155,7 @@ def array_unique(
219155@overload
220156def array_unique (
221157 array : NDArray , return_index : Literal [True ], ** kwargs : Any
222- ) -> Tuple [NDArray , NDArray ]: ...
158+ ) -> tuple [NDArray , NDArray ]: ...
223159
224160
225161@overload
@@ -231,12 +167,12 @@ def array_unique(
231167@overload
232168def array_unique (
233169 array : Tensor , return_index : Literal [True ], ** kwargs : Any
234- ) -> Tuple [Tensor , NDArray ]: ...
170+ ) -> tuple [Tensor , NDArray ]: ...
235171
236172
237173def array_unique (
238174 array : NDArray | Tensor , return_index : bool = False , ** kwargs : Any
239- ) -> Union [ NDArray | Tensor , Tuple [ NDArray | Tensor , NDArray ] ]:
175+ ) -> NDArray | tuple [ NDArray , NDArray ] | Tensor | tuple [ Tensor , NDArray ]:
240176 """
241177 Return the unique elements in an array, optionally with indices of their first
242178 occurrences.
@@ -265,9 +201,9 @@ def array_unique(
265201 indices_tensor = torch .tensor (
266202 indices , dtype = torch .long , device = tensor_array .device
267203 )
268- return cast (Tuple [Tensor , NDArray ], (result , indices_tensor .cpu ().numpy ()))
204+ return cast (tuple [Tensor , NDArray ], (result , indices_tensor .cpu ().numpy ()))
269205 return cast (Tensor , result )
270- else : # Fallback to numpy approach.
206+ else :
271207 numpy_array = to_numpy (array )
272208 if return_index :
273209 # np.unique returns a tuple when return_index=True
@@ -276,7 +212,7 @@ def array_unique(
276212 return_index = True ,
277213 ** {k : v for k , v in kwargs .items () if k != "return_index" },
278214 )
279- return cast (Tuple [NDArray , NDArray ], (unique_vals , indices ))
215+ return cast (tuple [NDArray , NDArray ], (unique_vals , indices ))
280216 else :
281217 # Simple case - just unique values
282218 result = np .unique (
@@ -296,7 +232,7 @@ def array_concatenate(arrays: Sequence[Tensor], axis: int = 0) -> Tensor: ...
296232
297233
298234def array_concatenate (
299- arrays : Sequence [NDArray | Tensor ], axis : int = 0
235+ arrays : Sequence [NDArray ] | Sequence [ Tensor ], axis : int = 0
300236) -> NDArray | Tensor :
301237 """
302238 Join a sequence of arrays along an existing axis.
@@ -333,13 +269,25 @@ def array_concatenate(
333269 return cast (NDArray , np .concatenate (numpy_arrays , axis = axis ))
334270
335271
336- ArrayT = TypeVar ("ArrayT" , bound = Array , contravariant = True )
337- ArrayRetT = TypeVar ("ArrayRetT" , bound = Array , covariant = True )
272+ ArrayT = TypeVar ("ArrayT" , NDArray , Tensor , contravariant = True )
273+ ArrayRetT = TypeVar ("ArrayRetT" , NDArray , Tensor , covariant = True )
338274
339275
276+ @overload
340277def stratified_split_indices (
341- y : ArrayT , train_size : float | int = 0.8 , random_state : int | None = None
342- ) -> Tuple [ArrayT , ArrayT ]:
278+ y : NDArray , train_size : float | int = 0.8 , random_state : int | None = None
279+ ) -> tuple [NDArray , NDArray ]: ...
280+
281+
282+ @overload
283+ def stratified_split_indices (
284+ y : Tensor , train_size : float | int = 0.8 , random_state : int | None = None
285+ ) -> tuple [Tensor , Tensor ]: ...
286+
287+
288+ def stratified_split_indices (
289+ y : NDArray | Tensor , train_size : float | int = 0.8 , random_state : int | None = None
290+ ) -> tuple [Tensor , Tensor ] | tuple [NDArray , NDArray ]:
343291 """
344292 Compute stratified train/test split indices based on labels.
345293
@@ -371,7 +319,7 @@ def stratified_split_indices(
371319 train_cat = torch .cat (train_indices )
372320 test_cat = torch .cat (test_indices )
373321 return cast (
374- Tuple [ ArrayT , ArrayT ],
322+ tuple [ Tensor , Tensor ],
375323 (
376324 train_cat [torch .randperm (len (train_cat ))],
377325 test_cat [torch .randperm (len (test_cat ))],
@@ -386,7 +334,7 @@ def stratified_split_indices(
386334 indices , train_size = train_size , stratify = y_np , random_state = random_state
387335 )
388336
389- return cast (Tuple [ ArrayT , ArrayT ], (train_indices , test_indices ))
337+ return cast (tuple [ NDArray , NDArray ], (train_indices , test_indices ))
390338
391339
392340@overload
@@ -430,7 +378,7 @@ def check_X_y(
430378 multi_output : bool = False ,
431379 estimator : str | object | None = None ,
432380 copy : bool = False ,
433- ) -> Tuple [NDArray , NDArray ]: ...
381+ ) -> tuple [NDArray , NDArray ]: ...
434382
435383
436384@overload
@@ -441,7 +389,7 @@ def check_X_y(
441389 multi_output : bool = False ,
442390 estimator : str | object | None = None ,
443391 copy : bool = False ,
444- ) -> Tuple [Tensor , Tensor ]: ...
392+ ) -> tuple [Tensor , Tensor ]: ...
445393
446394
447395def check_X_y (
@@ -451,7 +399,7 @@ def check_X_y(
451399 multi_output : bool = False ,
452400 estimator : str | object | None = None ,
453401 copy : bool = False ,
454- ) -> Tuple [NDArray | Tensor , NDArray | Tensor ]:
402+ ) -> tuple [NDArray , NDArray ] | tuple [ Tensor , Tensor ]:
455403 """
456404 Validate X and y mimicking the functionality of sklearn's check_X_y.
457405
@@ -483,7 +431,7 @@ def check_X_y(
483431 ),
484432 )
485433 return cast (
486- Tuple [NDArray , NDArray ],
434+ tuple [NDArray , NDArray ],
487435 sklearn .utils .check_X_y (
488436 X , y , multi_output = multi_output , estimator = estimator , copy = copy
489437 ),
@@ -578,9 +526,7 @@ def check_X_y_torch(
578526 return X , y
579527
580528
581- def array_count_nonzero (
582- x : Array ,
583- ) -> int :
529+ def array_count_nonzero (x : NDArray | Tensor ) -> int :
584530 """
585531 Count the number of non-zero elements in the array.
586532
@@ -594,23 +540,21 @@ def array_count_nonzero(
594540 assert torch is not None
595541 tensor_array = cast (Tensor , x )
596542 return int (torch .count_nonzero (tensor_array ).item ())
597- else : # Fallback to numpy approach
543+ else :
598544 numpy_array = to_numpy (x )
599545 return int (np .count_nonzero (numpy_array ))
600546
601547
602- def array_nonzero (
603- x : ArrayT ,
604- ) -> tuple [NDArray [np .int_ ], ...]:
548+ def array_nonzero (x : NDArray | Tensor ) -> tuple [NDArray [np .int_ ], ...]:
605549 """
606550 Find the indices of non-zero elements.
607551
608552 Args:
609553 x: Input array.
610554
611555 Returns:
612- Tuple of arrays, one for each dimension of x,
613- containing the indices of the non-zero elements in that dimension.
556+ Tuple of arrays, one for each dimension of x, containing the indices of the
557+ non-zero elements in that dimension.
614558 """
615559 if is_tensor (x ):
616560 assert torch is not None
@@ -619,7 +563,7 @@ def array_nonzero(
619563 return cast (tuple [NDArray , ...], np .nonzero (to_numpy (x )))
620564
621565
622- def is_categorical (x : Array [ Any ] ) -> bool :
566+ def is_categorical (x : NDArray | Tensor ) -> bool :
623567 """
624568 Check if an array contains categorical data (suitable for unique labels).
625569
0 commit comments