Skip to content

Commit 98ca437

Browse files
committed
Refactor and instead check if mps is being used, not availability
1 parent 0b5dcb3 commit 98ca437

File tree

1 file changed

+1
-5
lines changed

1 file changed

+1
-5
lines changed

modules/sd_hijack.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -182,11 +182,7 @@ def register_buffer(self, name, attr):
182182

183183
if type(attr) == torch.Tensor:
184184
if attr.device != devices.device:
185-
186-
if devices.has_mps():
187-
attr = attr.to(device="mps", dtype=torch.float32)
188-
else:
189-
attr = attr.to(devices.device)
185+
attr = attr.to(device=devices.device, dtype=(torch.float32 if devices.device.type == 'mps' else None))
190186

191187
setattr(self, name, attr)
192188

0 commit comments

Comments
 (0)