11import torch
2+ import torch .nn .functional as F
23
34class Blend :
45 def __init__ (self ):
@@ -26,6 +27,9 @@ def INPUT_TYPES(s):
2627 CATEGORY = "postprocessing"
2728
2829 def blend_images (self , image1 : torch .Tensor , image2 : torch .Tensor , blend_factor : float , blend_mode : str ):
30+ if image1 .shape != image2 .shape :
31+ image2 = self .crop_and_resize (image2 , image1 .shape )
32+
2933 blended_image = self .blend_mode (image1 , image2 , blend_mode )
3034 blended_image = image1 * (1 - blend_factor ) + blended_image * blend_factor
3135 blended_image = torch .clamp (blended_image , 0 , 1 )
@@ -48,6 +52,29 @@ def blend_mode(self, img1, img2, mode):
4852 def g (self , x ):
4953 return torch .where (x <= 0.25 , ((16 * x - 12 ) * x + 4 ) * x , torch .sqrt (x ))
5054
55+ def crop_and_resize (self , img : torch .Tensor , target_shape : tuple ):
56+ batch_size , img_h , img_w , img_c = img .shape
57+ _ , target_h , target_w , _ = target_shape
58+ img_aspect_ratio = img_w / img_h
59+ target_aspect_ratio = target_w / target_h
60+
61+ # Crop center of the image to the target aspect ratio
62+ if img_aspect_ratio > target_aspect_ratio :
63+ new_width = int (img_h * target_aspect_ratio )
64+ left = (img_w - new_width ) // 2
65+ img = img [:, :, left :left + new_width , :]
66+ else :
67+ new_height = int (img_w / target_aspect_ratio )
68+ top = (img_h - new_height ) // 2
69+ img = img [:, top :top + new_height , :, :]
70+
71+ # Resize to target size
72+ img = img .permute (0 , 3 , 1 , 2 ) # Torch wants (B, C, H, W) we use (B, H, W, C)
73+ img = F .interpolate (img , size = (target_h , target_w ), mode = 'bilinear' , align_corners = False )
74+ img = img .permute (0 , 2 , 3 , 1 )
75+
76+ return img
77+
5178NODE_CLASS_MAPPINGS = {
5279 "Blend" : Blend ,
5380}
0 commit comments