Skip to content

Commit c6eff71

Browse files
fix(backend): bug in kontext canvas dimension tracking when concating in latent space
We weren't tracking the canvas dimensions properly which coudl result in FLUX not "seeing" ref images after the first very well
1 parent 6ea4c47 commit c6eff71

File tree

1 file changed

+27
-18
lines changed

1 file changed

+27
-18
lines changed

invokeai/backend/flux/extensions/kontext_extension.py

Lines changed: 27 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -106,8 +106,8 @@ def _prepare_kontext(self) -> tuple[torch.Tensor, torch.Tensor]:
106106

107107
# Track cumulative dimensions for spatial tiling
108108
# These track the running extent of the virtual canvas in latent space
109-
h = 0 # Running height extent
110-
w = 0 # Running width extent
109+
canvas_h = 0 # Running canvas height
110+
canvas_w = 0 # Running canvas width
111111

112112
vae_info = self._context.models.load(self._vae_field.vae)
113113

@@ -132,11 +132,11 @@ def _prepare_kontext(self) -> tuple[torch.Tensor, torch.Tensor]:
132132
# Continue with VAE encoding
133133
# Don't sample from the distribution for reference images - use the mean (matching ComfyUI)
134134
# Estimate working memory for encode operation (50% of decode memory requirements)
135-
h = image_tensor.shape[-2]
136-
w = image_tensor.shape[-1]
135+
img_h = image_tensor.shape[-2]
136+
img_w = image_tensor.shape[-1]
137137
element_size = next(vae_info.model.parameters()).element_size()
138138
scaling_constant = 1100 # 50% of decode scaling constant (2200)
139-
estimated_working_memory = int(h * w * element_size * scaling_constant)
139+
estimated_working_memory = int(img_h * img_w * element_size * scaling_constant)
140140

141141
with vae_info.model_on_device(working_mem_bytes=estimated_working_memory) as (_, vae):
142142
assert isinstance(vae, AutoEncoder)
@@ -161,21 +161,35 @@ def _prepare_kontext(self) -> tuple[torch.Tensor, torch.Tensor]:
161161
kontext_latents_packed = pack(kontext_latents_unpacked).to(self._device, self._dtype)
162162

163163
# Determine spatial offsets for this reference image
164-
# - Compare the potential new canvas dimensions if we add the image vertically vs horizontally
165-
# - Choose the placement that results in a more square-like canvas
166164
h_offset = 0
167165
w_offset = 0
168166

169167
if idx > 0: # First image starts at (0, 0)
170-
# Check which placement would result in better canvas dimensions
171-
# If adding to height would make the canvas taller than wide, tile horizontally
172-
# Otherwise, tile vertically
173-
if latent_height + h > latent_width + w:
168+
# Calculate potential canvas dimensions for each tiling option
169+
# Option 1: Tile vertically (below existing content)
170+
potential_h_vertical = canvas_h + latent_height
171+
potential_w_vertical = max(canvas_w, latent_width)
172+
173+
# Option 2: Tile horizontally (to the right of existing content)
174+
potential_h_horizontal = max(canvas_h, latent_height)
175+
potential_w_horizontal = canvas_w + latent_width
176+
177+
# Choose arrangement that minimizes the maximum dimension
178+
# This keeps the canvas closer to square, optimizing attention computation
179+
if potential_h_vertical > potential_w_horizontal:
174180
# Tile horizontally (to the right of existing images)
175-
w_offset = w
181+
w_offset = canvas_w
182+
canvas_w = canvas_w + latent_width
183+
canvas_h = max(canvas_h, latent_height)
176184
else:
177185
# Tile vertically (below existing images)
178-
h_offset = h
186+
h_offset = canvas_h
187+
canvas_h = canvas_h + latent_height
188+
canvas_w = max(canvas_w, latent_width)
189+
else:
190+
# First image - just set canvas dimensions
191+
canvas_h = latent_height
192+
canvas_w = latent_width
179193

180194
# Generate IDs with both index offset and spatial offsets
181195
kontext_ids = generate_img_ids_with_offset(
@@ -189,11 +203,6 @@ def _prepare_kontext(self) -> tuple[torch.Tensor, torch.Tensor]:
189203
w_offset=w_offset,
190204
)
191205

192-
# Update cumulative dimensions
193-
# Track the maximum extent of the virtual canvas after placing this image
194-
h = max(h, latent_height + h_offset)
195-
w = max(w, latent_width + w_offset)
196-
197206
all_latents.append(kontext_latents_packed)
198207
all_ids.append(kontext_ids)
199208

0 commit comments

Comments
 (0)