@@ -72,6 +72,9 @@ def INPUT_TYPES(s):
7272 CATEGORY = "postprocessing"
7373
7474 def blend_images (self , image1 : torch .Tensor , image2 : torch .Tensor , blend_factor : float , blend_mode : str ):
75+ if image1 .shape != image2 .shape :
76+ image2 = self .crop_and_resize (image2 , image1 .shape )
77+
7578 blended_image = self .blend_mode (image1 , image2 , blend_mode )
7679 blended_image = image1 * (1 - blend_factor ) + blended_image * blend_factor
7780 blended_image = torch .clamp (blended_image , 0 , 1 )
@@ -94,6 +97,29 @@ def blend_mode(self, img1, img2, mode):
9497 def g (self , x ):
9598 return torch .where (x <= 0.25 , ((16 * x - 12 ) * x + 4 ) * x , torch .sqrt (x ))
9699
100+ def crop_and_resize (self , img : torch .Tensor , target_shape : tuple ):
101+ batch_size , img_h , img_w , img_c = img .shape
102+ _ , target_h , target_w , _ = target_shape
103+ img_aspect_ratio = img_w / img_h
104+ target_aspect_ratio = target_w / target_h
105+
106+ # Crop center of the image to the target aspect ratio
107+ if img_aspect_ratio > target_aspect_ratio :
108+ new_width = int (img_h * target_aspect_ratio )
109+ left = (img_w - new_width ) // 2
110+ img = img [:, :, left :left + new_width , :]
111+ else :
112+ new_height = int (img_w / target_aspect_ratio )
113+ top = (img_h - new_height ) // 2
114+ img = img [:, top :top + new_height , :, :]
115+
116+ # Resize to target size
117+ img = img .permute (0 , 3 , 1 , 2 ) # Torch wants (B, C, H, W) we use (B, H, W, C)
118+ img = F .interpolate (img , size = (target_h , target_w ), mode = 'bilinear' , align_corners = False )
119+ img = img .permute (0 , 2 , 3 , 1 )
120+
121+ return img
122+
97123class Blur :
98124 def __init__ (self ):
99125 pass
@@ -124,7 +150,7 @@ def INPUT_TYPES(s):
124150 CATEGORY = "postprocessing"
125151
126152 def gaussian_kernel (self , kernel_size : int , sigma : float ):
127- x , y = torch .meshgrid (torch .linspace (- 1 , 1 , kernel_size ), torch .linspace (- 1 , 1 , kernel_size ))
153+ x , y = torch .meshgrid (torch .linspace (- 1 , 1 , kernel_size ), torch .linspace (- 1 , 1 , kernel_size ), indexing = "ij" )
128154 d = torch .sqrt (x * x + y * y )
129155 g = torch .exp (- (d * d ) / (2.0 * sigma * sigma ))
130156 return g / g .sum ()
@@ -324,63 +350,6 @@ def dissolve_images(self, image1: torch.Tensor, image2: torch.Tensor, dissolve_f
324350 dissolved_image = torch .clamp (dissolved_image , 0 , 1 )
325351 return (dissolved_image ,)
326352
327- class Dither :
328- def __init__ (self ):
329- pass
330-
331- @classmethod
332- def INPUT_TYPES (s ):
333- return {
334- "required" : {
335- "image" : ("IMAGE" ,),
336- "bits" : ("INT" , {
337- "default" : 4 ,
338- "min" : 1 ,
339- "max" : 8 ,
340- "step" : 1
341- }),
342- },
343- }
344-
345- RETURN_TYPES = ("IMAGE" ,)
346- FUNCTION = "dither"
347-
348- CATEGORY = "postprocessing"
349-
350- def dither (self , image : torch .Tensor , bits : int ):
351- batch_size , height , width , _ = image .shape
352- result = torch .zeros_like (image )
353-
354- for b in range (batch_size ):
355- tensor_image = image [b ]
356- img = (tensor_image * 255 )
357- height , width , _ = img .shape
358-
359- scale = 255 / (2 ** bits - 1 )
360-
361- for y in range (height ):
362- for x in range (width ):
363- old_pixel = img [y , x ].clone ()
364- new_pixel = torch .round (old_pixel / scale ) * scale
365- img [y , x ] = new_pixel
366-
367- quant_error = old_pixel - new_pixel
368-
369- if x + 1 < width :
370- img [y , x + 1 ] += quant_error * 7 / 16
371- if y + 1 < height :
372- if x - 1 >= 0 :
373- img [y + 1 , x - 1 ] += quant_error * 3 / 16
374- img [y + 1 , x ] += quant_error * 5 / 16
375- if x + 1 < width :
376- img [y + 1 , x + 1 ] += quant_error * 1 / 16
377-
378- dithered = img / 255
379- tensor = dithered .unsqueeze (0 )
380- result [b ] = tensor
381-
382- return (result ,)
383-
384353class DodgeAndBurn :
385354 def __init__ (self ):
386355 pass
@@ -645,62 +614,6 @@ def gaussian_blur(self, image: torch.Tensor, kernel_size: int):
645614 def add_glow (self , img , blurred_img , intensity ):
646615 return img + blurred_img * intensity
647616
648- class KMeansQuantize :
649- def __init__ (self ):
650- pass
651-
652- @classmethod
653- def INPUT_TYPES (s ):
654- return {
655- "required" : {
656- "image" : ("IMAGE" ,),
657- "colors" : ("INT" , {
658- "default" : 16 ,
659- "min" : 1 ,
660- "max" : 256 ,
661- "step" : 1
662- }),
663- "precision" : ("INT" , {
664- "default" : 10 ,
665- "min" : 1 ,
666- "max" : 100 ,
667- "step" : 1
668- }),
669- },
670- }
671-
672- RETURN_TYPES = ("IMAGE" ,)
673- FUNCTION = "kmeans_quantize"
674-
675- CATEGORY = "postprocessing"
676-
677- def kmeans_quantize (self , image : torch .Tensor , colors : int , precision : int ):
678- batch_size , height , width , _ = image .shape
679- result = torch .zeros_like (image )
680-
681- for b in range (batch_size ):
682- tensor_image = image [b ].numpy ().astype (np .float32 )
683- img = tensor_image
684-
685- height , width , c = img .shape
686-
687- criteria = (
688- cv2 .TERM_CRITERIA_EPS + cv2 .TERM_CRITERIA_MAX_ITER ,
689- precision * 5 , 0.01
690- )
691-
692- img_copy = img .reshape (- 1 , c )
693- _ , label , center = cv2 .kmeans (
694- img_copy , colors , None ,
695- criteria , 1 , cv2 .KMEANS_PP_CENTERS
696- )
697-
698- img = center [label .flatten ()].reshape (* img .shape )
699- tensor = torch .from_numpy (img ).unsqueeze (0 )
700- result [b ] = tensor
701-
702- return (result ,)
703-
704617class PixelSort :
705618 def __init__ (self ):
706619 pass
@@ -785,6 +698,49 @@ def pixelize_image(self, image: torch.Tensor, pixel_size: int):
785698
786699 return image
787700
701+ class Quantize :
702+ def __init__ (self ):
703+ pass
704+
705+ @classmethod
706+ def INPUT_TYPES (s ):
707+ return {
708+ "required" : {
709+ "image" : ("IMAGE" ,),
710+ "colors" : ("INT" , {
711+ "default" : 256 ,
712+ "min" : 1 ,
713+ "max" : 256 ,
714+ "step" : 1
715+ }),
716+ "dither" : (["none" , "floyd-steinberg" ],),
717+ },
718+ }
719+
720+ RETURN_TYPES = ("IMAGE" ,)
721+ FUNCTION = "quantize"
722+
723+ CATEGORY = "postprocessing"
724+
725+ def quantize (self , image : torch .Tensor , colors : int = 256 , dither : str = "FLOYDSTEINBERG" ):
726+ batch_size , height , width , _ = image .shape
727+ result = torch .zeros_like (image )
728+
729+ dither_option = Image .Dither .FLOYDSTEINBERG if dither == "floyd-steinberg" else Image .Dither .NONE
730+
731+ for b in range (batch_size ):
732+ tensor_image = image [b ]
733+ img = (tensor_image * 255 ).to (torch .uint8 ).numpy ()
734+ pil_image = Image .fromarray (img , mode = 'RGB' )
735+
736+ palette = pil_image .quantize (colors = colors ) # Required as described in https://github.com/python-pillow/Pillow/issues/5836
737+ quantized_image = pil_image .quantize (colors = colors , palette = palette , dither = dither_option )
738+
739+ quantized_array = torch .tensor (np .array (quantized_image .convert ("RGB" ))).float () / 255
740+ result [b ] = quantized_array
741+
742+ return (result ,)
743+
788744class Sharpen :
789745 def __init__ (self ):
790746 pass
@@ -961,13 +917,12 @@ def pixel_sort(img, mask, horizontal_sort=False, span_limit=None, sort_by='H', r
961917 "CannyEdgeDetection" : CannyEdgeDetection ,
962918 "ColorCorrect" : ColorCorrect ,
963919 "Dissolve" : Dissolve ,
964- "Dither" : Dither ,
965920 "DodgeAndBurn" : DodgeAndBurn ,
966921 "FilmGrain" : FilmGrain ,
967922 "Glow" : Glow ,
968- "KMeansQuantize" : KMeansQuantize ,
969923 "PixelSort" : PixelSort ,
970924 "Pixelize" : Pixelize ,
925+ "Quantize" : Quantize ,
971926 "Sharpen" : Sharpen ,
972927 "Solarize" : Solarize ,
973928}
0 commit comments