Skip to content

Commit faed465

Browse files
brkirchAUTOMATIC1111
authored andcommitted
MPS Upscalers Fix
Get ESRGAN, SCUNet, and SwinIR working correctly on MPS by ensuring memory is contiguous for tensor views before sending to MPS device.
1 parent 4c24347 commit faed465

File tree

4 files changed

+7
-4
lines changed

4 files changed

+7
-4
lines changed

modules/devices.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,3 +81,7 @@ def autocast(disable=False):
8181
return contextlib.nullcontext()
8282

8383
return torch.autocast("cuda")
84+
85+
# MPS workaround for https://github.com/pytorch/pytorch/issues/79383
86+
def mps_contiguous(input_tensor, device): return input_tensor.contiguous() if device.type == 'mps' else input_tensor
87+
def mps_contiguous_to(input_tensor, device): return mps_contiguous(input_tensor, device).to(device)

modules/esrgan_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,7 @@ def upscale_without_tiling(model, img):
190190
img = img[:, :, ::-1]
191191
img = np.ascontiguousarray(np.transpose(img, (2, 0, 1))) / 255
192192
img = torch.from_numpy(img).float()
193-
img = img.unsqueeze(0).to(devices.device_esrgan)
193+
img = devices.mps_contiguous_to(img.unsqueeze(0), devices.device_esrgan)
194194
with torch.no_grad():
195195
output = model(img)
196196
output = output.squeeze().float().cpu().clamp_(0, 1).numpy()

modules/scunet_model.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,9 +54,8 @@ def do_upscale(self, img: PIL.Image, selected_file):
5454
img = img[:, :, ::-1]
5555
img = np.moveaxis(img, 2, 0) / 255
5656
img = torch.from_numpy(img).float()
57-
img = img.unsqueeze(0).to(device)
57+
img = devices.mps_contiguous_to(img.unsqueeze(0), device)
5858

59-
img = img.to(device)
6059
with torch.no_grad():
6160
output = model(img)
6261
output = output.squeeze().float().cpu().clamp_(0, 1).numpy()

modules/swinir_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ def upscale(
111111
img = img[:, :, ::-1]
112112
img = np.moveaxis(img, 2, 0) / 255
113113
img = torch.from_numpy(img).float()
114-
img = img.unsqueeze(0).to(devices.device_swinir)
114+
img = devices.mps_contiguous_to(img.unsqueeze(0), devices.device_swinir)
115115
with torch.no_grad(), precision_scope("cuda"):
116116
_, _, h_old, w_old = img.size()
117117
h_pad = (h_old // window_size + 1) * window_size - h_old

0 commit comments

Comments
 (0)