File tree Expand file tree Collapse file tree 1 file changed +3
-9
lines changed Expand file tree Collapse file tree 1 file changed +3
-9
lines changed Original file line number Diff line number Diff line change 6969 "atleast1d" ,
7070 "check_X_y" ,
7171 "check_X_y_torch" ,
72- "require_torch" ,
7372 "try_torch_import" ,
7473]
7574
@@ -101,12 +100,6 @@ def try_torch_import(require: bool = False) -> ModuleType | None:
101100 Tensor = Any if torch is None else torch .Tensor
102101
103102
104- def require_torch () -> ModuleType :
105- torch = try_torch_import (require = True )
106- assert torch is not None
107- return torch
108-
109-
110103def is_tensor (array : Any ) -> bool :
111104 """Check if an array is a PyTorch tensor."""
112105 return torch is not None and isinstance (array , torch .Tensor )
@@ -190,7 +183,8 @@ def to_tensor(array: Array | ArrayLike) -> Tensor:
190183 Raises:
191184 ImportError: If PyTorch is not available.
192185 """
193- torch = require_torch ()
186+ assert torch is not None
187+
194188 if isinstance (array , torch .Tensor ):
195189 return array
196190 return cast (Tensor , torch .as_tensor (array ))
@@ -520,7 +514,7 @@ def check_X_y_torch(
520514 Raises:
521515 ValueError or TypeError if the inputs are invalid.
522516 """
523- torch = require_torch ()
517+ assert torch is not None
524518
525519 estimator_name = (
526520 estimator .__class__ .__name__
You can’t perform that action at this time.
0 commit comments