Skip to content

Commit f657b6b

Browse files
committed
up
1 parent 8448bdf commit f657b6b

File tree

2 files changed

+3
-3
lines changed

2 files changed

+3
-3
lines changed

src/diffusers/pipelines/pipeline_utils.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,6 @@
6767
numpy_to_pil,
6868
)
6969
from ..utils.hub_utils import _check_legacy_sharding_variant_format, load_or_create_model_card, populate_model_card
70-
from ..utils.testing_utils import torch_device
7170
from ..utils.torch_utils import empty_device_cache, get_device, is_compiled_module
7271

7372

@@ -109,8 +108,7 @@
109108
for library in LOADABLE_CLASSES:
110109
LIBRARIES.append(library)
111110

112-
# TODO: support single-device namings
113-
SUPPORTED_DEVICE_MAP = ["balanced"] + [torch_device]
111+
SUPPORTED_DEVICE_MAP = ["balanced"] + [get_device()]
114112

115113
logger = logging.get_logger(__name__)
116114

src/diffusers/utils/torch_utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
PyTorch utilities: Utilities related to PyTorch
1616
"""
1717

18+
import functools
1819
from typing import List, Optional, Tuple, Union
1920

2021
from . import logging
@@ -168,6 +169,7 @@ def get_torch_cuda_device_capability():
168169
return None
169170

170171

172+
@functools.lru_cache
171173
def get_device():
172174
if torch.cuda.is_available():
173175
return "cuda"

0 commit comments

Comments
 (0)