Skip to content

Commit 4b8113e

Browse files
committed
Delete unnecessary Array protocol. Fix a bunch of types
1 parent 94eaaf4 commit 4b8113e

File tree

3 files changed

+49
-106
lines changed

3 files changed

+49
-106
lines changed

src/pydvl/utils/array.py

Lines changed: 40 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -37,23 +37,17 @@
3737
from typing import (
3838
TYPE_CHECKING,
3939
Any,
40-
Iterator,
41-
List,
4240
Literal,
43-
Protocol,
4441
Sequence,
45-
Tuple,
4642
TypeVar,
4743
Union,
4844
cast,
4945
overload,
50-
runtime_checkable,
5146
)
5247

5348
import numpy as np
5449
import sklearn.utils
5550
from numpy.typing import ArrayLike, NDArray
56-
from typing_extensions import Self
5751

5852
__all__ = [
5953
"is_tensor",
@@ -70,6 +64,8 @@
7064
"check_X_y",
7165
"check_X_y_torch",
7266
"try_torch_import",
67+
"ArrayT",
68+
"ArrayRetT",
7369
]
7470

7571
DT = TypeVar("DT", bound=np.generic)
@@ -110,67 +106,7 @@ def is_numpy(array: Any) -> bool:
110106
return isinstance(array, np.ndarray)
111107

112108

113-
@runtime_checkable
114-
class Array(Protocol[DT]):
115-
"""Protocol defining a common interface for NumPy arrays and PyTorch tensors.
116-
117-
This protocol defines the essential methods and properties required for array-like
118-
operations in PyDVL. It serves as a structural type for both numpy.ndarray
119-
and torch.Tensor, enabling type-safe generic functions that work with either type.
120-
121-
The generic parameter DT represents the data type of the array elements.
122-
123-
!!! note "Type Preservation"
124-
Functions that accept Array types will generally preserve the input type
125-
in their outputs. For example, if you pass a torch.Tensor, you'll get a
126-
torch.Tensor back; if you pass a numpy.ndarray, you'll get a numpy.ndarray back.
127-
128-
!!! warning
129-
This is a "best-effort" implementation that covers the methods and properties
130-
needed by PyDVL, but it is not a complete representation of all functionality
131-
in NumPy and PyTorch arrays.
132-
"""
133-
134-
@property
135-
def nbytes(self) -> int: ...
136-
137-
@property
138-
def shape(self) -> tuple[int, ...]: ...
139-
140-
@property
141-
def ndim(self) -> int: ...
142-
143-
@property
144-
def dtype(self) -> Any: ...
145-
146-
def __len__(self) -> int: ...
147-
148-
def __getitem__(self, key: Any) -> Self: ...
149-
150-
def __iter__(self) -> Iterator: ...
151-
152-
def __add__(self, other: Array) -> Array: ...
153-
154-
def __sub__(self, other) -> Array: ...
155-
156-
def __mul__(self, other) -> Array: ...
157-
158-
def __matmul__(self, other) -> Array: ...
159-
160-
def __array__(self, dtype: DT | None = None) -> NDArray[DT]: ...
161-
162-
def flatten(self, *args, **kwargs) -> Self: ...
163-
164-
def reshape(self, *args: Any, **kwargs: Any) -> Self: ...
165-
166-
def tolist(self) -> list: ...
167-
168-
def item(self) -> DT: ...
169-
170-
def sum(self, *args: Any, **kwargs: Any) -> Self: ...
171-
172-
173-
def to_tensor(array: Array | ArrayLike) -> Tensor:
109+
def to_tensor(array: NDArray | Tensor | ArrayLike) -> Tensor:
174110
"""
175111
Convert array to torch.Tensor if it's not already.
176112
@@ -190,7 +126,7 @@ def to_tensor(array: Array | ArrayLike) -> Tensor:
190126
return cast(Tensor, torch.as_tensor(array))
191127

192128

193-
def to_numpy(array: Array | ArrayLike) -> NDArray:
129+
def to_numpy(array: NDArray | Tensor | ArrayLike) -> NDArray:
194130
"""
195131
Convert array to a numpy.ndarray if it's not already.
196132
@@ -207,7 +143,7 @@ def to_numpy(array: Array | ArrayLike) -> NDArray:
207143
return cast(NDArray, np.asarray(array))
208144

209145

210-
ShapeType = Union[int, Tuple[int, ...], List[int]]
146+
ShapeType = Union[int, tuple[int, ...], list[int]]
211147

212148

213149
@overload
@@ -219,7 +155,7 @@ def array_unique(
219155
@overload
220156
def array_unique(
221157
array: NDArray, return_index: Literal[True], **kwargs: Any
222-
) -> Tuple[NDArray, NDArray]: ...
158+
) -> tuple[NDArray, NDArray]: ...
223159

224160

225161
@overload
@@ -231,12 +167,12 @@ def array_unique(
231167
@overload
232168
def array_unique(
233169
array: Tensor, return_index: Literal[True], **kwargs: Any
234-
) -> Tuple[Tensor, NDArray]: ...
170+
) -> tuple[Tensor, NDArray]: ...
235171

