Skip to content

Commit d0e69db

Browse files
committed
adds solarize
1 parent 8da09a5 commit d0e69db

File tree

3 files changed

+112
-0
lines changed

3 files changed

+112
-0
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ A collection of post processing nodes for [ComfyUI](https://github.com/comfyanon
1818
- PixelSort: Rearranges the pixels in the input image based on their values, and input mask. Creates a cool glitch like effect.
1919
- Pixelize: Applies a pixelization effect, simulating the reducing of resolution
2020
- Sharpen: Enhances the details in an image by applying a sharpening filter
21+
- Solarize: Inverts image colors based on a threshold for a striking, high-contrast effect
2122

2223
## Example workflow
2324

post_processing/solarize.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
import torch
2+
3+
class Solarize:
4+
def __init__(self):
5+
pass
6+
7+
@classmethod
8+
def INPUT_TYPES(s):
9+
return {
10+
"required": {
11+
"image": ("IMAGE",),
12+
"threshold": ("FLOAT", {
13+
"default": 0.5,
14+
"min": 0.0,
15+
"max": 1.0,
16+
"step": 0.01
17+
}),
18+
},
19+
}
20+
21+
RETURN_TYPES = ("IMAGE",)
22+
FUNCTION = "solarize_image"
23+
24+
CATEGORY = "postprocessing"
25+
26+
def solarize_image(self, image: torch.Tensor, threshold: float):
27+
solarized_image = torch.where(image > threshold, 1 - image, image)
28+
solarized_image = torch.clamp(solarized_image, 0, 1)
29+
return (solarized_image,)
30+
31+
NODE_CLASS_MAPPINGS = {
32+
"Solarize": Solarize,
33+
}

post_processing_nodes.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,54 @@ def subtract(self, img1, img2):
4646
def difference(self, img1, img2):
4747
return torch.abs(img1 - img2)
4848

49+
class Blend:
50+
def __init__(self):
51+
pass
52+
53+
@classmethod
54+
def INPUT_TYPES(s):
55+
return {
56+
"required": {
57+
"image1": ("IMAGE",),
58+
"image2": ("IMAGE",),
59+
"blend_factor": ("FLOAT", {
60+
"default": 0.5,
61+
"min": 0.0,
62+
"max": 1.0,
63+
"step": 0.01
64+
}),
65+
"blend_mode": (["normal", "multiply", "screen", "overlay", "soft_light"],),
66+
},
67+
}
68+
69+
RETURN_TYPES = ("IMAGE",)
70+
FUNCTION = "blend_images"
71+
72+
CATEGORY = "postprocessing"
73+
74+
def blend_images(self, image1: torch.Tensor, image2: torch.Tensor, blend_factor: float, blend_mode: str):
75+
blended_image = self.blend_mode(image1, image2, blend_mode)
76+
blended_image = image1 * (1 - blend_factor) + blended_image * blend_factor
77+
blended_image = torch.clamp(blended_image, 0, 1)
78+
return (blended_image,)
79+
80+
def blend_mode(self, img1, img2, mode):
81+
if mode == "normal":
82+
return img2
83+
elif mode == "multiply":
84+
return img1 * img2
85+
elif mode == "screen":
86+
return 1 - (1 - img1) * (1 - img2)
87+
elif mode == "overlay":
88+
return torch.where(img1 <= 0.5, 2 * img1 * img2, 1 - 2 * (1 - img1) * (1 - img2))
89+
elif mode == "soft_light":
90+
return torch.where(img2 <= 0.5, img1 - (1 - 2 * img2) * img1 * (1 - img1), img1 + (2 * img2 - 1) * (self.g(img1) - img1))
91+
else:
92+
raise ValueError(f"Unsupported blend mode: {mode}")
93+
94+
def g(self, x):
95+
return torch.where(x <= 0.25, ((16 * x - 12) * x + 4) * x, torch.sqrt(x))
96+
4997
class Blur:
5098
def __init__(self):
5199
pass
@@ -787,6 +835,34 @@ def sharpen(self, image: torch.Tensor, blur_radius: int, alpha: float):
787835

788836
return (result,)
789837

838+
class Solarize:
839+
def __init__(self):
840+
pass
841+
842+
@classmethod
843+
def INPUT_TYPES(s):
844+
return {
845+
"required": {
846+
"image": ("IMAGE",),
847+
"threshold": ("FLOAT", {
848+
"default": 0.5,
849+
"min": 0.0,
850+
"max": 1.0,
851+
"step": 0.01
852+
}),
853+
},
854+
}
855+
856+
RETURN_TYPES = ("IMAGE",)
857+
FUNCTION = "solarize_image"
858+
859+
CATEGORY = "postprocessing"
860+
861+
def solarize_image(self, image: torch.Tensor, threshold: float):
862+
solarized_image = torch.where(image > threshold, 1 - image, image)
863+
solarized_image = torch.clamp(solarized_image, 0, 1)
864+
return (solarized_image,)
865+
790866
def sort_span(span, sort_by, reverse_sorting):
791867
if sort_by == 'H':
792868
key = lambda x: x[1][0]
@@ -880,6 +956,7 @@ def pixel_sort(img, mask, horizontal_sort=False, span_limit=None, sort_by='H', r
880956

881957
NODE_CLASS_MAPPINGS = {
882958
"ArithmeticBlend": ArithmeticBlend,
959+
"Blend": Blend,
883960
"Blur": Blur,
884961
"CannyEdgeDetection": CannyEdgeDetection,
885962
"ColorCorrect": ColorCorrect,
@@ -892,4 +969,5 @@ def pixel_sort(img, mask, horizontal_sort=False, span_limit=None, sort_by='H', r
892969
"PixelSort": PixelSort,
893970
"Pixelize": Pixelize,
894971
"Sharpen": Sharpen,
972+
"Solarize": Solarize,
895973
}

0 commit comments

Comments
 (0)