diff --git a/src/diffusers/utils/torch_utils.py b/src/diffusers/utils/torch_utils.py index 053a3d99b9f9..bb5674092d09 100644 --- a/src/diffusers/utils/torch_utils.py +++ b/src/diffusers/utils/torch_utils.py @@ -18,7 +18,7 @@ from typing import List, Optional, Tuple, Union from . import logging -from .import_utils import is_torch_available, is_torch_version +from .import_utils import is_torch_available, is_torch_npu_available, is_torch_version if is_torch_available(): @@ -166,6 +166,8 @@ def get_torch_cuda_device_capability(): def get_device(): if torch.cuda.is_available(): return "cuda" + elif is_torch_npu_available(): + return "npu" elif hasattr(torch, "xpu") and torch.xpu.is_available(): return "xpu" else: