diff --git a/scripts/postprocessing_pixelization.py b/scripts/postprocessing_pixelization.py index b4b98ac..4c99887 100644 --- a/scripts/postprocessing_pixelization.py +++ b/scripts/postprocessing_pixelization.py @@ -13,6 +13,8 @@ from pixelization.models.networks import define_G import pixelization.models.c2pGen +import colorsys + pixelize_code = [ 233356.8125, -27387.5918, -32866.8008, 126575.0312, -181590.0156, -31543.1289, 50374.1289, 99631.4062, -188897.3750, 138322.7031, @@ -153,17 +155,46 @@ def process(img): return trans(img)[None, :, :, :] -def to_image(tensor, pixel_size, upscale_after): +def to_image(tensor, pixel_size, upscale_after, original_img, copy_hue, copy_sat): img = tensor.data[0].cpu().float().numpy() img = (np.transpose(img, (1, 2, 0)) + 1) / 2.0 * 255.0 img = img.astype(np.uint8) img = Image.fromarray(img) - img = img.resize((img.size[0]//4, img.size[1]//4), resample=Image.Resampling.NEAREST) + width = img.size[0]//4 + height = img.size[1]//4 + img = img.resize((width, height), resample=Image.Resampling.NEAREST) + + if copy_hue or copy_sat: + original_img = original_img.resize((width, height), resample=Image.Resampling.NEAREST) + img = color_image(img, original_img, copy_hue, copy_sat); + if upscale_after: img = img.resize((img.size[0]*pixel_size, img.size[1]*pixel_size), resample=Image.Resampling.NEAREST) return img + +def color_image(img, original_img, copy_hue, copy_sat): + img = img.convert("RGB") + original_img = original_img.convert("RGB") + + colored_img = Image.new("RGB", img.size) + + print(img.width, img.height) + + for x in range(img.width): + for y in range(img.height): + pixel = original_img.getpixel((x, y)) + r, g, b = pixel + original_h, original_s, original_v = colorsys.rgb_to_hsv(r / 255, g / 255, b / 255) + + pixel = img.getpixel((x, y)) + r, g, b = pixel + h, s, v = colorsys.rgb_to_hsv(r / 255, g / 255, b / 255) + r, g, b = colorsys.hsv_to_rgb(original_h if copy_hue else h, original_s if copy_sat else s, v) + colored_img.putpixel((x, y), (int(r * 255), int(g * 255), int(b * 255))) + + return colored_img class ScriptPostprocessingUpscale(scripts_postprocessing.ScriptPostprocessing): name = "Pixelization" @@ -176,17 +207,21 @@ def ui(self): with FormRow(): enable = gr.Checkbox(False, label="Enable pixelization") upscale_after = gr.Checkbox(False, label="Keep resolution") - + with FormRow(): + copy_hue = gr.Checkbox(False, label="Restore hue") + copy_sat = gr.Checkbox(False, label="Restore saturation") with gr.Column(): pixel_size = gr.Slider(minimum=1, maximum=16, step=1, label="Pixel size", value=4, elem_id="pixelization_pixel_size") - + return { "enable": enable, "upscale_after": upscale_after, "pixel_size": pixel_size, + "copy_hue": copy_hue, + "copy_sat": copy_sat, } - def process(self, pp: scripts_postprocessing.PostprocessedImage, enable, upscale_after, pixel_size): + def process(self, pp: scripts_postprocessing.PostprocessedImage, enable, upscale_after, pixel_size, copy_hue, copy_sat): if not enable: return @@ -199,6 +234,7 @@ def process(self, pp: scripts_postprocessing.PostprocessedImage, enable, upscale self.model.to(devices.device) pp.image = pp.image.resize((pp.image.width * 4 // pixel_size, pp.image.height * 4 // pixel_size)) + original_img = pp.image.copy() with torch.no_grad(): in_t = process(pp.image).to(devices.device) @@ -209,7 +245,7 @@ def process(self, pp: scripts_postprocessing.PostprocessedImage, enable, upscale images = self.model.G_A_net.module.RGBDec(feature, adain_params) out_t = self.model.alias_net(images) - pp.image = to_image(out_t, pixel_size=pixel_size, upscale_after=upscale_after) + pp.image = to_image(out_t, pixel_size=pixel_size, upscale_after=upscale_after, original_img=original_img, copy_hue=copy_hue, copy_sat=copy_sat) self.model.to(devices.cpu)