Skip to content

Commit 8126289

Browse files
authored
Update AILab_ImageMaskTools.py
1 parent 4862917 commit 8126289

File tree

1 file changed

+130
-111
lines changed

1 file changed

+130
-111
lines changed

AILab_ImageMaskTools.py

Lines changed: 130 additions & 111 deletions
Original file line numberDiff line numberDiff line change
@@ -1204,32 +1204,37 @@ class AILab_ImageStitch:
12041204
def INPUT_TYPES(s):
12051205
tooltips = {
12061206
"image1": "First image to stitch",
1207-
"direction": "Direction to stitch the second image",
1207+
"stitch_mode": "Mode for stitching images together",
12081208
"match_image_size": "If True, resize image2 to match image1's aspect ratio",
1209-
"max_width": "Maximum width of output image (0 = no limit)",
1210-
"max_height": "Maximum height of output image (0 = no limit)",
1209+
"megapixels": "Target megapixels for final output (0 = no limit, overrides max_width/max_height)",
1210+
"max_width": "Maximum width of output image (0 = no limit, ignored if megapixels > 0)",
1211+
"max_height": "Maximum height of output image (0 = no limit, ignored if megapixels > 0)",
1212+
"upscale_method": "Upscaling method for all resize operations",
12111213
"spacing_width": "Width of spacing between images",
12121214
"background_color": "Color for spacing between images and padding background",
12131215
"kontext_mode": "Special mode that arranges 3 images in a specific layout (image1 and image2 stacked vertically, image3 on the right)"
12141216
}
1215-
12161217
return {
12171218
"required": {
12181219
"image1": ("IMAGE",),
1219-
"direction": (["right", "down", "left", "up", "kontext_mode"], {"default": "right", "tooltip": tooltips["direction"]}),
1220+
"stitch_mode": (["right", "down", "left", "up", "2x2", "kontext_mode"], {"default": "right", "tooltip": tooltips["stitch_mode"]}),
12201221
"match_image_size": ("BOOLEAN", {"default": True, "tooltip": tooltips["match_image_size"]}),
1222+
"megapixels": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 16.0, "step": 0.01, "tooltip": tooltips["megapixels"]}),
12211223
"max_width": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8, "tooltip": tooltips["max_width"]}),
12221224
"max_height": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8, "tooltip": tooltips["max_height"]}),
1225+
"upscale_method": (["nearest-exact", "bilinear", "area", "bicubic", "lanczos"], {"default": "lanczos", "tooltip": tooltips["upscale_method"]}),
12231226
"spacing_width": ("INT", {"default": 0, "min": 0, "max": 512, "step": 1, "tooltip": tooltips["spacing_width"]}),
12241227
"background_color": ("COLOR", {"default": "#FFFFFF", "tooltip": tooltips["background_color"]}),
12251228
},
12261229
"optional": {
12271230
"image2": ("IMAGE",),
12281231
"image3": ("IMAGE",),
1232+
"image4": ("IMAGE",),
12291233
},
12301234
}
12311235

1232-
RETURN_TYPES = ("IMAGE",)
1236+
RETURN_TYPES = ("IMAGE", "INT", "INT")
1237+
RETURN_NAMES = ("IMAGE", "WIDTH", "HEIGHT")
12331238
FUNCTION = "stitch"
12341239
CATEGORY = "🧪AILab/🖼️IMAGE"
12351240

@@ -1241,33 +1246,25 @@ def hex_to_rgb(self, hex_color):
12411246
return (r, g, b)
12421247

12431248
def pad_with_color(self, image, padding, color_val):
1244-
"""Pad image with specified color"""
12451249
batch, height, width, channels = image.shape
12461250
r, g, b = color_val
1247-
12481251
pad_top, pad_bottom, pad_left, pad_right = padding
1249-
12501252
new_height = height + pad_top + pad_bottom
12511253
new_width = width + pad_left + pad_right
1252-
12531254
result = torch.zeros((batch, new_height, new_width, channels), device=image.device)
1254-
12551255
if channels >= 3:
12561256
result[..., 0] = r
12571257
result[..., 1] = g
12581258
result[..., 2] = b
12591259
if channels == 4:
12601260
result[..., 3] = 1.0
1261-
12621261
result[:, pad_top:pad_top+height, pad_left:pad_left+width, :] = image
1263-
12641262
return result
12651263

