102102from numbers import Integral
103103from pathlib import Path
104104from tempfile import mkdtemp
105- from typing import Any , Generic , Sequence , cast , overload
105+ from typing import TYPE_CHECKING , Any , Generic , Sequence , cast , overload
106106
107107import numpy as np
108108from deprecate import deprecated
113113__all__ = ["Dataset" , "GroupedDataset" , "RawData" ]
114114
115115from pydvl .utils .array import (
116- Array ,
117116 ArrayT ,
118117 atleast1d ,
119118 check_X_y ,
126125
127126logger = logging .getLogger (__name__ )
128127
129- # Import torch if available for typing
130- torch = try_torch_import ()
131- Tensor = None if torch is None else torch .Tensor
128+ if TYPE_CHECKING :
129+ import torch
130+ else :
131+ torch = try_torch_import ()
132132
133133
134134@dataclass (frozen = True )
@@ -302,8 +302,10 @@ def __init__(
302302 "Make sure that the data has the proper shape before "
303303 "constructing a Dataset"
304304 )
305- _x , _y = x , y
305+ _x : NDArray | np .memmap = x
306+ _y : NDArray | np .memmap = y
306307 else :
308+ assert isinstance (x , np .ndarray ) and isinstance (y , np .ndarray )
307309 _x , _y = check_X_y (x , y , multi_output = multi_output , estimator = "Dataset" )
308310
309311 self ._x = cast (ArrayT , _maybe_create_memmap (_x ))
@@ -314,10 +316,10 @@ def __init__(
314316 )
315317
316318 # These are for __setstate__
317- self ._x_dtype , self ._y_dtype = self ._x .dtype , self ._y .dtype
318- self ._x_shape , self ._y_shape = self ._x .shape , self ._y .shape
319+ self ._x_dtype , self ._y_dtype = self ._x .dtype , self ._y .dtype # type: ignore
320+ self ._x_shape , self ._y_shape = self ._x .shape , self ._y .shape # type: ignore
319321
320- def make_names (s : str , a : Array ) -> NDArray [np .str_ ]:
322+ def make_names (s : str , a : ArrayT ) -> NDArray [np .str_ ]:
321323 n = a .shape [1 ] if len (a .shape ) > 1 else 1
322324 return np .array (
323325 [f"{ s } { i :0{1 + int (math .log10 (n ))}d} " for i in range (1 , n + 1 )],
@@ -480,7 +482,7 @@ def target(self, name: str) -> tuple[slice, int] | slice:
480482 ValueError: If the target name is not found.
481483 """
482484 try :
483- target_idx = np .where (self .target_names == name )[0 ][0 ]
485+ target_idx = np .where (self .target_names == name )[0 ][0 ]. item ()
484486 if self .n_targets == 1 :
485487 return slice (None )
486488 else :
@@ -742,8 +744,8 @@ def __init__(
742744 Added support for PyTorch tensors.
743745 """
744746 super ().__init__ (
745- x = x ,
746- y = y ,
747+ x = x , # type: ignore
748+ y = y , # type: ignore
747749 feature_names = feature_names ,
748750 target_names = target_names ,
749751 data_names = data_names ,
@@ -788,7 +790,7 @@ def __len__(self) -> int:
788790
789791 def __getitem__ (
790792 self , idx : int | slice | Sequence [int ] | NDArray [np .int_ ] | None = None
791- ) -> GroupedDataset :
793+ ) -> GroupedDataset [ ArrayT ] :
792794 if idx is None :
793795 idx = slice (None )
794796 elif isinstance (idx , int ):
@@ -818,7 +820,7 @@ def names(self) -> NDArray[np.str_]:
818820
819821 def data (
820822 self , indices : int | slice | Sequence [int ] | NDArray [np .int_ ] | None = None
821- ) -> RawData :
823+ ) -> RawData [ ArrayT ] :
822824 """Returns the data and labels of all samples in the given groups.
823825
824826 Args:
0 commit comments