Skip to content

Commit c141eda

Browse files
committed
Remove redundant require_torch
1 parent 010a982 commit c141eda

File tree

1 file changed

+3
-9
lines changed

1 file changed

+3
-9
lines changed

src/pydvl/utils/array.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,6 @@
6969
"atleast1d",
7070
"check_X_y",
7171
"check_X_y_torch",
72-
"require_torch",
7372
"try_torch_import",
7473
]
7574

@@ -101,12 +100,6 @@ def try_torch_import(require: bool = False) -> ModuleType | None:
101100
Tensor = Any if torch is None else torch.Tensor
102101

103102

104-
def require_torch() -> ModuleType:
105-
torch = try_torch_import(require=True)
106-
assert torch is not None
107-
return torch
108-
109-
110103
def is_tensor(array: Any) -> bool:
111104
"""Check if an array is a PyTorch tensor."""
112105
return torch is not None and isinstance(array, torch.Tensor)
@@ -190,7 +183,8 @@ def to_tensor(array: Array | ArrayLike) -> Tensor:
190183
Raises:
191184
ImportError: If PyTorch is not available.
192185
"""
193-
torch = require_torch()
186+
assert torch is not None
187+
194188
if isinstance(array, torch.Tensor):
195189
return array
196190
return cast(Tensor, torch.as_tensor(array))
@@ -520,7 +514,7 @@ def check_X_y_torch(
520514
Raises:
521515
ValueError or TypeError if the inputs are invalid.
522516
"""
523-
torch = require_torch()
517+
assert torch is not None
524518

525519
estimator_name = (
526520
estimator.__class__.__name__

0 commit comments

Comments
 (0)