Skip to content

Commit d780974

Browse files
committed
Use proper overloads
1 parent 4b8113e commit d780974

File tree

1 file changed

+21
-9
lines changed

1 file changed

+21
-9
lines changed

src/pydvl/valuation/samplers/classwise.py

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -46,12 +46,13 @@
4646
from __future__ import annotations
4747

4848
from itertools import cycle, islice
49-
from typing import Generator, Iterable, Mapping, TypeVar, cast
49+
from typing import TYPE_CHECKING, Generator, Iterable, Mapping, TypeVar, overload
5050

5151
import numpy as np
5252
from more_itertools import chunked, flatten
53+
from numpy.typing import NDArray
5354

54-
from pydvl.utils.array import DT, Array, array_unique, is_categorical
55+
from pydvl.utils.array import array_unique, is_categorical
5556
from pydvl.valuation.dataset import Dataset
5657
from pydvl.valuation.samplers.base import EvaluationStrategy, IndexSampler
5758
from pydvl.valuation.samplers.powerset import NoIndexIteration, PowersetSampler
@@ -67,6 +68,10 @@
6768

6869
__all__ = ["ClasswiseSampler"]
6970

71+
if TYPE_CHECKING:
72+
from torch import Tensor
73+
74+
7075
U = TypeVar("U")
7176
V = TypeVar("V")
7277

@@ -104,7 +109,15 @@ def roundrobin(
104109
remaining_generators = cycle(islice(remaining_generators, n_active))
105110

106111

107-
def get_unique_labels(arr: Array[DT]) -> Array[DT]:
112+
@overload
113+
def get_unique_labels(arr: NDArray) -> NDArray: ...
114+
115+
116+
@overload
117+
def get_unique_labels(arr: Tensor) -> Tensor: ...
118+
119+
120+
def get_unique_labels(arr: NDArray | Tensor) -> NDArray | Tensor:
108121
"""Returns unique labels in a categorical dataset.
109122
110123
Args:
@@ -119,12 +132,11 @@ def get_unique_labels(arr: Array[DT]) -> Array[DT]:
119132
ValueError: If the input array is not of a categorical type.
120133
"""
121134
if is_categorical(arr):
122-
return cast(Array[DT], array_unique(arr))
123-
else:
124-
raise ValueError(
125-
f"Input array has an unsupported data type for categorical labels: {type(arr)}. "
126-
"Expected types: Object, String, Unicode, Unsigned integer, Signed integer, or Boolean."
127-
)
135+
return array_unique(arr)
136+
raise ValueError(
137+
f"Input array has an unsupported data type for categorical labels: {type(arr)}. "
138+
"Expected types: Object, String, Unicode, Unsigned integer, Signed integer, or Boolean."
139+
)
128140

129141

130142
class ClasswiseSampler(IndexSampler[ClasswiseSample, ValueUpdate]):

0 commit comments

Comments
 (0)