236172

237173
def array_unique(
238174
array: NDArray | Tensor, return_index: bool = False, **kwargs: Any
239-
) -> Union[NDArray | Tensor, Tuple[NDArray | Tensor, NDArray]]:
175+
) -> NDArray | tuple[NDArray, NDArray] | Tensor | tuple[Tensor, NDArray]:
240176
"""
241177
Return the unique elements in an array, optionally with indices of their first
242178
occurrences.
@@ -265,9 +201,9 @@ def array_unique(
265201
indices_tensor = torch.tensor(
266202
indices, dtype=torch.long, device=tensor_array.device
267203
)
268-
return cast(Tuple[Tensor, NDArray], (result, indices_tensor.cpu().numpy()))
204+
return cast(tuple[Tensor, NDArray], (result, indices_tensor.cpu().numpy()))
269205
return cast(Tensor, result)
270-
else: # Fallback to numpy approach.
206+
else:
271207
numpy_array = to_numpy(array)
272208
if return_index:
273209
# np.unique returns a tuple when return_index=True
@@ -276,7 +212,7 @@ def array_unique(
276212
return_index=True,
277213
**{k: v for k, v in kwargs.items() if k != "return_index"},
278214
)
279-
return cast(Tuple[NDArray, NDArray], (unique_vals, indices))
215+
return cast(tuple[NDArray, NDArray], (unique_vals, indices))
280216
else:
281217
# Simple case - just unique values
282218
result = np.unique(
@@ -296,7 +232,7 @@ def array_concatenate(arrays: Sequence[Tensor], axis: int = 0) -> Tensor: ...
296232

297233

298234
def array_concatenate(
299-
arrays: Sequence[NDArray | Tensor], axis: int = 0
235+
arrays: Sequence[NDArray] | Sequence[Tensor], axis: int = 0
300236
) -> NDArray | Tensor:
301237
"""
302238
Join a sequence of arrays along an existing axis.
@@ -333,13 +269,25 @@ def array_concatenate(
333269
return cast(NDArray, np.concatenate(numpy_arrays, axis=axis))
334270

335271

336-
ArrayT = TypeVar("ArrayT", bound=Array, contravariant=True)
337-
ArrayRetT = TypeVar("ArrayRetT", bound=Array, covariant=True)
272+
ArrayT = TypeVar("ArrayT", NDArray, Tensor, contravariant=True)
273+
ArrayRetT = TypeVar("ArrayRetT", NDArray, Tensor, covariant=True)
338274

339275

276+
@overload
340277
def stratified_split_indices(
341-
y: ArrayT, train_size: float | int = 0.8, random_state: int | None = None
342-
) -> Tuple[ArrayT, ArrayT]:
278+
y: NDArray, train_size: float | int = 0.8, random_state: int | None = None
279+
) -> tuple[NDArray, NDArray]: ...
280+
281+
282+
@overload
283+
def stratified_split_indices(
284+
y: Tensor, train_size: float | int = 0.8, random_state: int | None = None
285+
) -> tuple[Tensor, Tensor]: ...
286+
287+
288+
def stratified_split_indices(
289+
y: NDArray | Tensor, train_size: float | int = 0.8, random_state: int | None = None
290+
) -> tuple[Tensor, Tensor] | tuple[NDArray, NDArray]:
343291
"""
344292
Compute stratified train/test split indices based on labels.
345293
@@ -371,7 +319,7 @@ def stratified_split_indices(
371319
train_cat = torch.cat(train_indices)
372320
test_cat = torch.cat(test_indices)
373321
return cast(
374-
Tuple[ArrayT, ArrayT],
322+
tuple[Tensor, Tensor],
375323
(
376324
train_cat[torch.randperm(len(train_cat))],
377325
test_cat[torch.randperm(len(test_cat))],
@@ -386,7 +334,7 @@ def stratified_split_indices(
386334
indices, train_size=train_size, stratify=y_np, random_state=random_state
387335
)
388336

389-
return cast(Tuple[ArrayT, ArrayT], (train_indices, test_indices))
337+
return cast(tuple[NDArray, NDArray], (train_indices, test_indices))
390338

391339

392340
@overload
@@ -430,7 +378,7 @@ def check_X_y(
430378
multi_output: bool = False,
431379
estimator: str | object | None = None,
432380
copy: bool = False,
433-
) -> Tuple[NDArray, NDArray]: ...
381+
) -> tuple[NDArray, NDArray]: ...
434382

435383

436384
@overload
@@ -441,7 +389,7 @@ def check_X_y(
441389
multi_output: bool = False,
442390
estimator: str | object | None = None,
443391
copy: bool = False,
444-
) -> Tuple[Tensor, Tensor]: ...
392+
) -> tuple[Tensor, Tensor]: ...
445393

446394

447395
def check_X_y(
@@ -451,7 +399,7 @@ def check_X_y(
451399
multi_output: bool = False,
452400
estimator: str | object | None = None,
453401
copy: bool = False,
454-
) -> Tuple[NDArray | Tensor, NDArray | Tensor]:
402+
) -> tuple[NDArray, NDArray] | tuple[Tensor, Tensor]:
455403
"""
456404
Validate X and y mimicking the functionality of sklearn's check_X_y.
457405
@@ -483,7 +431,7 @@ def check_X_y(
483431
),
484432
)
485433
return cast(
486-
Tuple[NDArray, NDArray],
434+
tuple[NDArray, NDArray],
487435
sklearn.utils.check_X_y(
488436
X, y, multi_output=multi_output, estimator=estimator, copy=copy
489437
),
@@ -578,9 +526,7 @@ def check_X_y_torch(
578526
return X, y
579527

580528

581-
def array_count_nonzero(
582-
x: Array,
583-
) -> int:
529+
def array_count_nonzero(x: NDArray | Tensor) -> int:
584530
"""
585531
Count the number of non-zero elements in the array.
586532
@@ -594,23 +540,21 @@ def array_count_nonzero(
594540
assert torch is not None
595541
tensor_array = cast(Tensor, x)
596542
return int(torch.count_nonzero(tensor_array).item())
597-
else: # Fallback to numpy approach
543+
else:
598544
numpy_array = to_numpy(x)
599545
return int(np.count_nonzero(numpy_array))
600546

601547

602-
def array_nonzero(
603-
x: ArrayT,
604-
) -> tuple[NDArray[np.int_], ...]:
548+
def array_nonzero(x: NDArray | Tensor) -> tuple[NDArray[np.int_], ...]:
605549
"""
606550
Find the indices of non-zero elements.
607551
608552
Args:
609553
x: Input array.
610554
611555
Returns:
612-
Tuple of arrays, one for each dimension of x,
613-
containing the indices of the non-zero elements in that dimension.
556+
Tuple of arrays, one for each dimension of x, containing the indices of the
557+
non-zero elements in that dimension.
614558
"""
615559
if is_tensor(x):
616560
assert torch is not None
@@ -619,7 +563,7 @@ def array_nonzero(
619563
return cast(tuple[NDArray, ...], np.nonzero(to_numpy(x)))
620564

621565

622-
def is_categorical(x: Array[Any]) -> bool:
566+
def is_categorical(x: NDArray | Tensor) -> bool:
623567
"""
624568
Check if an array contains categorical data (suitable for unique labels).
625569

src/pydvl/valuation/result.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@
104104
from typing_extensions import Self
105105

106106
from pydvl.utils import log_running_moments
107-
from pydvl.utils.array import Array, is_tensor, to_numpy
107+
from pydvl.utils.array import is_tensor, to_numpy
108108
from pydvl.utils.status import Status
109109
from pydvl.utils.types import Seed
110110
from pydvl.valuation.dataset import Dataset
@@ -284,11 +284,11 @@ class ValuationResult(collections.abc.Sequence, Iterable[ValueItem]):
284284
def __init__(
285285
self,
286286
*,
287-
values: Sequence[np.float64] | NDArray[np.float64] | Array,
288-
variances: Sequence[np.float64] | NDArray[np.float64] | Array | None = None,
289-
counts: Sequence[np.int_] | NDArray[np.int_] | Array | None = None,
290-
indices: Sequence[IndexT] | NDArray[IndexT] | Array | None = None,
291-
data_names: Sequence[NameT] | NDArray[NameT] | Array | None = None,
287+
values: Sequence[np.float64] | NDArray[np.float64],
288+
variances: Sequence[np.float64] | NDArray[np.float64] | None = None,
289+
counts: Sequence[np.int_] | NDArray[np.int_] | None = None,
290+
indices: Sequence[IndexT] | NDArray[IndexT] | None = None,
291+
data_names: Sequence[NameT] | NDArray[NameT] | None = None,
292292
algorithm: str = "",
293293
status: Status = Status.Pending,
294294
sort: bool | None = None,

src/pydvl/valuation/types.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@
3030
from typing_extensions import Self, TypeAlias
3131

3232
from pydvl.utils.array import (
33-
Array,
3433
ArrayRetT,
3534
ArrayT,
3635
array_concatenate,
@@ -171,7 +170,7 @@ def with_idx(self, idx: IndexT) -> Self:
171170

172171
return replace(self, idx=idx)
173172

174-
def with_subset(self, subset: Array[IndexT]) -> Self:
173+
def with_subset(self, subset: NDArray[IndexT]) -> Self:
175174
"""Return a copy of sample with the subset changed.
176175
177176
Args:
@@ -257,8 +256,8 @@ def __iter__(self): # No way to type the return Iterator properly
257256
return iter((self.idx, self.subset, self.evaluation))
258257

259258

260-
class LossFunction(Protocol):
261-
def __call__(self, y_true: Array, y_pred: Array) -> Array: ...
259+
class LossFunction(Protocol[ArrayT, ArrayRetT]):
260+
def __call__(self, y_true: ArrayT, y_pred: ArrayT) -> ArrayRetT: ...
262261

263262

264263
class SemivalueCoefficient(Protocol):

0 commit comments

Comments
 (0)