|
3 | 3 | import torch
|
4 | 4 | from modules import errors
|
5 | 5 |
|
| 6 | + |
6 | 7 | # has_mps is only available in nightly pytorch (for now) and MasOS 12.3+.
|
7 | 8 | # check `getattr` and try it for compatibility
|
8 | 9 | def has_mps() -> bool:
|
9 |
| - if not getattr(torch, 'has_mps', False): return False |
| 10 | + if not getattr(torch, 'has_mps', False): |
| 11 | + return False |
10 | 12 | try:
|
11 | 13 | torch.zeros(1).to(torch.device("mps"))
|
12 | 14 | return True
|
13 | 15 | except Exception:
|
14 | 16 | return False
|
15 | 17 |
|
16 |
| -cpu = torch.device("cpu") |
17 | 18 |
|
18 | 19 | def extract_device_id(args, name):
|
19 | 20 | for x in range(len(args)):
|
20 |
| - if name in args[x]: return args[x+1] |
| 21 | + if name in args[x]: |
| 22 | + return args[x + 1] |
| 23 | + |
21 | 24 | return None
|
22 | 25 |
|
| 26 | + |
23 | 27 | def get_optimal_device():
|
24 | 28 | if torch.cuda.is_available():
|
25 | 29 | from modules import shared
|
@@ -52,10 +56,12 @@ def enable_tf32():
|
52 | 56 |
|
53 | 57 | errors.run(enable_tf32, "Enabling TF32")
|
54 | 58 |
|
| 59 | +cpu = torch.device("cpu") |
55 | 60 | device = device_interrogate = device_gfpgan = device_swinir = device_esrgan = device_scunet = device_codeformer = None
|
56 | 61 | dtype = torch.float16
|
57 | 62 | dtype_vae = torch.float16
|
58 | 63 |
|
| 64 | + |
59 | 65 | def randn(seed, shape):
|
60 | 66 | # Pytorch currently doesn't handle setting randomness correctly when the metal backend is used.
|
61 | 67 | if device.type == 'mps':
|
@@ -89,6 +95,11 @@ def autocast(disable=False):
|
89 | 95 |
|
90 | 96 | return torch.autocast("cuda")
|
91 | 97 |
|
| 98 | + |
92 | 99 | # MPS workaround for https://github.com/pytorch/pytorch/issues/79383
|
93 |
| -def mps_contiguous(input_tensor, device): return input_tensor.contiguous() if device.type == 'mps' else input_tensor |
94 |
| -def mps_contiguous_to(input_tensor, device): return mps_contiguous(input_tensor, device).to(device) |
| 100 | +def mps_contiguous(input_tensor, device): |
| 101 | + return input_tensor.contiguous() if device.type == 'mps' else input_tensor |
| 102 | + |
| 103 | + |
| 104 | +def mps_contiguous_to(input_tensor, device): |
| 105 | + return mps_contiguous(input_tensor, device).to(device) |
0 commit comments