@@ -1204,32 +1204,37 @@ class AILab_ImageStitch:
12041204 def INPUT_TYPES (s ):
12051205 tooltips = {
12061206 "image1" : "First image to stitch" ,
1207- "direction " : "Direction to stitch the second image " ,
1207+ "stitch_mode " : "Mode for stitching images together " ,
12081208 "match_image_size" : "If True, resize image2 to match image1's aspect ratio" ,
1209- "max_width" : "Maximum width of output image (0 = no limit)" ,
1210- "max_height" : "Maximum height of output image (0 = no limit)" ,
1209+ "megapixels" : "Target megapixels for final output (0 = no limit, overrides max_width/max_height)" ,
1210+ "max_width" : "Maximum width of output image (0 = no limit, ignored if megapixels > 0)" ,
1211+ "max_height" : "Maximum height of output image (0 = no limit, ignored if megapixels > 0)" ,
1212+ "upscale_method" : "Upscaling method for all resize operations" ,
12111213 "spacing_width" : "Width of spacing between images" ,
12121214 "background_color" : "Color for spacing between images and padding background" ,
12131215 "kontext_mode" : "Special mode that arranges 3 images in a specific layout (image1 and image2 stacked vertically, image3 on the right)"
12141216 }
1215-
12161217 return {
12171218 "required" : {
12181219 "image1" : ("IMAGE" ,),
1219- "direction " : (["right" , "down" , "left" , "up" , "kontext_mode" ], {"default" : "right" , "tooltip" : tooltips ["direction " ]}),
1220+ "stitch_mode " : (["right" , "down" , "left" , "up" , "2x2" , " kontext_mode" ], {"default" : "right" , "tooltip" : tooltips ["stitch_mode " ]}),
12201221 "match_image_size" : ("BOOLEAN" , {"default" : True , "tooltip" : tooltips ["match_image_size" ]}),
1222+ "megapixels" : ("FLOAT" , {"default" : 0.0 , "min" : 0.0 , "max" : 16.0 , "step" : 0.01 , "tooltip" : tooltips ["megapixels" ]}),
12211223 "max_width" : ("INT" , {"default" : 0 , "min" : 0 , "max" : MAX_RESOLUTION , "step" : 8 , "tooltip" : tooltips ["max_width" ]}),
12221224 "max_height" : ("INT" , {"default" : 0 , "min" : 0 , "max" : MAX_RESOLUTION , "step" : 8 , "tooltip" : tooltips ["max_height" ]}),
1225+ "upscale_method" : (["nearest-exact" , "bilinear" , "area" , "bicubic" , "lanczos" ], {"default" : "lanczos" , "tooltip" : tooltips ["upscale_method" ]}),
12231226 "spacing_width" : ("INT" , {"default" : 0 , "min" : 0 , "max" : 512 , "step" : 1 , "tooltip" : tooltips ["spacing_width" ]}),
12241227 "background_color" : ("COLOR" , {"default" : "#FFFFFF" , "tooltip" : tooltips ["background_color" ]}),
12251228 },
12261229 "optional" : {
12271230 "image2" : ("IMAGE" ,),
12281231 "image3" : ("IMAGE" ,),
1232+ "image4" : ("IMAGE" ,),
12291233 },
12301234 }
12311235
1232- RETURN_TYPES = ("IMAGE" ,)
1236+ RETURN_TYPES = ("IMAGE" , "INT" , "INT" )
1237+ RETURN_NAMES = ("IMAGE" , "WIDTH" , "HEIGHT" )
12331238 FUNCTION = "stitch"
12341239 CATEGORY = "🧪AILab/🖼️IMAGE"
12351240
@@ -1241,33 +1246,25 @@ def hex_to_rgb(self, hex_color):
12411246 return (r , g , b )
12421247
12431248 def pad_with_color (self , image , padding , color_val ):
1244- """Pad image with specified color"""
12451249 batch , height , width , channels = image .shape
12461250 r , g , b = color_val
1247-
12481251 pad_top , pad_bottom , pad_left , pad_right = padding
1249-
12501252 new_height = height + pad_top + pad_bottom
12511253 new_width = width + pad_left + pad_right
1252-
12531254 result = torch .zeros ((batch , new_height , new_width , channels ), device = image .device )
1254-
12551255 if channels >= 3 :
12561256 result [..., 0 ] = r
12571257 result [..., 1 ] = g
12581258 result [..., 2 ] = b
12591259 if channels == 4 :
12601260 result [..., 3 ] = 1.0
1261-
12621261 result [:, pad_top :pad_top + height , pad_left :pad_left + width , :] = image
1263-
12641262 return result
12651263
1266- def match_dimensions (self , image1 , image2 , direction , color_val ):
1264+ def match_dimensions (self , image1 , image2 , stitch_mode , color_val ):
12671265 h1 , w1 = image1 .shape [1 :3 ]
12681266 h2 , w2 = image2 .shape [1 :3 ]
1269-
1270- if direction in ["left" , "right" ]:
1267+ if stitch_mode in ["left" , "right" ]:
12711268 if h1 != h2 :
12721269 target_h = max (h1 , h2 )
12731270 if h1 < target_h :
@@ -1278,7 +1275,7 @@ def match_dimensions(self, image1, image2, direction, color_val):
12781275 pad_h = target_h - h2
12791276 pad_top , pad_bottom = pad_h // 2 , pad_h - pad_h // 2
12801277 image2 = self .pad_with_color (image2 , (pad_top , pad_bottom , 0 , 0 ), color_val )
1281- else :
1278+ else :
12821279 if w1 != w2 :
12831280 target_w = max (w1 , w2 )
12841281 if w1 < target_w :
@@ -1289,7 +1286,6 @@ def match_dimensions(self, image1, image2, direction, color_val):
12891286 pad_w = target_w - w2
12901287 pad_left , pad_right = pad_w // 2 , pad_w - pad_w // 2
12911288 image2 = self .pad_with_color (image2 , (0 , 0 , pad_left , pad_right ), color_val )
1292-
12931289 return image1 , image2
12941290
12951291 def ensure_same_channels (self , image1 , image2 ):
@@ -1307,13 +1303,11 @@ def ensure_same_channels(self, image1, image2):
13071303 ], dim = - 1 )
13081304 return image1 , image2
13091305
1310- def create_spacing (self , image1 , image2 , spacing_width , direction , color_val ):
1306+ def create_spacing (self , image1 , image2 , spacing_width , stitch_mode , color_val ):
13111307 if spacing_width <= 0 :
13121308 return None
1313-
13141309 spacing_width = spacing_width + (spacing_width % 2 )
1315-
1316- if direction in ["left" , "right" ]:
1310+ if stitch_mode in ["left" , "right" ]:
13171311 spacing_shape = (
13181312 image1 .shape [0 ],
13191313 max (image1 .shape [1 ], image2 .shape [1 ]),
@@ -1327,88 +1321,86 @@ def create_spacing(self, image1, image2, spacing_width, direction, color_val):
13271321 max (image1 .shape [2 ], image2 .shape [2 ]),
13281322 image1 .shape [- 1 ],
13291323 )
1330-
13311324 spacing = torch .zeros (spacing_shape , device = image1 .device )
1332-
13331325 r , g , b = color_val
13341326 if spacing .shape [- 1 ] >= 3 :
13351327 spacing [..., 0 ] = r
13361328 spacing [..., 1 ] = g
13371329 spacing [..., 2 ] = b
13381330 if spacing .shape [- 1 ] == 4 :
13391331 spacing [..., 3 ] = 1.0
1340-
13411332 return spacing
13421333
1343- def stitch_kontext_mode (self , image1 , image2 , image3 , match_image_size , spacing_width , color_val ):
1344- if image1 is None or image2 is None or image3 is None :
1345- if image3 is None :
1346- return self .stitch_two_images (image1 , image2 , "down" , match_image_size , spacing_width , color_val )
1347- elif image2 is None :
1348- return self .stitch_two_images (image1 , image3 , "right" , match_image_size , spacing_width , color_val )
1334+ def stitch_kontext_mode (self , image1 , image2 , image3 , match_image_size , spacing_width , color_val , upscale_method , image4 = None ):
1335+ has_image4 = image4 is not None
1336+ if image1 is None or image2 is None :
1337+ if image2 is None and image3 is not None :
1338+ return self .stitch_two_images (image1 , image3 , "right" , match_image_size , spacing_width , color_val , upscale_method )
13491339 else :
13501340 return image1
1351-
1352- max_batch = max (image1 .shape [0 ], image2 .shape [0 ], image3 .shape [0 ])
1353- if image1 .shape [0 ] < max_batch :
1354- image1 = torch .cat ([image1 , image1 [- 1 :].repeat (max_batch - image1 .shape [0 ], 1 , 1 , 1 )])
1355- if image2 .shape [0 ] < max_batch :
1356- image2 = torch .cat ([image2 , image2 [- 1 :].repeat (max_batch - image2 .shape [0 ], 1 , 1 , 1 )])
1357- if image3 .shape [0 ] < max_batch :
1358- image3 = torch .cat ([image3 , image3 [- 1 :].repeat (max_batch - image3 .shape [0 ], 1 , 1 , 1 )])
1359-
1341+ images_to_align = [image1 , image2 ]
1342+ if image3 is not None :
1343+ images_to_align .append (image3 )
1344+ if has_image4 :
1345+ images_to_align .append (image4 )
1346+ max_batch = max (img .shape [0 ] for img in images_to_align )
1347+ for i , img in enumerate (images_to_align ):
1348+ if img .shape [0 ] < max_batch :
1349+ images_to_align [i ] = torch .cat ([img , img [- 1 :].repeat (max_batch - img .shape [0 ], 1 , 1 , 1 )])
1350+ image1 , image2 = images_to_align [0 ], images_to_align [1 ]
1351+ image3 = images_to_align [2 ] if len (images_to_align ) > 2 else None
1352+ image4 = images_to_align [3 ] if len (images_to_align ) > 3 else None
1353+ if has_image4 :
1354+ left_images = [image1 , image2 , image3 ]
1355+ right_image = image4
1356+ else :
1357+ left_images = [image1 , image2 ]
1358+ right_image = image3
13601359 if match_image_size :
13611360 w1 = image1 .shape [2 ]
1362- h2 , w2 = image2 . shape [1 :3 ]
1363- aspect_ratio = h2 / w2
1364- target_w = w1
1365- target_h = int ( w1 * aspect_ratio )
1366-
1367- image2 = common_upscale (
1368- image2 .movedim (- 1 , 1 ), target_w , target_h , "lanczos" , "disabled"
1369- ).movedim (1 , - 1 )
1361+ for i , img in enumerate ( left_images [1 :], 1 ):
1362+ h , w = img . shape [ 1 : 3 ]
1363+ aspect_ratio = h / w
1364+ target_w = w1
1365+ target_h = int ( w1 * aspect_ratio )
1366+ left_images [ i ] = common_upscale (
1367+ img .movedim (- 1 , 1 ), target_w , target_h , upscale_method , "disabled"
1368+ ).movedim (1 , - 1 )
13701369 else :
1371- image1 , image2 = self .match_dimensions (image1 , image2 , "down" , color_val )
1372-
1373- image1 , image2 = self .ensure_same_channels (image1 , image2 )
1374-
1375- v_spacing = self .create_spacing (image1 , image2 , spacing_width , "down" , color_val )
1376-
1377- v_images = [image1 , image2 ]
1378- if v_spacing is not None :
1379- v_images .insert (1 , v_spacing )
1380-
1381- left_column = torch .cat (v_images , dim = 1 )
1382-
1370+ for i in range (1 , len (left_images )):
1371+ left_images [0 ], left_images [i ] = self .match_dimensions (left_images [0 ], left_images [i ], "down" , color_val )
1372+ for i in range (1 , len (left_images )):
1373+ left_images [0 ], left_images [i ] = self .ensure_same_channels (left_images [0 ], left_images [i ])
1374+ left_column_parts = [left_images [0 ]]
1375+ for i in range (1 , len (left_images )):
1376+ spacing = self .create_spacing (left_images [i - 1 ], left_images [i ], spacing_width , "down" , color_val )
1377+ if spacing is not None :
1378+ left_column_parts .append (spacing )
1379+ left_column_parts .append (left_images [i ])
1380+ left_column = torch .cat (left_column_parts , dim = 1 )
13831381 if match_image_size :
13841382 h_left = left_column .shape [1 ]
1385- h3 , w3 = image3 .shape [1 :3 ]
1386- aspect_ratio = w3 / h3
1383+ hr , wr = right_image .shape [1 :3 ]
1384+ aspect_ratio = wr / hr
13871385 target_h = h_left
13881386 target_w = int (h_left * aspect_ratio )
1389-
1390- image3 = common_upscale (
1391- image3 .movedim (- 1 , 1 ), target_w , target_h , "lanczos" , "disabled"
1387+ right_image = common_upscale (
1388+ right_image .movedim (- 1 , 1 ), target_w , target_h , upscale_method , "disabled"
13921389 ).movedim (1 , - 1 )
13931390 else :
1394- left_column , image3 = self .match_dimensions (left_column , image3 , "right" , color_val )
1395-
1396- left_column , image3 = self .ensure_same_channels (left_column , image3 )
1397-
1398- h_spacing = self .create_spacing (left_column , image3 , spacing_width , "right" , color_val )
1399-
1400- h_images = [left_column , image3 ]
1391+ left_column , right_image = self .match_dimensions (left_column , right_image , "right" , color_val )
1392+ left_column , right_image = self .ensure_same_channels (left_column , right_image )
1393+ h_spacing = self .create_spacing (left_column , right_image , spacing_width , "right" , color_val )
1394+ h_images = [left_column ]
14011395 if h_spacing is not None :
1402- h_images .insert ( 1 , h_spacing )
1403-
1396+ h_images .append ( h_spacing )
1397+ h_images . append ( right_image )
14041398 result = torch .cat (h_images , dim = 2 )
1405-
14061399 return result
14071400
1408- def stitch_two_images (self , image1 , image2 , direction , match_image_size , spacing_width , color_val ):
1401+ def stitch_two_images (self , image1 , image2 , stitch_mode , match_image_size , spacing_width , color_val , upscale_method ):
14091402 if image2 is None :
14101403 return image1
1411-
14121404 if image1 .shape [0 ] != image2 .shape [0 ]:
14131405 max_batch = max (image1 .shape [0 ], image2 .shape [0 ])
14141406 if image1 .shape [0 ] < max_batch :
@@ -1419,72 +1411,99 @@ def stitch_two_images(self, image1, image2, direction, match_image_size, spacing
14191411 image2 = torch .cat (
14201412 [image2 , image2 [- 1 :].repeat (max_batch - image2 .shape [0 ], 1 , 1 , 1 )]
14211413 )
1422-
14231414 if match_image_size :
14241415 h1 , w1 = image1 .shape [1 :3 ]
14251416 h2 , w2 = image2 .shape [1 :3 ]
14261417 aspect_ratio = w2 / h2
1427-
1428- if direction in ["left" , "right" ]:
1418+ if stitch_mode in ["left" , "right" ]:
14291419 target_h , target_w = h1 , int (h1 * aspect_ratio )
14301420 else :
14311421 target_w , target_h = w1 , int (w1 / aspect_ratio )
1432-
14331422 image2 = common_upscale (
1434- image2 .movedim (- 1 , 1 ), target_w , target_h , "lanczos" , "disabled"
1423+ image2 .movedim (- 1 , 1 ), target_w , target_h , upscale_method , "disabled"
14351424 ).movedim (1 , - 1 )
14361425 else :
1437- image1 , image2 = self .match_dimensions (image1 , image2 , direction , color_val )
1438-
1426+ image1 , image2 = self .match_dimensions (image1 , image2 , stitch_mode , color_val )
14391427 image1 , image2 = self .ensure_same_channels (image1 , image2 )
1440-
1441- spacing = self .create_spacing (image1 , image2 , spacing_width , direction , color_val )
1442-
1443- images = [image2 , image1 ] if direction in ["left" , "up" ] else [image1 , image2 ]
1428+ spacing = self .create_spacing (image1 , image2 , spacing_width , stitch_mode , color_val )
1429+ images = [image2 , image1 ] if stitch_mode in ["left" , "up" ] else [image1 , image2 ]
14441430 if spacing is not None :
14451431 images .insert (1 , spacing )
1446-
1447- concat_dim = 2 if direction in ["left" , "right" ] else 1
1432+ concat_dim = 2 if stitch_mode in ["left" , "right" ] else 1
14481433 result = torch .cat (images , dim = concat_dim )
1449-
14501434 return result
14511435
1452- def stitch (self , image1 , direction , match_image_size , max_width , max_height , spacing_width , background_color , image2 = None , image3 = None ,):
1436+ def create_blank_like (self , reference_image , color_val ):
1437+ batch , height , width , channels = reference_image .shape
1438+ result = torch .zeros ((batch , height , width , channels ), device = reference_image .device )
1439+ r , g , b = color_val
1440+ if channels >= 3 :
1441+ result [..., 0 ] = r
1442+ result [..., 1 ] = g
1443+ result [..., 2 ] = b
1444+ if channels == 4 :
1445+ result [..., 3 ] = 1.0
1446+ return result
1447+
1448+ def stitch_multi_mode (self , image1 , image2 , image3 , image4 , stitch_mode , match_image_size , spacing_width , color_val , upscale_method ):
1449+ images = [image for image in [image1 , image2 , image3 , image4 ] if image is not None ]
1450+ if len (images ) == 0 :
1451+ return torch .zeros ((1 , 64 , 64 , 3 ))
1452+ if len (images ) == 1 :
1453+ return images [0 ]
1454+ current = images [0 ]
1455+ for next_img in images [1 :]:
1456+ current = self .stitch_two_images (current , next_img , stitch_mode , match_image_size , spacing_width , color_val , upscale_method )
1457+ return current
1458+
1459+ def stitch_grid_2x2 (self , image1 , image2 , image3 , image4 , match_image_size , spacing_width , color_val , upscale_method ):
1460+ ref = image1
1461+ img2 = image2 if image2 is not None else self .create_blank_like (ref , color_val )
1462+ row1 = self .stitch_two_images (ref , img2 , "right" , match_image_size , spacing_width , color_val , upscale_method )
1463+ img3 = image3 if image3 is not None else self .create_blank_like (ref , color_val )
1464+ img4 = image4 if image4 is not None else self .create_blank_like (ref , color_val )
1465+ row2 = self .stitch_two_images (img3 , img4 , "right" , match_image_size , spacing_width , color_val , upscale_method )
1466+ result = self .stitch_two_images (row1 , row2 , "down" , match_image_size , spacing_width , color_val , upscale_method )
1467+ return result
14531468
1469+ def stitch (self , image1 , stitch_mode , match_image_size , megapixels , max_width , max_height , upscale_method , spacing_width , background_color , image2 = None , image3 = None , image4 = None ,):
14541470 if image1 is None :
14551471 return (torch .zeros ((1 , 64 , 64 , 3 )),)
1456-
14571472 color_val = self .hex_to_rgb (background_color )
1458-
1459- if direction == "kontext_mode" :
1460- result = self .stitch_kontext_mode (image1 , image2 , image3 , match_image_size , spacing_width , color_val )
1473+ if stitch_mode == "kontext_mode" :
1474+ result = self .stitch_kontext_mode (image1 , image2 , image3 , match_image_size , spacing_width , color_val , upscale_method , image4 = image4 )
1475+ elif stitch_mode == "2x2" :
1476+ result = self .stitch_grid_2x2 (image1 , image2 , image3 , image4 , match_image_size , spacing_width , color_val , upscale_method )
14611477 else :
1462- result = self .stitch_two_images (image1 , image2 , direction , match_image_size , spacing_width , color_val )
1463-
1464- if max_width > 0 or max_height > 0 :
1465- h , w = result .shape [1 :3 ]
1466- need_resize = False
1467-
1478+ result = self .stitch_multi_mode (image1 , image2 , image3 , image4 , stitch_mode , match_image_size , spacing_width , color_val , upscale_method )
1479+ h , w = result .shape [1 :3 ]
1480+ need_resize = False
1481+ target_w , target_h = w , h
1482+ if megapixels > 0 :
1483+ aspect_ratio = w / h
1484+ target_pixels = int (megapixels * 1024 * 1024 )
1485+ target_h = int ((target_pixels / aspect_ratio ) ** 0.5 )
1486+ target_w = int (aspect_ratio * target_h )
1487+ need_resize = True
1488+ elif max_width > 0 or max_height > 0 :
14681489 if max_width > 0 and w > max_width :
14691490 scale_factor = max_width / w
14701491 target_w = max_width
14711492 target_h = int (h * scale_factor )
14721493 need_resize = True
14731494 else :
14741495 target_w , target_h = w , h
1475-
14761496 if max_height > 0 and (target_h > max_height or (target_h == h and h > max_height )):
14771497 scale_factor = max_height / target_h
14781498 target_h = max_height
14791499 target_w = int (target_w * scale_factor )
14801500 need_resize = True
1481-
1482- if need_resize :
1483- result = common_upscale (
1484- result .movedim (- 1 , 1 ), target_w , target_h , "lanczos" , "disabled"
1485- ).movedim (1 , - 1 )
1486-
1487- return (result ,)
1501+ if need_resize :
1502+ result = common_upscale (
1503+ result .movedim (- 1 , 1 ), target_w , target_h , upscale_method , "disabled"
1504+ ).movedim (1 , - 1 )
1505+ final_height , final_width = result .shape [1 :3 ]
1506+ return (result , final_width , final_height )
14881507
14891508# Image Crop node
14901509class AILab_ImageCrop :
@@ -2156,4 +2175,4 @@ def resize(self, image, width, height, scale_by, upscale_method, resize_mode, pa
21562175 "AILab_ImageCompare" : "Image Compare (RMBG) 🖼️🖼️" ,
21572176 "AILab_ColorInput" : "Color Input (RMBG) 🎨" ,
21582177 "AILab_ImageMaskResize" : "Image Mask Resize (RMBG) 🖼️🎭"
2159- }
2178+ }
0 commit comments