From cc7d96b8e030cb9b5dd19f42d5411693bf7b67b4 Mon Sep 17 00:00:00 2001 From: y4my4my4m <8145020+y4my4my4m@users.noreply.github.com> Date: Mon, 8 Jul 2024 00:45:48 +0900 Subject: [PATCH] RGBA with proper hue/saturation and upscaling with nice ui --- scripts/postprocessing_pixelization.py | 103 +++++++++++++++++++------ 1 file changed, 79 insertions(+), 24 deletions(-) diff --git a/scripts/postprocessing_pixelization.py b/scripts/postprocessing_pixelization.py index e62c962..5609bd8 100644 --- a/scripts/postprocessing_pixelization.py +++ b/scripts/postprocessing_pixelization.py @@ -12,7 +12,8 @@ from pixelization.models.networks import define_G import pixelization.models.c2pGen -import gdown + +import colorsys pixelize_code = [ 233356.8125, -27387.5918, -32866.8008, 126575.0312, -181590.0156, @@ -107,20 +108,19 @@ def load(self): missing = False - models = ( - (path_pixelart_vgg19, "https://drive.google.com/uc?id=1VRYKQOsNlE1w1LXje3yTRU5THN2MGdMM"), - (path_160_net_G_A, "https://drive.google.com/uc?id=1i_8xL3stbLWNF4kdQJ50ZhnRFhSDh3Az"), - (path_alias_net, "https://drive.google.com/uc?id=17f2rKnZOpnO9ATwRXgqLz5u5AZsyDvq_"), - ) + if not os.path.exists(path_pixelart_vgg19): + print(f"Missing {path_pixelart_vgg19} - download it from https://drive.google.com/uc?id=1VRYKQOsNlE1w1LXje3yTRU5THN2MGdMM") + missing = True - for path, url in models: - if not os.path.exists(path): - gdown.download(url, path) + if not os.path.exists(path_160_net_G_A): + print(f"Missing {path_160_net_G_A} - download it from https://drive.google.com/uc?id=1i_8xL3stbLWNF4kdQJ50ZhnRFhSDh3Az") + missing = True - if not os.path.exists(path): - missing = True + if not os.path.exists(path_alias_net): + print(f"Missing {path_alias_net} - download it from https://drive.google.com/uc?id=17f2rKnZOpnO9ATwRXgqLz5u5AZsyDvq_") + missing = True - assert not missing, f'Missing checkpoints for pixelization - see console for download links. Download checkpoints manually and place them in {path_checkpoints}.' + assert not missing, 'Missing checkpoints for pixelization - see console for doqwnload links.' with torch.no_grad(): self.G_A_net = define_G(3, 3, 64, "c2pGen", "instance", False, "normal", 0.02, [0]) @@ -136,7 +136,6 @@ def load(self): alias_state["module." + str(p)] = alias_state.pop(p) self.alias_net.load_state_dict(alias_state) - def process(img): ow, oh = img.size @@ -150,22 +149,71 @@ def process(img): img = img.crop((left, top, right, bottom)) - trans = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) + # Split the RGBA image into RGB and alpha channels + img_rgba = img.convert('RGBA') + r, g, b, a = img_rgba.split() - return trans(img)[None, :, :, :] + # Convert RGB to tensor and normalize + rgb_img = Image.merge('RGB', (r, g, b)) + trans_rgb = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) + ]) + rgb_tensor = trans_rgb(rgb_img) + # Convert alpha channel to tensor (scale from 0-255 to 0-1) + alpha_tensor = transforms.ToTensor()(a)[None, :, :] # Add an extra dimension for batch size -def to_image(tensor, pixel_size, upscale_after): + return rgb_tensor[None, :, :, :], alpha_tensor + +def to_image(tensor, alpha_tensor, pixel_size, upscale_after, original_img, copy_hue, copy_sat): img = tensor.data[0].cpu().float().numpy() img = (np.transpose(img, (1, 2, 0)) + 1) / 2.0 * 255.0 img = img.astype(np.uint8) img = Image.fromarray(img) - img = img.resize((img.size[0]//4, img.size[1]//4), resample=Image.Resampling.NEAREST) + width = img.size[0] // 4 + height = img.size[1] // 4 + img = img.resize((width, height), resample=Image.Resampling.NEAREST) + + # Resize the alpha channel to match the new dimensions + alpha_img = alpha_tensor.data[0].cpu().numpy() + alpha_img = (alpha_img * 255).astype(np.uint8) + alpha_img = Image.fromarray(alpha_img.squeeze(), mode='L') + alpha_img = alpha_img.resize((width, height), resample=Image.Resampling.NEAREST) + + if copy_hue or copy_sat: + original_img = original_img.resize((width, height), resample=Image.Resampling.NEAREST) + img = color_image(img, original_img, copy_hue, copy_sat) + + + # Merge the processed RGB image with the alpha channel + img.putalpha(alpha_img) + if upscale_after: img = img.resize((img.size[0]*pixel_size, img.size[1]*pixel_size), resample=Image.Resampling.NEAREST) return img +def color_image(img, original_img, copy_hue, copy_sat): + img = img.convert("RGB") + original_img = original_img.convert("RGB") + + colored_img = Image.new("RGB", img.size) + + for x in range(img.width): + for y in range(img.height): + pixel = original_img.getpixel((x, y)) + r, g, b = pixel + original_h, original_s, original_v = colorsys.rgb_to_hsv(r / 255, g / 255, b / 255) + + pixel = img.getpixel((x, y)) + r, g, b = pixel + h, s, v = colorsys.rgb_to_hsv(r / 255, g / 255, b / 255) + + r, g, b = colorsys.hsv_to_rgb(original_h if copy_hue else h, original_s if copy_sat else s, v) + colored_img.putpixel((x, y), (int(r * 255), int(g * 255), int(b * 255))) + + return colored_img class ScriptPostprocessingUpscale(scripts_postprocessing.ScriptPostprocessing): name = "Pixelization" @@ -175,16 +223,21 @@ class ScriptPostprocessingUpscale(scripts_postprocessing.ScriptPostprocessing): def ui(self): with ui_components.InputAccordion(False, label="Pixelize") as enable: with gr.Row(): - upscale_after = gr.Checkbox(False, label="Keep resolution") + upscale_after = gr.Checkbox(False, label="Keep resolution") + copy_hue = gr.Checkbox(False, label="Restore hue") + copy_sat = gr.Checkbox(False, label="Restore saturation") + with gr.Column(): pixel_size = gr.Slider(minimum=1, maximum=16, step=1, label="Pixel size", value=4, elem_id="pixelization_pixel_size") return { "enable": enable, "upscale_after": upscale_after, "pixel_size": pixel_size, + "copy_hue": copy_hue, + "copy_sat": copy_sat, } - def process(self, pp: scripts_postprocessing.PostprocessedImage, enable, upscale_after, pixel_size): + def process(self, pp: scripts_postprocessing.PostprocessedImage, enable, upscale_after, pixel_size, copy_hue, copy_sat): if not enable: return @@ -196,20 +249,22 @@ def process(self, pp: scripts_postprocessing.PostprocessedImage, enable, upscale self.model.to(devices.device) - pp.image = pp.image.resize((pp.image.width * 4 // pixel_size, pp.image.height * 4 // pixel_size)).convert('RGB') + pp.image = pp.image.resize((pp.image.width * 4 // pixel_size, pp.image.height * 4 // pixel_size)) + original_img = pp.image.copy() with torch.no_grad(): - in_t = process(pp.image).to(devices.device) + in_t, alpha_t = process(pp.image) + in_t = in_t.to(devices.device) + alpha_t = alpha_t.to(devices.device) feature = self.model.G_A_net.module.RGBEnc(in_t) - code = torch.asarray(pixelize_code, device=devices.device).reshape((1, 256, 1, 1)) + code = torch.tensor(pixelize_code, device=devices.device).reshape((1, 256, 1, 1)) adain_params = self.model.G_A_net.module.MLP(code) images = self.model.G_A_net.module.RGBDec(feature, adain_params) out_t = self.model.alias_net(images) - pp.image = to_image(out_t, pixel_size=pixel_size, upscale_after=upscale_after) + pp.image = to_image(out_t, alpha_t, pixel_size=pixel_size, upscale_after=upscale_after, original_img=original_img, copy_hue=copy_hue, copy_sat=copy_sat) self.model.to(devices.cpu) pp.info["Pixelization pixel size"] = pixel_size -