diff --git a/src/diffusers/utils/torch_utils.py b/src/diffusers/utils/torch_utils.py index 61a5d95b6926..c1fc1242bb10 100644 --- a/src/diffusers/utils/torch_utils.py +++ b/src/diffusers/utils/torch_utils.py @@ -175,6 +175,8 @@ def get_device(): return "npu" elif hasattr(torch, "xpu") and torch.xpu.is_available(): return "xpu" + elif torch.backends.mps.is_available(): + return "mps" else: return "cpu"