Skip to content

Commit 2103e39

Browse files
authored
convert nodes_post_processing to V3 schema (#9491)
1 parent d20576e commit 2103e39

File tree

1 file changed

+112
-139
lines changed

1 file changed

+112
-139
lines changed
Lines changed: 112 additions & 139 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from typing_extensions import override
12
import numpy as np
23
import torch
34
import torch.nn.functional as F
@@ -7,46 +8,41 @@
78
import comfy.utils
89
import comfy.model_management
910
import node_helpers
11+
from comfy_api.latest import ComfyExtension, io
1012

11-
class Blend:
12-
def __init__(self):
13-
pass
13+
class Blend(io.ComfyNode):
14+
@classmethod
15+
def define_schema(cls):
16+
return io.Schema(
17+
node_id="ImageBlend",
18+
category="image/postprocessing",
19+
inputs=[
20+
io.Image.Input("image1"),
21+
io.Image.Input("image2"),
22+
io.Float.Input("blend_factor", default=0.5, min=0.0, max=1.0, step=0.01),
23+
io.Combo.Input("blend_mode", options=["normal", "multiply", "screen", "overlay", "soft_light", "difference"]),
24+
],
25+
outputs=[
26+
io.Image.Output(),
27+
],
28+
)
1429

1530
@classmethod
16-
def INPUT_TYPES(s):
17-
return {
18-
"required": {
19-
"image1": ("IMAGE",),
20-
"image2": ("IMAGE",),
21-
"blend_factor": ("FLOAT", {
22-
"default": 0.5,
23-
"min": 0.0,
24-
"max": 1.0,
25-
"step": 0.01
26-
}),
27-
"blend_mode": (["normal", "multiply", "screen", "overlay", "soft_light", "difference"],),
28-
},
29-
}
30-
31-
RETURN_TYPES = ("IMAGE",)
32-
FUNCTION = "blend_images"
33-
34-
CATEGORY = "image/postprocessing"
35-
36-
def blend_images(self, image1: torch.Tensor, image2: torch.Tensor, blend_factor: float, blend_mode: str):
31+
def execute(cls, image1: torch.Tensor, image2: torch.Tensor, blend_factor: float, blend_mode: str) -> io.NodeOutput:
3732
image1, image2 = node_helpers.image_alpha_fix(image1, image2)
3833
image2 = image2.to(image1.device)
3934
if image1.shape != image2.shape:
4035
image2 = image2.permute(0, 3, 1, 2)
4136
image2 = comfy.utils.common_upscale(image2, image1.shape[2], image1.shape[1], upscale_method='bicubic', crop='center')
4237
image2 = image2.permute(0, 2, 3, 1)
4338

44-
blended_image = self.blend_mode(image1, image2, blend_mode)
39+
blended_image = cls.blend_mode(image1, image2, blend_mode)
4540
blended_image = image1 * (1 - blend_factor) + blended_image * blend_factor
4641
blended_image = torch.clamp(blended_image, 0, 1)
47-
return (blended_image,)
42+
return io.NodeOutput(blended_image)
4843

49-
def blend_mode(self, img1, img2, mode):
44+
@classmethod
45+
def blend_mode(cls, img1, img2, mode):
5046
if mode == "normal":
5147
return img2
5248
elif mode == "multiply":
@@ -56,13 +52,13 @@ def blend_mode(self, img1, img2, mode):
5652
elif mode == "overlay":
5753
return torch.where(img1 <= 0.5, 2 * img1 * img2, 1 - 2 * (1 - img1) * (1 - img2))
5854
elif mode == "soft_light":
59-
return torch.where(img2 <= 0.5, img1 - (1 - 2 * img2) * img1 * (1 - img1), img1 + (2 * img2 - 1) * (self.g(img1) - img1))
55+
return torch.where(img2 <= 0.5, img1 - (1 - 2 * img2) * img1 * (1 - img1), img1 + (2 * img2 - 1) * (cls.g(img1) - img1))
6056
elif mode == "difference":
6157
return img1 - img2
62-
else:
63-
raise ValueError(f"Unsupported blend mode: {mode}")
58+
raise ValueError(f"Unsupported blend mode: {mode}")
6459

65-
def g(self, x):
60+
@classmethod
61+
def g(cls, x):
6662
return torch.where(x <= 0.25, ((16 * x - 12) * x + 4) * x, torch.sqrt(x))
6763

6864
def gaussian_kernel(kernel_size: int, sigma: float, device=None):
@@ -71,38 +67,26 @@ def gaussian_kernel(kernel_size: int, sigma: float, device=None):
7167
g = torch.exp(-(d * d) / (2.0 * sigma * sigma))
7268
return g / g.sum()
7369

74-
class Blur:
75-
def __init__(self):
76-
pass
70+
class Blur(io.ComfyNode):
71+
@classmethod
72+
def define_schema(cls):
73+
return io.Schema(
74+
node_id="ImageBlur",
75+
category="image/postprocessing",
76+
inputs=[
77+
io.Image.Input("image"),
78+
io.Int.Input("blur_radius", default=1, min=1, max=31, step=1),
79+
io.Float.Input("sigma", default=1.0, min=0.1, max=10.0, step=0.1),
80+
],
81+
outputs=[
82+
io.Image.Output(),
83+
],
84+
)
7785

7886
@classmethod
79-
def INPUT_TYPES(s):
80-
return {
81-
"required": {
82-
"image": ("IMAGE",),
83-
"blur_radius": ("INT", {
84-
"default": 1,
85-
"min": 1,
86-
"max": 31,
87-
"step": 1
88-
}),
89-
"sigma": ("FLOAT", {
90-
"default": 1.0,
91-
"min": 0.1,
92-
"max": 10.0,
93-
"step": 0.1
94-
}),
95-
},
96-
}
97-
98-
RETURN_TYPES = ("IMAGE",)
99-
FUNCTION = "blur"
100-
101-
CATEGORY = "image/postprocessing"
102-
103-
def blur(self, image: torch.Tensor, blur_radius: int, sigma: float):
87+
def execute(cls, image: torch.Tensor, blur_radius: int, sigma: float) -> io.NodeOutput:
10488
if blur_radius == 0:
105-
return (image,)
89+
return io.NodeOutput(image)
10690

10791
image = image.to(comfy.model_management.get_torch_device())
10892
batch_size, height, width, channels = image.shape
@@ -115,31 +99,24 @@ def blur(self, image: torch.Tensor, blur_radius: int, sigma: float):
11599
blurred = F.conv2d(padded_image, kernel, padding=kernel_size // 2, groups=channels)[:,:,blur_radius:-blur_radius, blur_radius:-blur_radius]
116100
blurred = blurred.permute(0, 2, 3, 1)
117101

118-
return (blurred.to(comfy.model_management.intermediate_device()),)
102+
return io.NodeOutput(blurred.to(comfy.model_management.intermediate_device()))
119103

120-
class Quantize:
121-
def __init__(self):
122-
pass
123104

105+
class Quantize(io.ComfyNode):
124106
@classmethod
125-
def INPUT_TYPES(s):
126-
return {
127-
"required": {
128-
"image": ("IMAGE",),
129-
"colors": ("INT", {
130-
"default": 256,
131-
"min": 1,
132-
"max": 256,
133-
"step": 1
134-
}),
135-
"dither": (["none", "floyd-steinberg", "bayer-2", "bayer-4", "bayer-8", "bayer-16"],),
136-
},
137-
}
138-
139-
RETURN_TYPES = ("IMAGE",)
140-
FUNCTION = "quantize"
141-
142-
CATEGORY = "image/postprocessing"
107+
def define_schema(cls):
108+
return io.Schema(
109+
node_id="ImageQuantize",
110+
category="image/postprocessing",
111+
inputs=[
112+
io.Image.Input("image"),
113+
io.Int.Input("colors", default=256, min=1, max=256, step=1),
114+
io.Combo.Input("dither", options=["none", "floyd-steinberg", "bayer-2", "bayer-4", "bayer-8", "bayer-16"]),
115+
],
116+
outputs=[
117+
io.Image.Output(),
118+
],
119+
)
143120

144121
@staticmethod
145122
def bayer(im, pal_im, order):
@@ -167,7 +144,8 @@ def normalized_bayer_matrix(n):
167144
im = im.quantize(palette=pal_im, dither=Image.Dither.NONE)
168145
return im
169146

170-
def quantize(self, image: torch.Tensor, colors: int, dither: str):
147+
@classmethod
148+
def execute(cls, image: torch.Tensor, colors: int, dither: str) -> io.NodeOutput:
171149
batch_size, height, width, _ = image.shape
172150
result = torch.zeros_like(image)
173151

@@ -187,46 +165,29 @@ def quantize(self, image: torch.Tensor, colors: int, dither: str):
187165
quantized_array = torch.tensor(np.array(quantized_image.convert("RGB"))).float() / 255
188166
result[b] = quantized_array
189167

190-
return (result,)
168+
return io.NodeOutput(result)
191169

192-
class Sharpen:
193-
def __init__(self):
194-
pass
170+
class Sharpen(io.ComfyNode):
171+
@classmethod
172+
def define_schema(cls):
173+
return io.Schema(
174+
node_id="ImageSharpen",
175+
category="image/postprocessing",
176+
inputs=[
177+
io.Image.Input("image"),
178+
io.Int.Input("sharpen_radius", default=1, min=1, max=31, step=1),
179+
io.Float.Input("sigma", default=1.0, min=0.1, max=10.0, step=0.01),
180+
io.Float.Input("alpha", default=1.0, min=0.0, max=5.0, step=0.01),
181+
],
182+
outputs=[
183+
io.Image.Output(),
184+
],
185+
)
195186

196187
@classmethod
197-
def INPUT_TYPES(s):
198-
return {
199-
"required": {
200-
"image": ("IMAGE",),
201-
"sharpen_radius": ("INT", {
202-
"default": 1,
203-
"min": 1,
204-
"max": 31,
205-
"step": 1
206-
}),
207-
"sigma": ("FLOAT", {
208-
"default": 1.0,
209-
"min": 0.1,
210-
"max": 10.0,
211-
"step": 0.01
212-
}),
213-
"alpha": ("FLOAT", {
214-
"default": 1.0,
215-
"min": 0.0,
216-
"max": 5.0,
217-
"step": 0.01
218-
}),
219-
},
220-
}
221-
222-
RETURN_TYPES = ("IMAGE",)
223-
FUNCTION = "sharpen"
224-
225-
CATEGORY = "image/postprocessing"
226-
227-
def sharpen(self, image: torch.Tensor, sharpen_radius: int, sigma:float, alpha: float):
188+
def execute(cls, image: torch.Tensor, sharpen_radius: int, sigma:float, alpha: float) -> io.NodeOutput:
228189
if sharpen_radius == 0:
229-
return (image,)
190+
return io.NodeOutput(image)
230191

231192
batch_size, height, width, channels = image.shape
232193
image = image.to(comfy.model_management.get_torch_device())
@@ -245,23 +206,29 @@ def sharpen(self, image: torch.Tensor, sharpen_radius: int, sigma:float, alpha:
245206

246207
result = torch.clamp(sharpened, 0, 1)
247208

248-
return (result.to(comfy.model_management.intermediate_device()),)
209+
return io.NodeOutput(result.to(comfy.model_management.intermediate_device()))
249210

250-
class ImageScaleToTotalPixels:
211+
class ImageScaleToTotalPixels(io.ComfyNode):
251212
upscale_methods = ["nearest-exact", "bilinear", "area", "bicubic", "lanczos"]
252213
crop_methods = ["disabled", "center"]
253214

254215
@classmethod
255-
def INPUT_TYPES(s):
256-
return {"required": { "image": ("IMAGE",), "upscale_method": (s.upscale_methods,),
257-
"megapixels": ("FLOAT", {"default": 1.0, "min": 0.01, "max": 16.0, "step": 0.01}),
258-
}}
259-
RETURN_TYPES = ("IMAGE",)
260-
FUNCTION = "upscale"
216+
def define_schema(cls):
217+
return io.Schema(
218+
node_id="ImageScaleToTotalPixels",
219+
category="image/upscaling",
220+
inputs=[
221+
io.Image.Input("image"),
222+
io.Combo.Input("upscale_method", options=cls.upscale_methods),
223+
io.Float.Input("megapixels", default=1.0, min=0.01, max=16.0, step=0.01),
224+
],
225+
outputs=[
226+
io.Image.Output(),
227+
],
228+
)
261229

262-
CATEGORY = "image/upscaling"
263-
264-
def upscale(self, image, upscale_method, megapixels):
230+
@classmethod
231+
def execute(cls, image, upscale_method, megapixels) -> io.NodeOutput:
265232
samples = image.movedim(-1,1)
266233
total = int(megapixels * 1024 * 1024)
267234

@@ -271,12 +238,18 @@ def upscale(self, image, upscale_method, megapixels):
271238

272239
s = comfy.utils.common_upscale(samples, width, height, upscale_method, "disabled")
273240
s = s.movedim(1,-1)
274-
return (s,)
275-
276-
NODE_CLASS_MAPPINGS = {
277-
"ImageBlend": Blend,
278-
"ImageBlur": Blur,
279-
"ImageQuantize": Quantize,
280-
"ImageSharpen": Sharpen,
281-
"ImageScaleToTotalPixels": ImageScaleToTotalPixels,
282-
}
241+
return io.NodeOutput(s)
242+
243+
class PostProcessingExtension(ComfyExtension):
244+
@override
245+
async def get_node_list(self) -> list[type[io.ComfyNode]]:
246+
return [
247+
Blend,
248+
Blur,
249+
Quantize,
250+
Sharpen,
251+
ImageScaleToTotalPixels,
252+
]
253+
254+
async def comfy_entrypoint() -> PostProcessingExtension:
255+
return PostProcessingExtension()

0 commit comments

Comments
 (0)