1
+ from typing_extensions import override
1
2
import numpy as np
2
3
import torch
3
4
import torch .nn .functional as F
7
8
import comfy .utils
8
9
import comfy .model_management
9
10
import node_helpers
11
+ from comfy_api .latest import ComfyExtension , io
10
12
11
- class Blend :
12
- def __init__ (self ):
13
- pass
13
+ class Blend (io .ComfyNode ):
14
+ @classmethod
15
+ def define_schema (cls ):
16
+ return io .Schema (
17
+ node_id = "ImageBlend" ,
18
+ category = "image/postprocessing" ,
19
+ inputs = [
20
+ io .Image .Input ("image1" ),
21
+ io .Image .Input ("image2" ),
22
+ io .Float .Input ("blend_factor" , default = 0.5 , min = 0.0 , max = 1.0 , step = 0.01 ),
23
+ io .Combo .Input ("blend_mode" , options = ["normal" , "multiply" , "screen" , "overlay" , "soft_light" , "difference" ]),
24
+ ],
25
+ outputs = [
26
+ io .Image .Output (),
27
+ ],
28
+ )
14
29
15
30
@classmethod
16
- def INPUT_TYPES (s ):
17
- return {
18
- "required" : {
19
- "image1" : ("IMAGE" ,),
20
- "image2" : ("IMAGE" ,),
21
- "blend_factor" : ("FLOAT" , {
22
- "default" : 0.5 ,
23
- "min" : 0.0 ,
24
- "max" : 1.0 ,
25
- "step" : 0.01
26
- }),
27
- "blend_mode" : (["normal" , "multiply" , "screen" , "overlay" , "soft_light" , "difference" ],),
28
- },
29
- }
30
-
31
- RETURN_TYPES = ("IMAGE" ,)
32
- FUNCTION = "blend_images"
33
-
34
- CATEGORY = "image/postprocessing"
35
-
36
- def blend_images (self , image1 : torch .Tensor , image2 : torch .Tensor , blend_factor : float , blend_mode : str ):
31
+ def execute (cls , image1 : torch .Tensor , image2 : torch .Tensor , blend_factor : float , blend_mode : str ) -> io .NodeOutput :
37
32
image1 , image2 = node_helpers .image_alpha_fix (image1 , image2 )
38
33
image2 = image2 .to (image1 .device )
39
34
if image1 .shape != image2 .shape :
40
35
image2 = image2 .permute (0 , 3 , 1 , 2 )
41
36
image2 = comfy .utils .common_upscale (image2 , image1 .shape [2 ], image1 .shape [1 ], upscale_method = 'bicubic' , crop = 'center' )
42
37
image2 = image2 .permute (0 , 2 , 3 , 1 )
43
38
44
- blended_image = self .blend_mode (image1 , image2 , blend_mode )
39
+ blended_image = cls .blend_mode (image1 , image2 , blend_mode )
45
40
blended_image = image1 * (1 - blend_factor ) + blended_image * blend_factor
46
41
blended_image = torch .clamp (blended_image , 0 , 1 )
47
- return (blended_image , )
42
+ return io . NodeOutput (blended_image )
48
43
49
- def blend_mode (self , img1 , img2 , mode ):
44
+ @classmethod
45
+ def blend_mode (cls , img1 , img2 , mode ):
50
46
if mode == "normal" :
51
47
return img2
52
48
elif mode == "multiply" :
@@ -56,13 +52,13 @@ def blend_mode(self, img1, img2, mode):
56
52
elif mode == "overlay" :
57
53
return torch .where (img1 <= 0.5 , 2 * img1 * img2 , 1 - 2 * (1 - img1 ) * (1 - img2 ))
58
54
elif mode == "soft_light" :
59
- return torch .where (img2 <= 0.5 , img1 - (1 - 2 * img2 ) * img1 * (1 - img1 ), img1 + (2 * img2 - 1 ) * (self .g (img1 ) - img1 ))
55
+ return torch .where (img2 <= 0.5 , img1 - (1 - 2 * img2 ) * img1 * (1 - img1 ), img1 + (2 * img2 - 1 ) * (cls .g (img1 ) - img1 ))
60
56
elif mode == "difference" :
61
57
return img1 - img2
62
- else :
63
- raise ValueError (f"Unsupported blend mode: { mode } " )
58
+ raise ValueError (f"Unsupported blend mode: { mode } " )
64
59
65
- def g (self , x ):
60
+ @classmethod
61
+ def g (cls , x ):
66
62
return torch .where (x <= 0.25 , ((16 * x - 12 ) * x + 4 ) * x , torch .sqrt (x ))
67
63
68
64
def gaussian_kernel (kernel_size : int , sigma : float , device = None ):
@@ -71,38 +67,26 @@ def gaussian_kernel(kernel_size: int, sigma: float, device=None):
71
67
g = torch .exp (- (d * d ) / (2.0 * sigma * sigma ))
72
68
return g / g .sum ()
73
69
74
- class Blur :
75
- def __init__ (self ):
76
- pass
70
+ class Blur (io .ComfyNode ):
71
+ @classmethod
72
+ def define_schema (cls ):
73
+ return io .Schema (
74
+ node_id = "ImageBlur" ,
75
+ category = "image/postprocessing" ,
76
+ inputs = [
77
+ io .Image .Input ("image" ),
78
+ io .Int .Input ("blur_radius" , default = 1 , min = 1 , max = 31 , step = 1 ),
79
+ io .Float .Input ("sigma" , default = 1.0 , min = 0.1 , max = 10.0 , step = 0.1 ),
80
+ ],
81
+ outputs = [
82
+ io .Image .Output (),
83
+ ],
84
+ )
77
85
78
86
@classmethod
79
- def INPUT_TYPES (s ):
80
- return {
81
- "required" : {
82
- "image" : ("IMAGE" ,),
83
- "blur_radius" : ("INT" , {
84
- "default" : 1 ,
85
- "min" : 1 ,
86
- "max" : 31 ,
87
- "step" : 1
88
- }),
89
- "sigma" : ("FLOAT" , {
90
- "default" : 1.0 ,
91
- "min" : 0.1 ,
92
- "max" : 10.0 ,
93
- "step" : 0.1
94
- }),
95
- },
96
- }
97
-
98
- RETURN_TYPES = ("IMAGE" ,)
99
- FUNCTION = "blur"
100
-
101
- CATEGORY = "image/postprocessing"
102
-
103
- def blur (self , image : torch .Tensor , blur_radius : int , sigma : float ):
87
+ def execute (cls , image : torch .Tensor , blur_radius : int , sigma : float ) -> io .NodeOutput :
104
88
if blur_radius == 0 :
105
- return (image , )
89
+ return io . NodeOutput (image )
106
90
107
91
image = image .to (comfy .model_management .get_torch_device ())
108
92
batch_size , height , width , channels = image .shape
@@ -115,31 +99,24 @@ def blur(self, image: torch.Tensor, blur_radius: int, sigma: float):
115
99
blurred = F .conv2d (padded_image , kernel , padding = kernel_size // 2 , groups = channels )[:,:,blur_radius :- blur_radius , blur_radius :- blur_radius ]
116
100
blurred = blurred .permute (0 , 2 , 3 , 1 )
117
101
118
- return (blurred .to (comfy .model_management .intermediate_device ()), )
102
+ return io . NodeOutput (blurred .to (comfy .model_management .intermediate_device ()))
119
103
120
- class Quantize :
121
- def __init__ (self ):
122
- pass
123
104
105
+ class Quantize (io .ComfyNode ):
124
106
@classmethod
125
- def INPUT_TYPES (s ):
126
- return {
127
- "required" : {
128
- "image" : ("IMAGE" ,),
129
- "colors" : ("INT" , {
130
- "default" : 256 ,
131
- "min" : 1 ,
132
- "max" : 256 ,
133
- "step" : 1
134
- }),
135
- "dither" : (["none" , "floyd-steinberg" , "bayer-2" , "bayer-4" , "bayer-8" , "bayer-16" ],),
136
- },
137
- }
138
-
139
- RETURN_TYPES = ("IMAGE" ,)
140
- FUNCTION = "quantize"
141
-
142
- CATEGORY = "image/postprocessing"
107
+ def define_schema (cls ):
108
+ return io .Schema (
109
+ node_id = "ImageQuantize" ,
110
+ category = "image/postprocessing" ,
111
+ inputs = [
112
+ io .Image .Input ("image" ),
113
+ io .Int .Input ("colors" , default = 256 , min = 1 , max = 256 , step = 1 ),
114
+ io .Combo .Input ("dither" , options = ["none" , "floyd-steinberg" , "bayer-2" , "bayer-4" , "bayer-8" , "bayer-16" ]),
115
+ ],
116
+ outputs = [
117
+ io .Image .Output (),
118
+ ],
119
+ )
143
120
144
121
@staticmethod
145
122
def bayer (im , pal_im , order ):
@@ -167,7 +144,8 @@ def normalized_bayer_matrix(n):
167
144
im = im .quantize (palette = pal_im , dither = Image .Dither .NONE )
168
145
return im
169
146
170
- def quantize (self , image : torch .Tensor , colors : int , dither : str ):
147
+ @classmethod
148
+ def execute (cls , image : torch .Tensor , colors : int , dither : str ) -> io .NodeOutput :
171
149
batch_size , height , width , _ = image .shape
172
150
result = torch .zeros_like (image )
173
151
@@ -187,46 +165,29 @@ def quantize(self, image: torch.Tensor, colors: int, dither: str):
187
165
quantized_array = torch .tensor (np .array (quantized_image .convert ("RGB" ))).float () / 255
188
166
result [b ] = quantized_array
189
167
190
- return (result , )
168
+ return io . NodeOutput (result )
191
169
192
- class Sharpen :
193
- def __init__ (self ):
194
- pass
170
+ class Sharpen (io .ComfyNode ):
171
+ @classmethod
172
+ def define_schema (cls ):
173
+ return io .Schema (
174
+ node_id = "ImageSharpen" ,
175
+ category = "image/postprocessing" ,
176
+ inputs = [
177
+ io .Image .Input ("image" ),
178
+ io .Int .Input ("sharpen_radius" , default = 1 , min = 1 , max = 31 , step = 1 ),
179
+ io .Float .Input ("sigma" , default = 1.0 , min = 0.1 , max = 10.0 , step = 0.01 ),
180
+ io .Float .Input ("alpha" , default = 1.0 , min = 0.0 , max = 5.0 , step = 0.01 ),
181
+ ],
182
+ outputs = [
183
+ io .Image .Output (),
184
+ ],
185
+ )
195
186
196
187
@classmethod
197
- def INPUT_TYPES (s ):
198
- return {
199
- "required" : {
200
- "image" : ("IMAGE" ,),
201
- "sharpen_radius" : ("INT" , {
202
- "default" : 1 ,
203
- "min" : 1 ,
204
- "max" : 31 ,
205
- "step" : 1
206
- }),
207
- "sigma" : ("FLOAT" , {
208
- "default" : 1.0 ,
209
- "min" : 0.1 ,
210
- "max" : 10.0 ,
211
- "step" : 0.01
212
- }),
213
- "alpha" : ("FLOAT" , {
214
- "default" : 1.0 ,
215
- "min" : 0.0 ,
216
- "max" : 5.0 ,
217
- "step" : 0.01
218
- }),
219
- },
220
- }
221
-
222
- RETURN_TYPES = ("IMAGE" ,)
223
- FUNCTION = "sharpen"
224
-
225
- CATEGORY = "image/postprocessing"
226
-
227
- def sharpen (self , image : torch .Tensor , sharpen_radius : int , sigma :float , alpha : float ):
188
+ def execute (cls , image : torch .Tensor , sharpen_radius : int , sigma :float , alpha : float ) -> io .NodeOutput :
228
189
if sharpen_radius == 0 :
229
- return (image , )
190
+ return io . NodeOutput (image )
230
191
231
192
batch_size , height , width , channels = image .shape
232
193
image = image .to (comfy .model_management .get_torch_device ())
@@ -245,23 +206,29 @@ def sharpen(self, image: torch.Tensor, sharpen_radius: int, sigma:float, alpha:
245
206
246
207
result = torch .clamp (sharpened , 0 , 1 )
247
208
248
- return (result .to (comfy .model_management .intermediate_device ()), )
209
+ return io . NodeOutput (result .to (comfy .model_management .intermediate_device ()))
249
210
250
- class ImageScaleToTotalPixels :
211
+ class ImageScaleToTotalPixels ( io . ComfyNode ) :
251
212
upscale_methods = ["nearest-exact" , "bilinear" , "area" , "bicubic" , "lanczos" ]
252
213
crop_methods = ["disabled" , "center" ]
253
214
254
215
@classmethod
255
- def INPUT_TYPES (s ):
256
- return {"required" : { "image" : ("IMAGE" ,), "upscale_method" : (s .upscale_methods ,),
257
- "megapixels" : ("FLOAT" , {"default" : 1.0 , "min" : 0.01 , "max" : 16.0 , "step" : 0.01 }),
258
- }}
259
- RETURN_TYPES = ("IMAGE" ,)
260
- FUNCTION = "upscale"
216
+ def define_schema (cls ):
217
+ return io .Schema (
218
+ node_id = "ImageScaleToTotalPixels" ,
219
+ category = "image/upscaling" ,
220
+ inputs = [
221
+ io .Image .Input ("image" ),
222
+ io .Combo .Input ("upscale_method" , options = cls .upscale_methods ),
223
+ io .Float .Input ("megapixels" , default = 1.0 , min = 0.01 , max = 16.0 , step = 0.01 ),
224
+ ],
225
+ outputs = [
226
+ io .Image .Output (),
227
+ ],
228
+ )
261
229
262
- CATEGORY = "image/upscaling"
263
-
264
- def upscale (self , image , upscale_method , megapixels ):
230
+ @classmethod
231
+ def execute (cls , image , upscale_method , megapixels ) -> io .NodeOutput :
265
232
samples = image .movedim (- 1 ,1 )
266
233
total = int (megapixels * 1024 * 1024 )
267
234
@@ -271,12 +238,18 @@ def upscale(self, image, upscale_method, megapixels):
271
238
272
239
s = comfy .utils .common_upscale (samples , width , height , upscale_method , "disabled" )
273
240
s = s .movedim (1 ,- 1 )
274
- return (s ,)
275
-
276
- NODE_CLASS_MAPPINGS = {
277
- "ImageBlend" : Blend ,
278
- "ImageBlur" : Blur ,
279
- "ImageQuantize" : Quantize ,
280
- "ImageSharpen" : Sharpen ,
281
- "ImageScaleToTotalPixels" : ImageScaleToTotalPixels ,
282
- }
241
+ return io .NodeOutput (s )
242
+
243
+ class PostProcessingExtension (ComfyExtension ):
244
+ @override
245
+ async def get_node_list (self ) -> list [type [io .ComfyNode ]]:
246
+ return [
247
+ Blend ,
248
+ Blur ,
249
+ Quantize ,
250
+ Sharpen ,
251
+ ImageScaleToTotalPixels ,
252
+ ]
253
+
254
+ async def comfy_entrypoint () -> PostProcessingExtension :
255
+ return PostProcessingExtension ()
0 commit comments