@@ -106,8 +106,8 @@ def _prepare_kontext(self) -> tuple[torch.Tensor, torch.Tensor]:
106
106
107
107
# Track cumulative dimensions for spatial tiling
108
108
# These track the running extent of the virtual canvas in latent space
109
- h = 0 # Running height extent
110
- w = 0 # Running width extent
109
+ canvas_h = 0 # Running canvas height
110
+ canvas_w = 0 # Running canvas width
111
111
112
112
vae_info = self ._context .models .load (self ._vae_field .vae )
113
113
@@ -132,11 +132,11 @@ def _prepare_kontext(self) -> tuple[torch.Tensor, torch.Tensor]:
132
132
# Continue with VAE encoding
133
133
# Don't sample from the distribution for reference images - use the mean (matching ComfyUI)
134
134
# Estimate working memory for encode operation (50% of decode memory requirements)
135
- h = image_tensor .shape [- 2 ]
136
- w = image_tensor .shape [- 1 ]
135
+ img_h = image_tensor .shape [- 2 ]
136
+ img_w = image_tensor .shape [- 1 ]
137
137
element_size = next (vae_info .model .parameters ()).element_size ()
138
138
scaling_constant = 1100 # 50% of decode scaling constant (2200)
139
- estimated_working_memory = int (h * w * element_size * scaling_constant )
139
+ estimated_working_memory = int (img_h * img_w * element_size * scaling_constant )
140
140
141
141
with vae_info .model_on_device (working_mem_bytes = estimated_working_memory ) as (_ , vae ):
142
142
assert isinstance (vae , AutoEncoder )
@@ -161,21 +161,35 @@ def _prepare_kontext(self) -> tuple[torch.Tensor, torch.Tensor]:
161
161
kontext_latents_packed = pack (kontext_latents_unpacked ).to (self ._device , self ._dtype )
162
162
163
163
# Determine spatial offsets for this reference image
164
- # - Compare the potential new canvas dimensions if we add the image vertically vs horizontally
165
- # - Choose the placement that results in a more square-like canvas
166
164
h_offset = 0
167
165
w_offset = 0
168
166
169
167
if idx > 0 : # First image starts at (0, 0)
170
- # Check which placement would result in better canvas dimensions
171
- # If adding to height would make the canvas taller than wide, tile horizontally
172
- # Otherwise, tile vertically
173
- if latent_height + h > latent_width + w :
168
+ # Calculate potential canvas dimensions for each tiling option
169
+ # Option 1: Tile vertically (below existing content)
170
+ potential_h_vertical = canvas_h + latent_height
171
+ potential_w_vertical = max (canvas_w , latent_width )
172
+
173
+ # Option 2: Tile horizontally (to the right of existing content)
174
+ potential_h_horizontal = max (canvas_h , latent_height )
175
+ potential_w_horizontal = canvas_w + latent_width
176
+
177
+ # Choose arrangement that minimizes the maximum dimension
178
+ # This keeps the canvas closer to square, optimizing attention computation
179
+ if potential_h_vertical > potential_w_horizontal :
174
180
# Tile horizontally (to the right of existing images)
175
- w_offset = w
181
+ w_offset = canvas_w
182
+ canvas_w = canvas_w + latent_width
183
+ canvas_h = max (canvas_h , latent_height )
176
184
else :
177
185
# Tile vertically (below existing images)
178
- h_offset = h
186
+ h_offset = canvas_h
187
+ canvas_h = canvas_h + latent_height
188
+ canvas_w = max (canvas_w , latent_width )
189
+ else :
190
+ # First image - just set canvas dimensions
191
+ canvas_h = latent_height
192
+ canvas_w = latent_width
179
193
180
194
# Generate IDs with both index offset and spatial offsets
181
195
kontext_ids = generate_img_ids_with_offset (
@@ -189,11 +203,6 @@ def _prepare_kontext(self) -> tuple[torch.Tensor, torch.Tensor]:
189
203
w_offset = w_offset ,
190
204
)
191
205
192
- # Update cumulative dimensions
193
- # Track the maximum extent of the virtual canvas after placing this image
194
- h = max (h , latent_height + h_offset )
195
- w = max (w , latent_width + w_offset )
196
-
197
206
all_latents .append (kontext_latents_packed )
198
207
all_ids .append (kontext_ids )
199
208
0 commit comments