4646from __future__ import annotations
4747
4848from itertools import cycle , islice
49- from typing import Generator , Iterable , Mapping , TypeVar , cast
49+ from typing import TYPE_CHECKING , Generator , Iterable , Mapping , TypeVar , overload
5050
5151import numpy as np
5252from more_itertools import chunked , flatten
53+ from numpy .typing import NDArray
5354
54- from pydvl .utils .array import DT , Array , array_unique , is_categorical
55+ from pydvl .utils .array import array_unique , is_categorical
5556from pydvl .valuation .dataset import Dataset
5657from pydvl .valuation .samplers .base import EvaluationStrategy , IndexSampler
5758from pydvl .valuation .samplers .powerset import NoIndexIteration , PowersetSampler
6768
6869__all__ = ["ClasswiseSampler" ]
6970
71+ if TYPE_CHECKING :
72+ from torch import Tensor
73+
74+
7075U = TypeVar ("U" )
7176V = TypeVar ("V" )
7277
@@ -104,7 +109,15 @@ def roundrobin(
104109 remaining_generators = cycle (islice (remaining_generators , n_active ))
105110
106111
107- def get_unique_labels (arr : Array [DT ]) -> Array [DT ]:
112+ @overload
113+ def get_unique_labels (arr : NDArray ) -> NDArray : ...
114+
115+
116+ @overload
117+ def get_unique_labels (arr : Tensor ) -> Tensor : ...
118+
119+
120+ def get_unique_labels (arr : NDArray | Tensor ) -> NDArray | Tensor :
108121 """Returns unique labels in a categorical dataset.
109122
110123 Args:
@@ -119,12 +132,11 @@ def get_unique_labels(arr: Array[DT]) -> Array[DT]:
119132 ValueError: If the input array is not of a categorical type.
120133 """
121134 if is_categorical (arr ):
122- return cast (Array [DT ], array_unique (arr ))
123- else :
124- raise ValueError (
125- f"Input array has an unsupported data type for categorical labels: { type (arr )} . "
126- "Expected types: Object, String, Unicode, Unsigned integer, Signed integer, or Boolean."
127- )
135+ return array_unique (arr )
136+ raise ValueError (
137+ f"Input array has an unsupported data type for categorical labels: { type (arr )} . "
138+ "Expected types: Object, String, Unicode, Unsigned integer, Signed integer, or Boolean."
139+ )
128140
129141
130142class ClasswiseSampler (IndexSampler [ClasswiseSample , ValueUpdate ]):
0 commit comments