File tree Expand file tree Collapse file tree 1 file changed +7
-3
lines changed Expand file tree Collapse file tree 1 file changed +7
-3
lines changed Original file line number Diff line number Diff line change 3636from scipy .special import gammainc
3737import six
3838from tqdm .auto import tqdm
39- from torch import Tensor
4039
4140from art import config
4241
42+ if TYPE_CHECKING :
43+ import torch
44+
4345logger = logging .getLogger (__name__ )
4446
4547
@@ -1241,28 +1243,30 @@ def pad_sequence_input(x: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
12411243# -------------------------------------------------------------------------------------------------------- CUDA SUPPORT
12421244
12431245
1244- def to_cuda (x : Tensor ) -> Tensor :
1246+ def to_cuda (x : "torch. Tensor" ) -> "torch. Tensor" :
12451247 """
12461248 Move the tensor from the CPU to the GPU if a GPU is available.
12471249
12481250 :param x: CPU Tensor to move to GPU if available.
12491251 :return: The CPU Tensor moved to a GPU Tensor.
12501252 """
12511253 from torch .cuda import is_available
1254+
12521255 use_cuda = is_available ()
12531256 if use_cuda :
12541257 x = x .cuda ()
12551258 return x
12561259
12571260
1258- def from_cuda (x : Tensor ) -> Tensor :
1261+ def from_cuda (x : "torch. Tensor" ) -> "torch. Tensor" :
12591262 """
12601263 Move the tensor from the GPU to the CPU if a GPU is available.
12611264
12621265 :param x: GPU Tensor to move to CPU if available.
12631266 :return: The GPU Tensor moved to a CPU Tensor.
12641267 """
12651268 from torch .cuda import is_available
1269+
12661270 use_cuda = is_available ()
12671271 if use_cuda :
12681272 x = x .cpu ()
You can’t perform that action at this time.
0 commit comments