1266-
def match_dimensions(self, image1, image2, direction, color_val):
1264+
def match_dimensions(self, image1, image2, stitch_mode, color_val):
12671265
h1, w1 = image1.shape[1:3]
12681266
h2, w2 = image2.shape[1:3]
1269-
1270-
if direction in ["left", "right"]:
1267+
if stitch_mode in ["left", "right"]:
12711268
if h1 != h2:
12721269
target_h = max(h1, h2)
12731270
if h1 < target_h:
@@ -1278,7 +1275,7 @@ def match_dimensions(self, image1, image2, direction, color_val):
12781275
pad_h = target_h - h2
12791276
pad_top, pad_bottom = pad_h // 2, pad_h - pad_h // 2
12801277
image2 = self.pad_with_color(image2, (pad_top, pad_bottom, 0, 0), color_val)
1281-
else:
1278+
else:
12821279
if w1 != w2:
12831280
target_w = max(w1, w2)
12841281
if w1 < target_w:
@@ -1289,7 +1286,6 @@ def match_dimensions(self, image1, image2, direction, color_val):
12891286
pad_w = target_w - w2
12901287
pad_left, pad_right = pad_w // 2, pad_w - pad_w // 2
12911288
image2 = self.pad_with_color(image2, (0, 0, pad_left, pad_right), color_val)
1292-
12931289
return image1, image2
12941290

12951291
def ensure_same_channels(self, image1, image2):
@@ -1307,13 +1303,11 @@ def ensure_same_channels(self, image1, image2):
13071303
], dim=-1)
13081304
return image1, image2
13091305

1310-
def create_spacing(self, image1, image2, spacing_width, direction, color_val):
1306+
def create_spacing(self, image1, image2, spacing_width, stitch_mode, color_val):
13111307
if spacing_width <= 0:
13121308
return None
1313-
13141309
spacing_width = spacing_width + (spacing_width % 2)
1315-
1316-
if direction in ["left", "right"]:
1310+
if stitch_mode in ["left", "right"]:
13171311
spacing_shape = (
13181312
image1.shape[0],
13191313
max(image1.shape[1], image2.shape[1]),
@@ -1327,88 +1321,86 @@ def create_spacing(self, image1, image2, spacing_width, direction, color_val):
13271321
max(image1.shape[2], image2.shape[2]),
13281322
image1.shape[-1],
13291323
)
1330-
13311324
spacing = torch.zeros(spacing_shape, device=image1.device)
1332-
13331325
r, g, b = color_val
13341326
if spacing.shape[-1] >= 3:
13351327
spacing[..., 0] = r
13361328
spacing[..., 1] = g
13371329
spacing[..., 2] = b
13381330
if spacing.shape[-1] == 4:
13391331
spacing[..., 3] = 1.0
1340-
13411332
return spacing
13421333

