Skip to content

Commit fa74c06

Browse files
committed
Further stroking of mypy's ego
1 parent 9de2f2b commit fa74c06

File tree

2 files changed

+18
-15
lines changed

2 files changed

+18
-15
lines changed

src/pydvl/valuation/dataset.py

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@
102102
from numbers import Integral
103103
from pathlib import Path
104104
from tempfile import mkdtemp
105-
from typing import Any, Generic, Sequence, cast, overload
105+
from typing import TYPE_CHECKING, Any, Generic, Sequence, cast, overload
106106

107107
import numpy as np
108108
from deprecate import deprecated
@@ -113,7 +113,6 @@
113113
__all__ = ["Dataset", "GroupedDataset", "RawData"]
114114

115115
from pydvl.utils.array import (
116-
Array,
117116
ArrayT,
118117
atleast1d,
119118
check_X_y,
@@ -126,9 +125,10 @@
126125

127126
logger = logging.getLogger(__name__)
128127

129-
# Import torch if available for typing
130-
torch = try_torch_import()
131-
Tensor = None if torch is None else torch.Tensor
128+
if TYPE_CHECKING:
129+
import torch
130+
else:
131+
torch = try_torch_import()
132132

133133

134134
@dataclass(frozen=True)
@@ -302,8 +302,10 @@ def __init__(
302302
"Make sure that the data has the proper shape before "
303303
"constructing a Dataset"
304304
)
305-
_x, _y = x, y
305+
_x: NDArray | np.memmap = x
306+
_y: NDArray | np.memmap = y
306307
else:
308+
assert isinstance(x, np.ndarray) and isinstance(y, np.ndarray)
307309
_x, _y = check_X_y(x, y, multi_output=multi_output, estimator="Dataset")
308310

309311
self._x = cast(ArrayT, _maybe_create_memmap(_x))
@@ -314,10 +316,10 @@ def __init__(
314316
)
315317

316318
# These are for __setstate__
317-
self._x_dtype, self._y_dtype = self._x.dtype, self._y.dtype
318-
self._x_shape, self._y_shape = self._x.shape, self._y.shape
319+
self._x_dtype, self._y_dtype = self._x.dtype, self._y.dtype # type: ignore
320+
self._x_shape, self._y_shape = self._x.shape, self._y.shape # type: ignore
319321

320-
def make_names(s: str, a: Array) -> NDArray[np.str_]:
322+
def make_names(s: str, a: ArrayT) -> NDArray[np.str_]:
321323
n = a.shape[1] if len(a.shape) > 1 else 1
322324
return np.array(
323325
[f"{s}{i:0{1 + int(math.log10(n))}d}" for i in range(1, n + 1)],
@@ -480,7 +482,7 @@ def target(self, name: str) -> tuple[slice, int] | slice:
480482
ValueError: If the target name is not found.
481483
"""
482484
try:
483-
target_idx = np.where(self.target_names == name)[0][0]
485+
target_idx = np.where(self.target_names == name)[0][0].item()
484486
if self.n_targets == 1:
485487
return slice(None)
486488
else:
@@ -742,8 +744,8 @@ def __init__(
742744
Added support for PyTorch tensors.
743745
"""
744746
super().__init__(
745-
x=x,
746-
y=y,
747+
x=x, # type: ignore
748+
y=y, # type: ignore
747749
feature_names=feature_names,
748750
target_names=target_names,
749751
data_names=data_names,
@@ -788,7 +790,7 @@ def __len__(self) -> int:
788790

789791
def __getitem__(
790792
self, idx: int | slice | Sequence[int] | NDArray[np.int_] | None = None
791-
) -> GroupedDataset:
793+
) -> GroupedDataset[ArrayT]:
792794
if idx is None:
793795
idx = slice(None)
794796
elif isinstance(idx, int):
@@ -818,7 +820,7 @@ def names(self) -> NDArray[np.str_]:
818820

819821
def data(
820822
self, indices: int | slice | Sequence[int] | NDArray[np.int_] | None = None
821-
) -> RawData:
823+
) -> RawData[ArrayT]:
822824
"""Returns the data and labels of all samples in the given groups.
823825
824826
Args:

src/pydvl/valuation/scorers/classwise.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,8 @@ def __init__(
9999
disc_score_out_of_class = out_of_class_discount_fn(range[1])
100100
transformed_range = (0, disc_score_in_class * disc_score_out_of_class)
101101
super().__init__(
102-
scoring=scoring,
102+
# FIXME: no idea why this makes mypy unhappy
103+
scoring=scoring, # type: ignore[arg-type]
103104
test_data=test_data,
104105
range=transformed_range,
105106
default=default,

0 commit comments

Comments
 (0)