File tree Expand file tree Collapse file tree 2 files changed +7
-9
lines changed Expand file tree Collapse file tree 2 files changed +7
-9
lines changed Original file line number Diff line number Diff line change 7373 "try_torch_import" ,
7474]
7575
76- DT = TypeVar ("DT" )
76+ DT = TypeVar ("DT" , bound = np . generic )
7777
7878
7979def try_torch_import (require : bool = False ) -> ModuleType | None :
@@ -138,7 +138,8 @@ class Array(Protocol[DT]):
138138 in NumPy and PyTorch arrays.
139139 """
140140
141- nbytes : int
141+ @property
142+ def nbytes (self ) -> int : ...
142143
143144 @property
144145 def shape (self ) -> tuple [int , ...]: ...
@@ -163,7 +164,7 @@ def __mul__(self, other) -> Array: ...
163164
164165 def __matmul__ (self , other ) -> Array : ...
165166
166- def __array__ (self , dtype : DT | None = None ) -> NDArray : ...
167+ def __array__ (self , dtype : DT | None = None ) -> NDArray [ DT ] : ...
167168
168169 def flatten (self , * args , ** kwargs ) -> Self : ...
169170
Original file line number Diff line number Diff line change 5151import numpy as np
5252from more_itertools import chunked , flatten
5353
54- from pydvl .utils .array import Array , array_unique , is_categorical
54+ from pydvl .utils .array import DT , Array , array_unique , is_categorical
5555from pydvl .valuation .dataset import Dataset
5656from pydvl .valuation .samplers .base import EvaluationStrategy , IndexSampler
5757from pydvl .valuation .samplers .powerset import NoIndexIteration , PowersetSampler
@@ -104,10 +104,7 @@ def roundrobin(
104104 remaining_generators = cycle (islice (remaining_generators , n_active ))
105105
106106
107- T = TypeVar ("T" )
108-
109-
110- def get_unique_labels (arr : Array [T ]) -> Array [T ]:
107+ def get_unique_labels (arr : Array [DT ]) -> Array [DT ]:
111108 """Returns unique labels in a categorical dataset.
112109
113110 Args:
@@ -122,7 +119,7 @@ def get_unique_labels(arr: Array[T]) -> Array[T]:
122119 ValueError: If the input array is not of a categorical type.
123120 """
124121 if is_categorical (arr ):
125- return cast (Array [T ], array_unique (arr ))
122+ return cast (Array [DT ], array_unique (arr ))
126123 else :
127124 raise ValueError (
128125 f"Input array has an unsupported data type for categorical labels: { type (arr )} . "
You can’t perform that action at this time.
0 commit comments