1343-
def stitch_kontext_mode(self, image1, image2, image3, match_image_size, spacing_width, color_val):
1344-
if image1 is None or image2 is None or image3 is None:
1345-
if image3 is None:
1346-
return self.stitch_two_images(image1, image2, "down", match_image_size, spacing_width, color_val)
1347-
elif image2 is None:
1348-
return self.stitch_two_images(image1, image3, "right", match_image_size, spacing_width, color_val)
1334+
def stitch_kontext_mode(self, image1, image2, image3, match_image_size, spacing_width, color_val, upscale_method, image4=None):
1335+
has_image4 = image4 is not None
1336+
if image1 is None or image2 is None:
1337+
if image2 is None and image3 is not None:
1338+
return self.stitch_two_images(image1, image3, "right", match_image_size, spacing_width, color_val, upscale_method)
13491339
else:
13501340
return image1
1351-
1352-
max_batch = max(image1.shape[0], image2.shape[0], image3.shape[0])
1353-
if image1.shape[0] < max_batch:
1354-
image1 = torch.cat([image1, image1[-1:].repeat(max_batch - image1.shape[0], 1, 1, 1)])
1355-
if image2.shape[0] < max_batch:
1356-
image2 = torch.cat([image2, image2[-1:].repeat(max_batch - image2.shape[0], 1, 1, 1)])
1357-
if image3.shape[0] < max_batch:
1358-
image3 = torch.cat([image3, image3[-1:].repeat(max_batch - image3.shape[0], 1, 1, 1)])
1359-
1341+
images_to_align = [image1, image2]
1342+
if image3 is not None:
1343+
images_to_align.append(image3)
1344+
if has_image4:
1345+
images_to_align.append(image4)
1346+
max_batch = max(img.shape[0] for img in images_to_align)
1347+
for i, img in enumerate(images_to_align):
1348+
if img.shape[0] < max_batch:
1349+
images_to_align[i] = torch.cat([img, img[-1:].repeat(max_batch - img.shape[0], 1, 1, 1)])
1350+
image1, image2 = images_to_align[0], images_to_align[1]
1351+
image3 = images_to_align[2] if len(images_to_align) > 2 else None
1352+
image4 = images_to_align[3] if len(images_to_align) > 3 else None
1353+
if has_image4:
1354+
left_images = [image1, image2, image3]
1355+
right_image = image4
1356+
else:
1357+
left_images = [image1, image2]
1358+
right_image = image3
13601359
if match_image_size:
13611360
w1 = image1.shape[2]
1362-
h2, w2 = image2.shape[1:3]
1363-
aspect_ratio = h2 / w2
1364-
target_w = w1
1365-
target_h = int(w1 * aspect_ratio)
1366-
1367-
image2 = common_upscale(
1368-
image2.movedim(-1, 1), target_w, target_h, "lanczos", "disabled"
1369-
).movedim(1, -1)
1361+
for i, img in enumerate(left_images[1:], 1):
1362+
h, w = img.shape[1:3]
1363+
aspect_ratio = h / w
1364+
target_w = w1
1365+
target_h = int(w1 * aspect_ratio)
1366+
left_images[i] = common_upscale(
1367+
img.movedim(-1, 1), target_w, target_h, upscale_method, "disabled"
1368+
).movedim(1, -1)
13701369
else:
1371-
image1, image2 = self.match_dimensions(image1, image2, "down", color_val)
1372-
1373-
image1, image2 = self.ensure_same_channels(image1, image2)
1374-
1375-
v_spacing = self.create_spacing(image1, image2, spacing_width, "down", color_val)
1376-
1377-
v_images = [image1, image2]
1378-
if v_spacing is not None:
1379-
v_images.insert(1, v_spacing)
1380-
1381-
left_column = torch.cat(v_images, dim=1)
1382-
1370+
for i in range(1, len(left_images)):
1371+
left_images[0], left_images[i] = self.match_dimensions(left_images[0], left_images[i], "down", color_val)
1372+
for i in range(1, len(left_images)):
1373+
left_images[0], left_images[i] = self.ensure_same_channels(left_images[0], left_images[i])
1374+
left_column_parts = [left_images[0]]
1375+
for i in range(1, len(left_images)):
1376+
spacing = self.create_spacing(left_images[i-1], left_images[i], spacing_width, "down", color_val)
1377+
if spacing is not None:
1378+
left_column_parts.append(spacing)
1379+
left_column_parts.append(left_images[i])
1380+
left_column = torch.cat(left_column_parts, dim=1)
13831381
if match_image_size:
13841382
h_left = left_column.shape[1]
1385-
h3, w3 = image3.shape[1:3]
1386-
aspect_ratio = w3 / h3
1383+
hr, wr = right_image.shape[1:3]
1384+
aspect_ratio = wr / hr
13871385
target_h = h_left
13881386
target_w = int(h_left * aspect_ratio)
1389-
1390-
image3 = common_upscale(
1391-
image3.movedim(-1, 1), target_w, target_h, "lanczos", "disabled"
1387+
right_image = common_upscale(
1388+
right_image.movedim(-1, 1), target_w, target_h, upscale_method, "disabled"
13921389
).movedim(1, -1)
13931390
else:
1394-
left_column, image3 = self.match_dimensions(left_column, image3, "right", color_val)
1395-
1396-
left_column, image3 = self.ensure_same_channels(left_column, image3)
1397-
1398-
h_spacing = self.create_spacing(left_column, image3, spacing_width, "right", color_val)
1399-
1400-
h_images = [left_column, image3]
1391+
left_column, right_image = self.match_dimensions(left_column, right_image, "right", color_val)
1392+
left_column, right_image = self.ensure_same_channels(left_column, right_image)
1393+
h_spacing = self.create_spacing(left_column, right_image, spacing_width, "right", color_val)
1394+
h_images = [left_column]
14011395
if h_spacing is not None:
1402-
h_images.insert(1, h_spacing)
1403-
1396+
h_images.append(h_spacing)
1397+
h_images.append(right_image)
14041398
result = torch.cat(h_images, dim=2)
1405-
14061399
return result
14071400

1408-
def stitch_two_images(self, image1, image2, direction, match_image_size, spacing_width, color_val):
1401+
def stitch_two_images(self, image1, image2, stitch_mode, match_image_size, spacing_width, color_val, upscale_method):
14091402
if image2 is None:
14101403
return image1
1411-
14121404
if image1.shape[0] != image2.shape[0]:
14131405
max_batch = max(image1.shape[0], image2.shape[0])
14141406
if image1.shape[0] < max_batch:
@@ -1419,72 +1411,99 @@ def stitch_two_images(self, image1, image2, direction, match_image_size, spacing
14191411
image2 = torch.cat(
14201412
[image2, image2[-1:].repeat(max_batch - image2.shape[0], 1, 1, 1)]
14211413
)
1422-
14231414
if match_image_size:
14241415
h1, w1 = image1.shape[1:3]
14251416
h2, w2 = image2.shape[1:3]
14261417
aspect_ratio = w2 / h2
1427-
1428-
if direction in ["left", "right"]:
1418+
if stitch_mode in ["left", "right"]:
14291419
target_h, target_w = h1, int(h1 * aspect_ratio)
14301420
else:
14311421
target_w, target_h = w1, int(w1 / aspect_ratio)
1432-
14331422
image2 = common_upscale(
1434-
image2.movedim(-1, 1), target_w, target_h, "lanczos", "disabled"
1423+
image2.movedim(-1, 1), target_w, target_h, upscale_method, "disabled"
14351424
).movedim(1, -1)
14361425
else:
1437-
image1, image2 = self.match_dimensions(image1, image2, direction, color_val)
1438-
1426+
image1, image2 = self.match_dimensions(image1, image2, stitch_mode, color_val)
14391427
image1, image2 = self.ensure_same_channels(image1, image2)
1440-
1441-
spacing = self.create_spacing(image1, image2, spacing_width, direction, color_val)
1442-
1443-
images = [image2, image1] if direction in ["left", "up"] else [image1, image2]
1428+
spacing = self.create_spacing(image1, image2, spacing_width, stitch_mode, color_val)
1429+
images = [image2, image1] if stitch_mode in ["left", "up"] else [image1, image2]
14441430
if spacing is not None:
14451431
images.insert(1, spacing)
1446-
1447-
concat_dim = 2 if direction in ["left", "right"] else 1
1432+
concat_dim = 2 if stitch_mode in ["left", "right"] else 1
14481433
result = torch.cat(images, dim=concat_dim)
1449-
14501434
return result
14511435

1452-
def stitch(self, image1, direction, match_image_size, max_width, max_height, spacing_width, background_color, image2=None, image3=None,):
1436+
def create_blank_like(self, reference_image, color_val):
1437+
batch, height, width, channels = reference_image.shape
1438+
result = torch.zeros((batch, height, width, channels), device=reference_image.device)
1439+
r, g, b = color_val
1440+
if channels >= 3:
1441+
result[..., 0] = r
1442+
result[..., 1] = g
1443+
result[..., 2] = b
1444+
if channels == 4:
1445+
result[..., 3] = 1.0
1446+
return result
1447+
1448+
def stitch_multi_mode(self, image1, image2, image3, image4, stitch_mode, match_image_size, spacing_width, color_val, upscale_method):
1449+
images = [image for image in [image1, image2, image3, image4] if image is not None]
1450+
if len(images) == 0:
1451+
return torch.zeros((1, 64, 64, 3))
1452+
if len(images) == 1:
1453+
return images[0]
1454+
current = images[0]
1455+
for next_img in images[1:]:
1456+
current = self.stitch_two_images(current, next_img, stitch_mode, match_image_size, spacing_width, color_val, upscale_method)
1457+
return current
1458+
1459+
def stitch_grid_2x2(self, image1, image2, image3, image4, match_image_size, spacing_width, color_val, upscale_method):
1460+
ref = image1
1461+
img2 = image2 if image2 is not None else self.create_blank_like(ref, color_val)
1462+
row1 = self.stitch_two_images(ref, img2, "right", match_image_size, spacing_width, color_val, upscale_method)
1463+
img3 = image3 if image3 is not None else self.create_blank_like(ref, color_val)
1464+
img4 = image4 if image4 is not None else self.create_blank_like(ref, color_val)
1465+
row2 = self.stitch_two_images(img3, img4, "right", match_image_size, spacing_width, color_val, upscale_method)
1466+
result = self.stitch_two_images(row1, row2, "down", match_image_size, spacing_width, color_val, upscale_method)
1467+
return result
14531468

1469+
def stitch(self, image1, stitch_mode, match_image_size, megapixels, max_width, max_height, upscale_method, spacing_width, background_color, image2=None, image3=None, image4=None,):
14541470
if image1 is None:
14551471
return (torch.zeros((1, 64, 64, 3)),)
1456-
14571472
color_val = self.hex_to_rgb(background_color)
1458-
1459-
if direction == "kontext_mode":
1460-
result = self.stitch_kontext_mode(image1, image2, image3, match_image_size, spacing_width, color_val)
1473+
if stitch_mode == "kontext_mode":
1474+
result = self.stitch_kontext_mode(image1, image2, image3, match_image_size, spacing_width, color_val, upscale_method, image4=image4)
1475+
elif stitch_mode == "2x2":
1476+
result = self.stitch_grid_2x2(image1, image2, image3, image4, match_image_size, spacing_width, color_val, upscale_method)
14611477
else:
1462-
result = self.stitch_two_images(image1, image2, direction, match_image_size, spacing_width, color_val)
1463-
1464-
if max_width > 0 or max_height > 0:
1465-
h, w = result.shape[1:3]
1466-
need_resize = False
1467-
1478+
result = self.stitch_multi_mode(image1, image2, image3, image4, stitch_mode, match_image_size, spacing_width, color_val, upscale_method)
1479+
h, w = result.shape[1:3]
1480+
need_resize = False
1481+
target_w, target_h = w, h
1482+
if megapixels > 0:
1483+
aspect_ratio = w / h
1484+
target_pixels = int(megapixels * 1024 * 1024)
1485+
target_h = int((target_pixels / aspect_ratio) ** 0.5)
1486+
target_w = int(aspect_ratio * target_h)
1487+
need_resize = True
1488+
elif max_width > 0 or max_height > 0:
14681489
if max_width > 0 and w > max_width:
14691490
scale_factor = max_width / w
14701491
target_w = max_width
14711492
target_h = int(h * scale_factor)
14721493
need_resize = True
14731494
else:
14741495
target_w, target_h = w, h
1475-
14761496
if max_height > 0 and (target_h > max_height or (target_h == h and h > max_height)):
14771497
scale_factor = max_height / target_h
14781498
target_h = max_height
14791499
target_w = int(target_w * scale_factor)
14801500
need_resize = True
1481-
1482-
if need_resize:
1483-
result = common_upscale(
1484-
result.movedim(-1, 1), target_w, target_h, "lanczos", "disabled"
1485-
).movedim(1, -1)
1486-
1487-
return (result,)
1501+
if need_resize:
1502+
result = common_upscale(
1503+
result.movedim(-1, 1), target_w, target_h, upscale_method, "disabled"
1504+
).movedim(1, -1)
1505+
final_height, final_width = result.shape[1:3]
1506+
return (result, final_width, final_height)
14881507

14891508
# Image Crop node
14901509
class AILab_ImageCrop:
@@ -2156,4 +2175,4 @@ def resize(self, image, width, height, scale_by, upscale_method, resize_mode, pa
21562175
"AILab_ImageCompare": "Image Compare (RMBG) 🖼️🖼️",
21572176
"AILab_ColorInput": "Color Input (RMBG) 🎨",
21582177
"AILab_ImageMaskResize": "Image Mask Resize (RMBG) 🖼️🎭"
2159-
}
2178+
}

0 commit comments

Comments
 (0)