|
1 | | -import asyncio |
2 | | -import os |
3 | | -import sys |
4 | | - |
5 | | -import comfy.utils |
6 | | -import numpy as np |
7 | | -import torch |
8 | | -from PIL import Image |
9 | | -from torchvision import transforms |
10 | | - |
11 | | -from .Pixelization.models import c2pGen |
12 | | -from .Pixelization.models.networks import define_G |
13 | | -from .Pixelization.test_pro import MLP_code |
14 | | - |
15 | | - |
16 | | -def has_mps() -> bool: |
17 | | - if sys.platform != "darwin": |
18 | | - return False |
19 | | - return torch.backends.mps.is_available() |
20 | | - |
21 | | - |
22 | | -def get_cuda_device_string(): |
23 | | - return "cuda" |
24 | | - |
25 | | - |
26 | | -def get_optimal_device_name(): |
27 | | - if torch.cuda.is_available(): |
28 | | - return get_cuda_device_string() |
29 | | - |
30 | | - if has_mps(): |
31 | | - return "mps" |
32 | | - |
33 | | - return "cpu" |
34 | | - |
35 | | - |
36 | | -def get_optimal_device(): |
37 | | - return torch.device(get_optimal_device_name()) |
38 | | - |
39 | | - |
40 | | -device = get_optimal_device() |
41 | | - |
42 | | - |
43 | | -basedir = os.path.dirname(os.path.realpath(__file__)) |
44 | | -path_checkpoints = os.path.join(basedir, "checkpoints") |
45 | | -path_pixelart_vgg19 = os.path.join(path_checkpoints, "pixelart_vgg19.pth") |
46 | | -path_160_net_G_A = os.path.join(path_checkpoints, "160_net_G_A.pth") |
47 | | -path_alias_net = os.path.join(path_checkpoints, "alias_net.pth") |
48 | | - |
49 | | - |
50 | | -class TorchHijackForC2pGen: |
51 | | - def __getattr__(self, item): |
52 | | - if item == "load": |
53 | | - return self.load |
54 | | - |
55 | | - if hasattr(torch, item): |
56 | | - return getattr(torch, item) |
57 | | - |
58 | | - raise AttributeError(f"'{type(self).__name__}' object has no attribute '{item}'") |
59 | | - |
60 | | - def load(self, filename, *args, **kwargs): |
61 | | - if filename == "./pixelart_vgg19.pth": |
62 | | - filename = path_pixelart_vgg19 |
63 | | - |
64 | | - return torch.load(filename, *args, **kwargs) |
65 | | - |
66 | | - |
67 | | -c2pGen.torch = TorchHijackForC2pGen() |
68 | | - |
69 | | - |
70 | | -class Model(torch.nn.Module): |
71 | | - def __init__(self): |
72 | | - super().__init__() |
73 | | - |
74 | | - os.makedirs(path_checkpoints, exist_ok=True) |
75 | | - |
76 | | - models_missing = False |
77 | | - |
78 | | - if not os.path.exists(path_pixelart_vgg19): |
79 | | - print( |
80 | | - f"Missing {path_pixelart_vgg19} - download it from https://drive.google.com/uc?id=1VRYKQOsNlE1w1LXje3yTRU5THN2MGdMM" |
81 | | - ) |
82 | | - models_missing = True |
83 | | - |
84 | | - if not os.path.exists(path_160_net_G_A): |
85 | | - print( |
86 | | - f"Missing {path_160_net_G_A} - download it from https://drive.google.com/uc?id=1i_8xL3stbLWNF4kdQJ50ZhnRFhSDh3Az" |
87 | | - ) |
88 | | - models_missing = True |
89 | | - |
90 | | - if not os.path.exists(path_alias_net): |
91 | | - print( |
92 | | - f"Missing {path_alias_net} - download it from https://drive.google.com/uc?id=17f2rKnZOpnO9ATwRXgqLz5u5AZsyDvq_" |
93 | | - ) |
94 | | - models_missing = True |
95 | | - |
96 | | - if models_missing: |
97 | | - error_message = "Missing checkpoints for pixelization - see console for download links." |
98 | | - print(error_message) |
99 | | - raise RuntimeError(error_message) |
100 | | - |
101 | | - with torch.no_grad(): |
102 | | - self.G_A_net = define_G(3, 3, 64, "c2pGen", "instance", False, "normal", 0.02, [0]) |
103 | | - self.alias_net = define_G(3, 3, 64, "antialias", "instance", False, "normal", 0.02, [0]) |
104 | | - |
105 | | - G_A_state = torch.load(path_160_net_G_A) |
106 | | - for p in list(G_A_state.keys()): |
107 | | - G_A_state["module." + str(p)] = G_A_state.pop(p) |
108 | | - self.G_A_net.load_state_dict(G_A_state) |
109 | | - |
110 | | - alias_state = torch.load(path_alias_net) |
111 | | - for p in list(alias_state.keys()): |
112 | | - alias_state["module." + str(p)] = alias_state.pop(p) |
113 | | - self.alias_net.load_state_dict(alias_state) |
114 | | - |
115 | | - |
116 | | -def rescale_image(img): |
117 | | - """ |
118 | | - Preprocess the image for pixelization. |
119 | | -
|
120 | | - Crops the image to a size that is divisible by 4. |
121 | | - """ |
122 | | - orig_width, orig_height = img.size |
123 | | - |
124 | | - new_width = int(round(orig_width / 4) * 4) |
125 | | - new_height = int(round(orig_height / 4) * 4) |
126 | | - |
127 | | - left = (orig_width - new_width) // 2 |
128 | | - top = (orig_height - new_height) // 2 |
129 | | - right = left + new_width |
130 | | - bottom = top + new_height |
131 | | - |
132 | | - img = img.crop((left, top, right, bottom)) |
133 | | - |
134 | | - trans = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) |
135 | | - |
136 | | - return trans(img)[None, :, :, :] |
137 | | - |
138 | | - |
139 | | -def to_image(tensor, pixel_size, upscale_after): |
140 | | - img = tensor.data[0].cpu().float().numpy() |
141 | | - img = (np.transpose(img, (1, 2, 0)) + 1) / 2.0 * 255.0 |
142 | | - img = img.astype(np.uint8) |
143 | | - img = Image.fromarray(img) |
144 | | - img = img.resize((img.size[0] // 4, img.size[1] // 4), resample=Image.Resampling.NEAREST) |
145 | | - if upscale_after: |
146 | | - img = img.resize((img.size[0] * pixel_size, img.size[1] * pixel_size), resample=Image.Resampling.NEAREST) |
147 | | - |
148 | | - return img |
149 | | - |
150 | | - |
151 | | -def tensor2pil(image): |
152 | | - return Image.fromarray(np.clip(255.0 * image.cpu().numpy().squeeze(), 0, 255).astype(np.uint8)) |
153 | | - |
154 | | - |
155 | | -def pil2tensor(image): |
156 | | - return torch.from_numpy(np.array(image).astype(np.float32) / 255.0).unsqueeze(0) |
157 | | - |
158 | | - |
159 | | -def wait_for_async(async_fn, loop=None): |
160 | | - res = [] |
161 | | - |
162 | | - async def run_async(): |
163 | | - r = await async_fn() |
164 | | - res.append(r) |
165 | | - |
166 | | - if loop is None: |
167 | | - try: |
168 | | - loop = asyncio.get_event_loop() |
169 | | - except: |
170 | | - loop = asyncio.new_event_loop() |
171 | | - asyncio.set_event_loop(loop) |
172 | | - |
173 | | - loop.run_until_complete(run_async()) |
174 | | - |
175 | | - return res[0] |
176 | | - |
177 | | - |
178 | | -class Pixelization: |
179 | | - def __init__(self): |
180 | | - if not hasattr(self, "model"): |
181 | | - self.model = Model() |
182 | | - |
183 | | - @classmethod |
184 | | - def INPUT_TYPES(cls): |
185 | | - return { |
186 | | - "required": { |
187 | | - "image": ("IMAGE",), |
188 | | - "pixel_size": ("INT", {"default": 4, "min": 1, "max": 32}), |
189 | | - "upscale_after": ("BOOLEAN", {"default": True}), |
190 | | - } |
191 | | - } |
192 | | - |
193 | | - RETURN_TYPES = ("IMAGE",) |
194 | | - |
195 | | - FUNCTION = "pixelize" |
196 | | - |
197 | | - CATEGORY = "image" |
198 | | - |
199 | | - OUTPUT_IS_LIST = (True,) |
200 | | - OUTPUT_NODE = False |
201 | | - |
202 | | - async def run_pixelatization(self, image, pixel_size, upscale_after): |
203 | | - image = image.resize((image.width * 4 // pixel_size, image.height * 4 // pixel_size)) |
204 | | - |
205 | | - with torch.no_grad(): |
206 | | - in_t = rescale_image(image).to(device) |
207 | | - |
208 | | - code = torch.asarray(MLP_code, device=device).reshape((1, 256, 1, 1)) |
209 | | - adain_params = self.model.G_A_net.module.MLP(code) |
210 | | - |
211 | | - feature = self.model.G_A_net.module.RGBEnc(in_t) |
212 | | - images = self.model.G_A_net.module.RGBDec(feature, adain_params) |
213 | | - out_t = self.model.alias_net(images) |
214 | | - |
215 | | - image = to_image(out_t, pixel_size=pixel_size, upscale_after=upscale_after) |
216 | | - |
217 | | - image = pil2tensor(image) |
218 | | - |
219 | | - return image |
220 | | - |
221 | | - def pixelize(self, image, pixel_size, upscale_after): |
222 | | - self.model.to(device) |
223 | | - |
224 | | - tensor = image * 255 |
225 | | - tensor = np.array(tensor, dtype=np.uint8) |
226 | | - |
227 | | - progressbar = comfy.utils.ProgressBar(tensor.shape[0]) |
228 | | - all_images = [] |
229 | | - for i in range(tensor.shape[0]): |
230 | | - image = Image.fromarray(tensor[i]) |
231 | | - all_images.append(wait_for_async(lambda: self.run_pixelatization(image, pixel_size, upscale_after))) |
232 | | - progressbar.update(1) |
233 | | - |
234 | | - return (all_images,) |
235 | | - |
| 1 | +from .nodes import Pixelization |
236 | 2 |
|
237 | 3 | NODE_CLASS_MAPPINGS = {"Pixelization": Pixelization} |
0 commit comments