|
7 | 7 | from basicsr.utils.download_util import load_file_from_url
|
8 | 8 | from tqdm import tqdm
|
9 | 9 |
|
10 |
| -from modules import modelloader |
11 |
| -from modules.shared import cmd_opts, opts, device |
| 10 | +from modules import modelloader, devices |
| 11 | +from modules.shared import cmd_opts, opts |
12 | 12 | from modules.swinir_model_arch import SwinIR as net
|
13 | 13 | from modules.swinir_model_arch_v2 import Swin2SR as net2
|
14 | 14 | from modules.upscaler import Upscaler, UpscalerData
|
@@ -42,7 +42,7 @@ def do_upscale(self, img, model_file):
|
42 | 42 | model = self.load_model(model_file)
|
43 | 43 | if model is None:
|
44 | 44 | return img
|
45 |
| - model = model.to(device) |
| 45 | + model = model.to(devices.device_swinir) |
46 | 46 | img = upscale(img, model)
|
47 | 47 | try:
|
48 | 48 | torch.cuda.empty_cache()
|
@@ -111,7 +111,7 @@ def upscale(
|
111 | 111 | img = img[:, :, ::-1]
|
112 | 112 | img = np.moveaxis(img, 2, 0) / 255
|
113 | 113 | img = torch.from_numpy(img).float()
|
114 |
| - img = img.unsqueeze(0).to(device) |
| 114 | + img = img.unsqueeze(0).to(devices.device_swinir) |
115 | 115 | with torch.no_grad(), precision_scope("cuda"):
|
116 | 116 | _, _, h_old, w_old = img.size()
|
117 | 117 | h_pad = (h_old // window_size + 1) * window_size - h_old
|
@@ -139,8 +139,8 @@ def inference(img, model, tile, tile_overlap, window_size, scale):
|
139 | 139 | stride = tile - tile_overlap
|
140 | 140 | h_idx_list = list(range(0, h - tile, stride)) + [h - tile]
|
141 | 141 | w_idx_list = list(range(0, w - tile, stride)) + [w - tile]
|
142 |
| - E = torch.zeros(b, c, h * sf, w * sf, dtype=torch.half, device=device).type_as(img) |
143 |
| - W = torch.zeros_like(E, dtype=torch.half, device=device) |
| 142 | + E = torch.zeros(b, c, h * sf, w * sf, dtype=torch.half, device=devices.device_swinir).type_as(img) |
| 143 | + W = torch.zeros_like(E, dtype=torch.half, device=devices.device_swinir) |
144 | 144 |
|
145 | 145 | with tqdm(total=len(h_idx_list) * len(w_idx_list), desc="SwinIR tiles") as pbar:
|
146 | 146 | for h_idx in h_idx_list:
|
|
0 commit comments