Skip to content

Commit 0ab0a50

Browse files
committed
change formatting to match the main program in devices.py
1 parent c62d17a commit 0ab0a50

File tree

1 file changed

+16
-5
lines changed

1 file changed

+16
-5
lines changed

modules/devices.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,23 +3,27 @@
33
import torch
44
from modules import errors
55

6+
67
# has_mps is only available in nightly pytorch (for now) and MasOS 12.3+.
78
# check `getattr` and try it for compatibility
89
def has_mps() -> bool:
9-
if not getattr(torch, 'has_mps', False): return False
10+
if not getattr(torch, 'has_mps', False):
11+
return False
1012
try:
1113
torch.zeros(1).to(torch.device("mps"))
1214
return True
1315
except Exception:
1416
return False
1517

16-
cpu = torch.device("cpu")
1718

1819
def extract_device_id(args, name):
1920
for x in range(len(args)):
20-
if name in args[x]: return args[x+1]
21+
if name in args[x]:
22+
return args[x + 1]
23+
2124
return None
2225

26+
2327
def get_optimal_device():
2428
if torch.cuda.is_available():
2529
from modules import shared
@@ -52,10 +56,12 @@ def enable_tf32():
5256

5357
errors.run(enable_tf32, "Enabling TF32")
5458

59+
cpu = torch.device("cpu")
5560
device = device_interrogate = device_gfpgan = device_swinir = device_esrgan = device_scunet = device_codeformer = None
5661
dtype = torch.float16
5762
dtype_vae = torch.float16
5863

64+
5965
def randn(seed, shape):
6066
# Pytorch currently doesn't handle setting randomness correctly when the metal backend is used.
6167
if device.type == 'mps':
@@ -89,6 +95,11 @@ def autocast(disable=False):
8995

9096
return torch.autocast("cuda")
9197

98+
9299
# MPS workaround for https://github.com/pytorch/pytorch/issues/79383
93-
def mps_contiguous(input_tensor, device): return input_tensor.contiguous() if device.type == 'mps' else input_tensor
94-
def mps_contiguous_to(input_tensor, device): return mps_contiguous(input_tensor, device).to(device)
100+
def mps_contiguous(input_tensor, device):
101+
return input_tensor.contiguous() if device.type == 'mps' else input_tensor
102+
103+
104+
def mps_contiguous_to(input_tensor, device):
105+
return mps_contiguous(input_tensor, device).to(device)

0 commit comments

Comments
 (0)