Skip to content

Commit 57a5e56

Browse files
author
Beat Buesser
committed
Update format and typing
Signed-off-by: Beat Buesser <[email protected]>
1 parent 1cd67ca commit 57a5e56

File tree

1 file changed

+7
-3
lines changed

1 file changed

+7
-3
lines changed

art/utils.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,10 +36,12 @@
3636
from scipy.special import gammainc
3737
import six
3838
from tqdm.auto import tqdm
39-
from torch import Tensor
4039

4140
from art import config
4241

42+
if TYPE_CHECKING:
43+
import torch
44+
4345
logger = logging.getLogger(__name__)
4446

4547

@@ -1241,28 +1243,30 @@ def pad_sequence_input(x: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
12411243
# -------------------------------------------------------------------------------------------------------- CUDA SUPPORT
12421244

12431245

1244-
def to_cuda(x: Tensor) -> Tensor:
1246+
def to_cuda(x: "torch.Tensor") -> "torch.Tensor":
12451247
"""
12461248
Move the tensor from the CPU to the GPU if a GPU is available.
12471249
12481250
:param x: CPU Tensor to move to GPU if available.
12491251
:return: The CPU Tensor moved to a GPU Tensor.
12501252
"""
12511253
from torch.cuda import is_available
1254+
12521255
use_cuda = is_available()
12531256
if use_cuda:
12541257
x = x.cuda()
12551258
return x
12561259

12571260

1258-
def from_cuda(x: Tensor) -> Tensor:
1261+
def from_cuda(x: "torch.Tensor") -> "torch.Tensor":
12591262
"""
12601263
Move the tensor from the GPU to the CPU if a GPU is available.
12611264
12621265
:param x: GPU Tensor to move to CPU if available.
12631266
:return: The GPU Tensor moved to a CPU Tensor.
12641267
"""
12651268
from torch.cuda import is_available
1269+
12661270
use_cuda = is_available()
12671271
if use_cuda:
12681272
x = x.cpu()

0 commit comments

Comments
 (0)