Skip to content

Commit 135ae2c

Browse files
committed
Allow HuMo to work with embedded image for I2V
1 parent 7eca956 commit 135ae2c

File tree

3 files changed

+10
-4
lines changed

3 files changed

+10
-4
lines changed

comfy/ldm/wan/model.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1510,7 +1510,7 @@ def __init__(self,
15101510
operations=None,
15111511
):
15121512

1513-
super().__init__(model_type='t2v', patch_size=patch_size, text_len=text_len, in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim, text_dim=text_dim, out_dim=out_dim, num_heads=num_heads, num_layers=num_layers, window_size=window_size, qk_norm=qk_norm, cross_attn_norm=cross_attn_norm, eps=eps, flf_pos_embed_token_number=flf_pos_embed_token_number, wan_attn_block_class=WanAttentionBlockAudio, image_model=image_model, device=device, dtype=dtype, operations=operations)
1513+
super().__init__(model_type='i2v', patch_size=patch_size, text_len=text_len, in_dim=36, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim, text_dim=text_dim, out_dim=out_dim, num_heads=num_heads, num_layers=num_layers, window_size=window_size, qk_norm=qk_norm, cross_attn_norm=cross_attn_norm, eps=eps, flf_pos_embed_token_number=flf_pos_embed_token_number, wan_attn_block_class=WanAttentionBlockAudio, image_model=image_model, device=device, dtype=dtype, operations=operations)
15141514

15151515
self.audio_proj = AudioProjModel(seq_len=8, blocks=5, channels=1280, intermediate_dim=512, output_dim=1536, context_tokens=audio_token_num, dtype=dtype, device=device, operations=operations)
15161516

@@ -1539,6 +1539,12 @@ def forward_orig(
15391539
e0 = self.time_projection(e).unflatten(2, (6, self.dim))
15401540

15411541
if reference_latent is not None:
1542+
if reference_latent.shape[1] < 36:
1543+
padding_needed = 36 - reference_latent.shape[1]
1544+
padding = torch.zeros(reference_latent.shape[0], padding_needed, *reference_latent.shape[2:],
1545+
device=reference_latent.device, dtype=reference_latent.dtype)
1546+
reference_latent = torch.cat([padding, reference_latent], dim=1) # pad at beginning like c_concat
1547+
15421548
ref = self.patch_embedding(reference_latent.float()).to(x.dtype)
15431549
ref = ref.flatten(2).transpose(1, 2)
15441550
freqs_ref = self.rope_encode(reference_latent.shape[-3], reference_latent.shape[-2], reference_latent.shape[-1], t_start=time, device=x.device, dtype=x.dtype)
@@ -1548,7 +1554,7 @@ def forward_orig(
15481554

15491555
# context
15501556
context = self.text_embedding(context)
1551-
context_img_len = None
1557+
context_img_len = 0
15521558

15531559
if audio_embed is not None:
15541560
if reference_latent is not None:

comfy/model_base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1227,7 +1227,7 @@ def extra_conds(self, **kwargs):
12271227
if audio_embed is not None:
12281228
out['audio_embed'] = comfy.conds.CONDRegular(audio_embed)
12291229

1230-
if "c_concat" not in out: # 1.7B model
1230+
if "c_concat" not in out or "concat_latent_image" in kwargs: # 1.7B model OR I2V mode
12311231
reference_latents = kwargs.get("reference_latents", None)
12321232
if reference_latents is not None:
12331233
out['reference_latent'] = comfy.conds.CONDRegular(self.process_latent_in(reference_latents[-1]))

comfy/supported_models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1080,7 +1080,7 @@ class WAN21_HuMo(WAN21_T2V):
10801080
}
10811081

10821082
def get_model(self, state_dict, prefix="", device=None):
1083-
out = model_base.WAN21_HuMo(self, image_to_video=False, device=device)
1083+
out = model_base.WAN21_HuMo(self, image_to_video=True, device=device)
10841084
return out
10851085

10861086
class WAN22_S2V(WAN21_T2V):

0 commit comments

Comments
 (0)