Skip to content

Commit c62d17a

Browse files
committed
use the new devices.has_mps() function in register_buffer for DDIM/PLMS fix for OSX
1 parent 526f0aa commit c62d17a

File tree

1 file changed

+1
-2
lines changed

1 file changed

+1
-2
lines changed

modules/sd_hijack.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -418,8 +418,7 @@ def register_buffer(self, name, attr):
418418
if type(attr) == torch.Tensor:
419419
if attr.device != devices.device:
420420

421-
# would this not break cuda when torch adds has_mps() to main version?
422-
if getattr(torch, 'has_mps', False):
421+
if devices.has_mps():
423422
attr = attr.to(device="mps", dtype=torch.float32)
424423
else:
425424
attr = attr.to(devices.device)

0 commit comments

Comments
 (0)