Skip to content

Commit 4571fe4

Browse files
committed
Fix some types
1 parent c141eda commit 4571fe4

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

src/pydvl/utils/array.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@ def to_tensor(array: Array | ArrayLike) -> Tensor:
186186
assert torch is not None
187187

188188
if isinstance(array, torch.Tensor):
189-
return array
189+
return cast(Tensor, array)
190190
return cast(Tensor, torch.as_tensor(array))
191191

192192

@@ -473,7 +473,7 @@ def check_X_y(
473473
if is_tensor(X) and is_tensor(y):
474474
assert torch is not None
475475
return cast(
476-
Tuple[Array, Array],
476+
tuple[Tensor, Tensor],
477477
check_X_y_torch(
478478
cast(Tensor, X),
479479
cast(Tensor, y),
@@ -483,7 +483,7 @@ def check_X_y(
483483
),
484484
)
485485
return cast(
486-
Tuple[Array, Array],
486+
Tuple[NDArray, NDArray],
487487
sklearn.utils.check_X_y(
488488
X, y, multi_output=multi_output, estimator=estimator, copy=copy
489489
),

0 commit comments

Comments
 (0)