Skip to content

Commit 9995668

Browse files
committed
replaces dither and kmeans quantize with quantize
much faster dithering and uses PIL
1 parent 08a83e3 commit 9995668

File tree

6 files changed

+133
-120
lines changed

6 files changed

+133
-120
lines changed

README.md

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,12 @@ A collection of post processing nodes for [ComfyUI](https://github.com/comfyanon
1010
- CannyEdgeDetection: Applies Canny edge detection to the input image
1111
- ColorCorrect: Adjusts the color balance, temperature, hue, brightness, contrast, saturation, and gamma of an image
1212
- Dissolve: Creates a grainy blend of two images using random pixels based on a dissolve factor.
13-
- Dither: Reduces the color information in an image by dithering, resulting in a patterned, pixelated appearance
1413
- DodgeAndBurn: Adjusts image brightness using dodge and burn effects based on a mask and intensity.
1514
- FilmGrain: Adds a film grain effect to the image, along with options to control the temperature, and vignetting
1615
- Glow: Applies a blur with a specified radius and then blends it with the original image. Creates a nice glowing effect.
17-
- KMeansQuantize: Reduce the amount of colors in an image from 0-256
1816
- PixelSort: Rearranges the pixels in the input image based on their values, and input mask. Creates a cool glitch like effect.
1917
- Pixelize: Applies a pixelization effect, simulating the reducing of resolution
18+
- Quantize: Set and dither the amount of colors in an image from 0-256, reducing color information
2019
- Sharpen: Enhances the details in an image by applying a sharpening filter
2120
- Solarize: Inverts image colors based on a threshold for a striking, high-contrast effect
2221

combine_files.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,20 @@
66
import ast
77
import argparse
88

9-
9+
ignore_dirs = ["old"]
1010

1111
def get_python_files(path, recursive=False, args=None):
1212
search_pattern = "**/*.py" if recursive else "*.py"
13-
files = sorted([str(file) for file in Path(path).glob(search_pattern) if file.is_file() and not file.name.startswith("combine") and not args.output in str(file)])
13+
14+
def should_include(file):
15+
if file.is_file() and not file.name.startswith("combine") and not args.output in str(file):
16+
for ignore_dir in ignore_dirs:
17+
if ignore_dir in str(file.parent):
18+
return False
19+
return True
20+
return False
21+
22+
files = sorted([str(file) for file in Path(path).glob(search_pattern) if should_include(file)])
1423
yield from files
1524

1625
def parse_files(files):
File renamed without changes.

post_processing/quantize.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
import torch
2+
from PIL import Image
3+
import numpy as np
4+
5+
class Quantize:
6+
def __init__(self):
7+
pass
8+
9+
@classmethod
10+
def INPUT_TYPES(s):
11+
return {
12+
"required": {
13+
"image": ("IMAGE",),
14+
"colors": ("INT", {
15+
"default": 256,
16+
"min": 1,
17+
"max": 256,
18+
"step": 1
19+
}),
20+
"dither": (["none", "floyd-steinberg"],),
21+
},
22+
}
23+
24+
RETURN_TYPES = ("IMAGE",)
25+
FUNCTION = "quantize"
26+
27+
CATEGORY = "postprocessing"
28+
29+
def quantize(self, image: torch.Tensor, colors: int = 256, dither: str = "FLOYDSTEINBERG"):
30+
batch_size, height, width, _ = image.shape
31+
result = torch.zeros_like(image)
32+
33+
dither_option = Image.Dither.FLOYDSTEINBERG if dither == "floyd-steinberg" else Image.Dither.NONE
34+
35+
for b in range(batch_size):
36+
tensor_image = image[b]
37+
img = (tensor_image * 255).to(torch.uint8).numpy()
38+
pil_image = Image.fromarray(img, mode='RGB')
39+
40+
palette = pil_image.quantize(colors=colors) # Required as described in https://github.com/python-pillow/Pillow/issues/5836
41+
quantized_image = pil_image.quantize(colors=colors, palette=palette, dither=dither_option)
42+
43+
quantized_array = torch.tensor(np.array(quantized_image.convert("RGB"))).float() / 255
44+
result[b] = quantized_array
45+
46+
return (result,)
47+
48+
NODE_CLASS_MAPPINGS = {
49+
"Quantize": Quantize,
50+
}

post_processing_nodes.py

Lines changed: 71 additions & 116 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,9 @@ def INPUT_TYPES(s):
7272
CATEGORY = "postprocessing"
7373

7474
def blend_images(self, image1: torch.Tensor, image2: torch.Tensor, blend_factor: float, blend_mode: str):
75+
if image1.shape != image2.shape:
76+
image2 = self.crop_and_resize(image2, image1.shape)
77+
7578
blended_image = self.blend_mode(image1, image2, blend_mode)
7679
blended_image = image1 * (1 - blend_factor) + blended_image * blend_factor
7780
blended_image = torch.clamp(blended_image, 0, 1)
@@ -94,6 +97,29 @@ def blend_mode(self, img1, img2, mode):
9497
def g(self, x):
9598
return torch.where(x <= 0.25, ((16 * x - 12) * x + 4) * x, torch.sqrt(x))
9699

100+
def crop_and_resize(self, img: torch.Tensor, target_shape: tuple):
101+
batch_size, img_h, img_w, img_c = img.shape
102+
_, target_h, target_w, _ = target_shape
103+
img_aspect_ratio = img_w / img_h
104+
target_aspect_ratio = target_w / target_h
105+
106+
# Crop center of the image to the target aspect ratio
107+
if img_aspect_ratio > target_aspect_ratio:
108+
new_width = int(img_h * target_aspect_ratio)
109+
left = (img_w - new_width) // 2
110+
img = img[:, :, left:left + new_width, :]
111+
else:
112+
new_height = int(img_w / target_aspect_ratio)
113+
top = (img_h - new_height) // 2
114+
img = img[:, top:top + new_height, :, :]
115+
116+
# Resize to target size
117+
img = img.permute(0, 3, 1, 2) # Torch wants (B, C, H, W) we use (B, H, W, C)
118+
img = F.interpolate(img, size=(target_h, target_w), mode='bilinear', align_corners=False)
119+
img = img.permute(0, 2, 3, 1)
120+
121+
return img
122+
97123
class Blur:
98124
def __init__(self):
99125
pass
@@ -124,7 +150,7 @@ def INPUT_TYPES(s):
124150
CATEGORY = "postprocessing"
125151

126152
def gaussian_kernel(self, kernel_size: int, sigma: float):
127-
x, y = torch.meshgrid(torch.linspace(-1, 1, kernel_size), torch.linspace(-1, 1, kernel_size))
153+
x, y = torch.meshgrid(torch.linspace(-1, 1, kernel_size), torch.linspace(-1, 1, kernel_size), indexing="ij")
128154
d = torch.sqrt(x * x + y * y)
129155
g = torch.exp(-(d * d) / (2.0 * sigma * sigma))
130156
return g / g.sum()
@@ -324,63 +350,6 @@ def dissolve_images(self, image1: torch.Tensor, image2: torch.Tensor, dissolve_f
324350
dissolved_image = torch.clamp(dissolved_image, 0, 1)
325351
return (dissolved_image,)
326352

327-
class Dither:
328-
def __init__(self):
329-
pass
330-
331-
@classmethod
332-
def INPUT_TYPES(s):
333-
return {
334-
"required": {
335-
"image": ("IMAGE",),
336-
"bits": ("INT", {
337-
"default": 4,
338-
"min": 1,
339-
"max": 8,
340-
"step": 1
341-
}),
342-
},
343-
}
344-
345-
RETURN_TYPES = ("IMAGE",)
346-
FUNCTION = "dither"
347-
348-
CATEGORY = "postprocessing"
349-
350-
def dither(self, image: torch.Tensor, bits: int):
351-
batch_size, height, width, _ = image.shape
352-
result = torch.zeros_like(image)
353-
354-
for b in range(batch_size):
355-
tensor_image = image[b]
356-
img = (tensor_image * 255)
357-
height, width, _ = img.shape
358-
359-
scale = 255 / (2**bits - 1)
360-
361-
for y in range(height):
362-
for x in range(width):
363-
old_pixel = img[y, x].clone()
364-
new_pixel = torch.round(old_pixel / scale) * scale
365-
img[y, x] = new_pixel
366-
367-
quant_error = old_pixel - new_pixel
368-
369-
if x + 1 < width:
370-
img[y, x + 1] += quant_error * 7 / 16
371-
if y + 1 < height:
372-
if x - 1 >= 0:
373-
img[y + 1, x - 1] += quant_error * 3 / 16
374-
img[y + 1, x] += quant_error * 5 / 16
375-
if x + 1 < width:
376-
img[y + 1, x + 1] += quant_error * 1 / 16
377-
378-
dithered = img / 255
379-
tensor = dithered.unsqueeze(0)
380-
result[b] = tensor
381-
382-
return (result,)
383-
384353
class DodgeAndBurn:
385354
def __init__(self):
386355
pass
@@ -645,62 +614,6 @@ def gaussian_blur(self, image: torch.Tensor, kernel_size: int):
645614
def add_glow(self, img, blurred_img, intensity):
646615
return img + blurred_img * intensity
647616

648-
class KMeansQuantize:
649-
def __init__(self):
650-
pass
651-
652-
@classmethod
653-
def INPUT_TYPES(s):
654-
return {
655-
"required": {
656-
"image": ("IMAGE",),
657-
"colors": ("INT", {
658-
"default": 16,
659-
"min": 1,
660-
"max": 256,
661-
"step": 1
662-
}),
663-
"precision": ("INT", {
664-
"default": 10,
665-
"min": 1,
666-
"max": 100,
667-
"step": 1
668-
}),
669-
},
670-
}
671-
672-
RETURN_TYPES = ("IMAGE",)
673-
FUNCTION = "kmeans_quantize"
674-
675-
CATEGORY = "postprocessing"
676-
677-
def kmeans_quantize(self, image: torch.Tensor, colors: int, precision: int):
678-
batch_size, height, width, _ = image.shape
679-
result = torch.zeros_like(image)
680-
681-
for b in range(batch_size):
682-
tensor_image = image[b].numpy().astype(np.float32)
683-
img = tensor_image
684-
685-
height, width, c = img.shape
686-
687-
criteria = (
688-
cv2.TERM_CRITERIA_EPS + cv2.TERM_CRITERIA_MAX_ITER,
689-
precision * 5, 0.01
690-
)
691-
692-
img_copy = img.reshape(-1, c)
693-
_, label, center = cv2.kmeans(
694-
img_copy, colors, None,
695-
criteria, 1, cv2.KMEANS_PP_CENTERS
696-
)
697-
698-
img = center[label.flatten()].reshape(*img.shape)
699-
tensor = torch.from_numpy(img).unsqueeze(0)
700-
result[b] = tensor
701-
702-
return (result,)
703-
704617
class PixelSort:
705618
def __init__(self):
706619
pass
@@ -785,6 +698,49 @@ def pixelize_image(self, image: torch.Tensor, pixel_size: int):
785698

786699
return image
787700

701+
class Quantize:
702+
def __init__(self):
703+
pass
704+
705+
@classmethod
706+
def INPUT_TYPES(s):
707+
return {
708+
"required": {
709+
"image": ("IMAGE",),
710+
"colors": ("INT", {
711+
"default": 256,
712+
"min": 1,
713+
"max": 256,
714+
"step": 1
715+
}),
716+
"dither": (["none", "floyd-steinberg"],),
717+
},
718+
}
719+
720+
RETURN_TYPES = ("IMAGE",)
721+
FUNCTION = "quantize"
722+
723+
CATEGORY = "postprocessing"
724+
725+
def quantize(self, image: torch.Tensor, colors: int = 256, dither: str = "FLOYDSTEINBERG"):
726+
batch_size, height, width, _ = image.shape
727+
result = torch.zeros_like(image)
728+
729+
dither_option = Image.Dither.FLOYDSTEINBERG if dither == "floyd-steinberg" else Image.Dither.NONE
730+
731+
for b in range(batch_size):
732+
tensor_image = image[b]
733+
img = (tensor_image * 255).to(torch.uint8).numpy()
734+
pil_image = Image.fromarray(img, mode='RGB')
735+
736+
palette = pil_image.quantize(colors=colors) # Required as described in https://github.com/python-pillow/Pillow/issues/5836
737+
quantized_image = pil_image.quantize(colors=colors, palette=palette, dither=dither_option)
738+
739+
quantized_array = torch.tensor(np.array(quantized_image.convert("RGB"))).float() / 255
740+
result[b] = quantized_array
741+
742+
return (result,)
743+
788744
class Sharpen:
789745
def __init__(self):
790746
pass
@@ -961,13 +917,12 @@ def pixel_sort(img, mask, horizontal_sort=False, span_limit=None, sort_by='H', r
961917
"CannyEdgeDetection": CannyEdgeDetection,
962918
"ColorCorrect": ColorCorrect,
963919
"Dissolve": Dissolve,
964-
"Dither": Dither,
965920
"DodgeAndBurn": DodgeAndBurn,
966921
"FilmGrain": FilmGrain,
967922
"Glow": Glow,
968-
"KMeansQuantize": KMeansQuantize,
969923
"PixelSort": PixelSort,
970924
"Pixelize": Pixelize,
925+
"Quantize": Quantize,
971926
"Sharpen": Sharpen,
972927
"Solarize": Solarize,
973928
}

0 commit comments

Comments
 (0)