Skip to content

Commit 65ad392

Browse files
feat(nodes): add node to prep images for FLUX Kontext
1 parent 56d75e1 commit 65ad392

File tree

1 file changed

+93
-0
lines changed

1 file changed

+93
-0
lines changed

invokeai/app/invocations/image.py

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1347,3 +1347,96 @@ def invoke(self, context: InvocationContext) -> ImageOutput:
13471347

13481348
image_dto = context.images.save(image=target_image)
13491349
return ImageOutput.build(image_dto)
1350+
1351+
1352+
@invocation(
1353+
"flux_kontext_image_prep",
1354+
title="FLUX Kontext Image Prep",
1355+
tags=["image", "concatenate", "flux", "kontext"],
1356+
category="image",
1357+
version="1.0.0",
1358+
)
1359+
class FluxKontextConcatenateImagesInvocation(BaseInvocation, WithMetadata, WithBoard):
1360+
"""Prepares an image or images for use with FLUX Kontext. The first/single image is resized to the nearest
1361+
preferred Kontext resolution. All other images are concatenated horizontally, maintaining their aspect ratio."""
1362+
1363+
images: list[ImageField] = InputField(
1364+
description="The images to concatenate",
1365+
min_length=1,
1366+
max_length=10,
1367+
)
1368+
1369+
use_preferred_resolution: bool = InputField(
1370+
default=True, description="Use FLUX preferred resolutions for the first image"
1371+
)
1372+
1373+
def invoke(self, context: InvocationContext) -> ImageOutput:
1374+
from invokeai.backend.flux.util import PREFERED_KONTEXT_RESOLUTIONS
1375+
1376+
# Step 1: Load all images
1377+
pil_images = []
1378+
for image_field in self.images:
1379+
image = context.images.get_pil(image_field.image_name, mode="RGBA")
1380+
pil_images.append(image)
1381+
1382+
# Step 2: Determine target resolution for the first image
1383+
first_image = pil_images[0]
1384+
width, height = first_image.size
1385+
1386+
if self.use_preferred_resolution:
1387+
aspect_ratio = width / height
1388+
1389+
# Find the closest preferred resolution for the first image
1390+
_, target_width, target_height = min(
1391+
((abs(aspect_ratio - w / h), w, h) for w, h in PREFERED_KONTEXT_RESOLUTIONS), key=lambda x: x[0]
1392+
)
1393+
1394+
# Apply BFL's scaling formula
1395+
scaled_height = 2 * int(target_height / 16)
1396+
final_height = 8 * scaled_height # This will be consistent for all images
1397+
scaled_width = 2 * int(target_width / 16)
1398+
first_width = 8 * scaled_width
1399+
else:
1400+
# Use original dimensions of first image, ensuring divisibility by 16
1401+
final_height = 16 * (height // 16)
1402+
first_width = 16 * (width // 16)
1403+
# Ensure minimum dimensions
1404+
if final_height < 16:
1405+
final_height = 16
1406+
if first_width < 16:
1407+
first_width = 16
1408+
1409+
# Step 3: Process and resize all images with consistent height
1410+
processed_images = []
1411+
total_width = 0
1412+
1413+
for i, image in enumerate(pil_images):
1414+
if i == 0:
1415+
# First image uses the calculated dimensions
1416+
final_width = first_width
1417+
else:
1418+
# Subsequent images maintain aspect ratio with the same height
1419+
img_aspect_ratio = image.width / image.height
1420+
# Calculate width that maintains aspect ratio at the target height
1421+
calculated_width = int(final_height * img_aspect_ratio)
1422+
# Ensure width is divisible by 16 for proper VAE encoding
1423+
final_width = 16 * (calculated_width // 16)
1424+
# Ensure minimum width
1425+
if final_width < 16:
1426+
final_width = 16
1427+
1428+
# Resize image to calculated dimensions
1429+
resized_image = image.resize((final_width, final_height), Image.Resampling.LANCZOS)
1430+
processed_images.append(resized_image)
1431+
total_width += final_width
1432+
1433+
# Step 4: Concatenate images horizontally
1434+
concatenated_image = Image.new("RGB", (total_width, final_height))
1435+
x_offset = 0
1436+
for img in processed_images:
1437+
concatenated_image.paste(img, (x_offset, 0))
1438+
x_offset += img.width
1439+
1440+
# Save the concatenated image
1441+
image_dto = context.images.save(image=concatenated_image)
1442+
return ImageOutput.build(image_dto)

0 commit comments

Comments
 (0)