Skip to content

Commit 5d0b3a1

Browse files
committed
Fix array protocol and some types
1 parent ae18a7a commit 5d0b3a1

File tree

2 files changed

+7
-9
lines changed

2 files changed

+7
-9
lines changed

src/pydvl/utils/array.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@
7373
"try_torch_import",
7474
]
7575

76-
DT = TypeVar("DT")
76+
DT = TypeVar("DT", bound=np.generic)
7777

7878

7979
def try_torch_import(require: bool = False) -> ModuleType | None:
@@ -138,7 +138,8 @@ class Array(Protocol[DT]):
138138
in NumPy and PyTorch arrays.
139139
"""
140140

141-
nbytes: int
141+
@property
142+
def nbytes(self) -> int: ...
142143

143144
@property
144145
def shape(self) -> tuple[int, ...]: ...
@@ -163,7 +164,7 @@ def __mul__(self, other) -> Array: ...
163164

164165
def __matmul__(self, other) -> Array: ...
165166

166-
def __array__(self, dtype: DT | None = None) -> NDArray: ...
167+
def __array__(self, dtype: DT | None = None) -> NDArray[DT]: ...
167168

168169
def flatten(self, *args, **kwargs) -> Self: ...
169170

src/pydvl/valuation/samplers/classwise.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@
5151
import numpy as np
5252
from more_itertools import chunked, flatten
5353

54-
from pydvl.utils.array import Array, array_unique, is_categorical
54+
from pydvl.utils.array import DT, Array, array_unique, is_categorical
5555
from pydvl.valuation.dataset import Dataset
5656
from pydvl.valuation.samplers.base import EvaluationStrategy, IndexSampler
5757
from pydvl.valuation.samplers.powerset import NoIndexIteration, PowersetSampler
@@ -104,10 +104,7 @@ def roundrobin(
104104
remaining_generators = cycle(islice(remaining_generators, n_active))
105105

106106

107-
T = TypeVar("T")
108-
109-
110-
def get_unique_labels(arr: Array[T]) -> Array[T]:
107+
def get_unique_labels(arr: Array[DT]) -> Array[DT]:
111108
"""Returns unique labels in a categorical dataset.
112109
113110
Args:
@@ -122,7 +119,7 @@ def get_unique_labels(arr: Array[T]) -> Array[T]:
122119
ValueError: If the input array is not of a categorical type.
123120
"""
124121
if is_categorical(arr):
125-
return cast(Array[T], array_unique(arr))
122+
return cast(Array[DT], array_unique(arr))
126123
else:
127124
raise ValueError(
128125
f"Input array has an unsupported data type for categorical labels: {type(arr)}. "

0 commit comments

Comments
 (0)