Skip to content

Commit 86d7c22

Browse files
committed
Add the ability to reapply hue and saturation from the original image
Based on AUTOMATIC1111/stable-diffusion-webui-pixelization#23
1 parent 7342874 commit 86d7c22

File tree

2 files changed

+301
-235
lines changed

2 files changed

+301
-235
lines changed

__init__.py

Lines changed: 1 addition & 235 deletions
Original file line numberDiff line numberDiff line change
@@ -1,237 +1,3 @@
1-
import asyncio
2-
import os
3-
import sys
4-
5-
import comfy.utils
6-
import numpy as np
7-
import torch
8-
from PIL import Image
9-
from torchvision import transforms
10-
11-
from .Pixelization.models import c2pGen
12-
from .Pixelization.models.networks import define_G
13-
from .Pixelization.test_pro import MLP_code
14-
15-
16-
def has_mps() -> bool:
17-
if sys.platform != "darwin":
18-
return False
19-
return torch.backends.mps.is_available()
20-
21-
22-
def get_cuda_device_string():
23-
return "cuda"
24-
25-
26-
def get_optimal_device_name():
27-
if torch.cuda.is_available():
28-
return get_cuda_device_string()
29-
30-
if has_mps():
31-
return "mps"
32-
33-
return "cpu"
34-
35-
36-
def get_optimal_device():
37-
return torch.device(get_optimal_device_name())
38-
39-
40-
device = get_optimal_device()
41-
42-
43-
basedir = os.path.dirname(os.path.realpath(__file__))
44-
path_checkpoints = os.path.join(basedir, "checkpoints")
45-
path_pixelart_vgg19 = os.path.join(path_checkpoints, "pixelart_vgg19.pth")
46-
path_160_net_G_A = os.path.join(path_checkpoints, "160_net_G_A.pth")
47-
path_alias_net = os.path.join(path_checkpoints, "alias_net.pth")
48-
49-
50-
class TorchHijackForC2pGen:
51-
def __getattr__(self, item):
52-
if item == "load":
53-
return self.load
54-
55-
if hasattr(torch, item):
56-
return getattr(torch, item)
57-
58-
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{item}'")
59-
60-
def load(self, filename, *args, **kwargs):
61-
if filename == "./pixelart_vgg19.pth":
62-
filename = path_pixelart_vgg19
63-
64-
return torch.load(filename, *args, **kwargs)
65-
66-
67-
c2pGen.torch = TorchHijackForC2pGen()
68-
69-
70-
class Model(torch.nn.Module):
71-
def __init__(self):
72-
super().__init__()
73-
74-
os.makedirs(path_checkpoints, exist_ok=True)
75-
76-
models_missing = False
77-
78-
if not os.path.exists(path_pixelart_vgg19):
79-
print(
80-
f"Missing {path_pixelart_vgg19} - download it from https://drive.google.com/uc?id=1VRYKQOsNlE1w1LXje3yTRU5THN2MGdMM"
81-
)
82-
models_missing = True
83-
84-
if not os.path.exists(path_160_net_G_A):
85-
print(
86-
f"Missing {path_160_net_G_A} - download it from https://drive.google.com/uc?id=1i_8xL3stbLWNF4kdQJ50ZhnRFhSDh3Az"
87-
)
88-
models_missing = True
89-
90-
if not os.path.exists(path_alias_net):
91-
print(
92-
f"Missing {path_alias_net} - download it from https://drive.google.com/uc?id=17f2rKnZOpnO9ATwRXgqLz5u5AZsyDvq_"
93-
)
94-
models_missing = True
95-
96-
if models_missing:
97-
error_message = "Missing checkpoints for pixelization - see console for download links."
98-
print(error_message)
99-
raise RuntimeError(error_message)
100-
101-
with torch.no_grad():
102-
self.G_A_net = define_G(3, 3, 64, "c2pGen", "instance", False, "normal", 0.02, [0])
103-
self.alias_net = define_G(3, 3, 64, "antialias", "instance", False, "normal", 0.02, [0])
104-
105-
G_A_state = torch.load(path_160_net_G_A)
106-
for p in list(G_A_state.keys()):
107-
G_A_state["module." + str(p)] = G_A_state.pop(p)
108-
self.G_A_net.load_state_dict(G_A_state)
109-
110-
alias_state = torch.load(path_alias_net)
111-
for p in list(alias_state.keys()):
112-
alias_state["module." + str(p)] = alias_state.pop(p)
113-
self.alias_net.load_state_dict(alias_state)
114-
115-
116-
def rescale_image(img):
117-
"""
118-
Preprocess the image for pixelization.
119-
120-
Crops the image to a size that is divisible by 4.
121-
"""
122-
orig_width, orig_height = img.size
123-
124-
new_width = int(round(orig_width / 4) * 4)
125-
new_height = int(round(orig_height / 4) * 4)
126-
127-
left = (orig_width - new_width) // 2
128-
top = (orig_height - new_height) // 2
129-
right = left + new_width
130-
bottom = top + new_height
131-
132-
img = img.crop((left, top, right, bottom))
133-
134-
trans = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
135-
136-
return trans(img)[None, :, :, :]
137-
138-
139-
def to_image(tensor, pixel_size, upscale_after):
140-
img = tensor.data[0].cpu().float().numpy()
141-
img = (np.transpose(img, (1, 2, 0)) + 1) / 2.0 * 255.0
142-
img = img.astype(np.uint8)
143-
img = Image.fromarray(img)
144-
img = img.resize((img.size[0] // 4, img.size[1] // 4), resample=Image.Resampling.NEAREST)
145-
if upscale_after:
146-
img = img.resize((img.size[0] * pixel_size, img.size[1] * pixel_size), resample=Image.Resampling.NEAREST)
147-
148-
return img
149-
150-
151-
def tensor2pil(image):
152-
return Image.fromarray(np.clip(255.0 * image.cpu().numpy().squeeze(), 0, 255).astype(np.uint8))
153-
154-
155-
def pil2tensor(image):
156-
return torch.from_numpy(np.array(image).astype(np.float32) / 255.0).unsqueeze(0)
157-
158-
159-
def wait_for_async(async_fn, loop=None):
160-
res = []
161-
162-
async def run_async():
163-
r = await async_fn()
164-
res.append(r)
165-
166-
if loop is None:
167-
try:
168-
loop = asyncio.get_event_loop()
169-
except:
170-
loop = asyncio.new_event_loop()
171-
asyncio.set_event_loop(loop)
172-
173-
loop.run_until_complete(run_async())
174-
175-
return res[0]
176-
177-
178-
class Pixelization:
179-
def __init__(self):
180-
if not hasattr(self, "model"):
181-
self.model = Model()
182-
183-
@classmethod
184-
def INPUT_TYPES(cls):
185-
return {
186-
"required": {
187-
"image": ("IMAGE",),
188-
"pixel_size": ("INT", {"default": 4, "min": 1, "max": 32}),
189-
"upscale_after": ("BOOLEAN", {"default": True}),
190-
}
191-
}
192-
193-
RETURN_TYPES = ("IMAGE",)
194-
195-
FUNCTION = "pixelize"
196-
197-
CATEGORY = "image"
198-
199-
OUTPUT_IS_LIST = (True,)
200-
OUTPUT_NODE = False
201-
202-
async def run_pixelatization(self, image, pixel_size, upscale_after):
203-
image = image.resize((image.width * 4 // pixel_size, image.height * 4 // pixel_size))
204-
205-
with torch.no_grad():
206-
in_t = rescale_image(image).to(device)
207-
208-
code = torch.asarray(MLP_code, device=device).reshape((1, 256, 1, 1))
209-
adain_params = self.model.G_A_net.module.MLP(code)
210-
211-
feature = self.model.G_A_net.module.RGBEnc(in_t)
212-
images = self.model.G_A_net.module.RGBDec(feature, adain_params)
213-
out_t = self.model.alias_net(images)
214-
215-
image = to_image(out_t, pixel_size=pixel_size, upscale_after=upscale_after)
216-
217-
image = pil2tensor(image)
218-
219-
return image
220-
221-
def pixelize(self, image, pixel_size, upscale_after):
222-
self.model.to(device)
223-
224-
tensor = image * 255
225-
tensor = np.array(tensor, dtype=np.uint8)
226-
227-
progressbar = comfy.utils.ProgressBar(tensor.shape[0])
228-
all_images = []
229-
for i in range(tensor.shape[0]):
230-
image = Image.fromarray(tensor[i])
231-
all_images.append(wait_for_async(lambda: self.run_pixelatization(image, pixel_size, upscale_after)))
232-
progressbar.update(1)
233-
234-
return (all_images,)
235-
1+
from .nodes import Pixelization
2362

2373
NODE_CLASS_MAPPINGS = {"Pixelization": Pixelization}

0 commit comments

Comments
 (0)