Skip to content

Commit bfe31d0

Browse files
authored
IC-LoRA: support small grid (#12074)
1 parent 2129e7d commit bfe31d0

File tree

1 file changed

+17
-4
lines changed

1 file changed

+17
-4
lines changed

comfy_extras/nodes_lt.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -223,24 +223,37 @@ def get_latent_index(cls, cond, latent_length, guide_length, frame_idx, scale_fa
223223
return frame_idx, latent_idx
224224

225225
@classmethod
226-
def add_keyframe_index(cls, cond, frame_idx, guiding_latent, scale_factors):
226+
def add_keyframe_index(cls, cond, frame_idx, guiding_latent, scale_factors, latent_downscale_factor=1):
227227
keyframe_idxs, _ = get_keyframe_idxs(cond)
228228
_, latent_coords = cls.PATCHIFIER.patchify(guiding_latent)
229229
pixel_coords = latent_to_pixel_coords(latent_coords, scale_factors, causal_fix=frame_idx == 0) # we need the causal fix only if we're placing the new latents at index 0
230230
pixel_coords[:, 0] += frame_idx
231+
232+
# The following adjusts keyframe end positions for small grid IC-LoRA.
233+
# After dilation, the small grid has the same size and position as the large grid,
234+
# but each token encodes a larger image patch. We adjust the end position (not start)
235+
# so that RoPE represents the correct middle point of each token.
236+
# keyframe_idxs dims: (batch, spatial_dim [t,h,w], token_id, [start, end])
237+
# We only adjust h,w (not t) in dim 1, and only end (not start) in dim 3.
238+
spatial_end_offset = (latent_downscale_factor - 1) * torch.tensor(
239+
scale_factors[1:],
240+
device=pixel_coords.device,
241+
).view(1, -1, 1, 1)
242+
pixel_coords[:, 1:, :, 1:] += spatial_end_offset.to(pixel_coords.dtype)
243+
231244
if keyframe_idxs is None:
232245
keyframe_idxs = pixel_coords
233246
else:
234247
keyframe_idxs = torch.cat([keyframe_idxs, pixel_coords], dim=2)
235248
return node_helpers.conditioning_set_values(cond, {"keyframe_idxs": keyframe_idxs})
236249

237250
@classmethod
238-
def append_keyframe(cls, positive, negative, frame_idx, latent_image, noise_mask, guiding_latent, strength, scale_factors, guide_mask=None, in_channels=128):
251+
def append_keyframe(cls, positive, negative, frame_idx, latent_image, noise_mask, guiding_latent, strength, scale_factors, guide_mask=None, in_channels=128, latent_downscale_factor=1):
239252
if latent_image.shape[1] != in_channels or guiding_latent.shape[1] != in_channels:
240253
raise ValueError("Adding guide to a combined AV latent is not supported.")
241254

242-
positive = cls.add_keyframe_index(positive, frame_idx, guiding_latent, scale_factors)
243-
negative = cls.add_keyframe_index(negative, frame_idx, guiding_latent, scale_factors)
255+
positive = cls.add_keyframe_index(positive, frame_idx, guiding_latent, scale_factors, latent_downscale_factor)
256+
negative = cls.add_keyframe_index(negative, frame_idx, guiding_latent, scale_factors, latent_downscale_factor)
244257

245258
if guide_mask is not None:
246259
target_h = max(noise_mask.shape[3], guide_mask.shape[3])

0 commit comments

Comments
 (0)