Skip to content

Commit e4a0d34

Browse files
committed
Fix proper handling of difference between 1.7B and 14B HuMo models
1 parent 135ae2c commit e4a0d34

File tree

2 files changed

+17
-22
lines changed

2 files changed

+17
-22
lines changed

comfy/ldm/wan/model.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1510,7 +1510,7 @@ def __init__(self,
15101510
operations=None,
15111511
):
15121512

1513-
super().__init__(model_type='i2v', patch_size=patch_size, text_len=text_len, in_dim=36, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim, text_dim=text_dim, out_dim=out_dim, num_heads=num_heads, num_layers=num_layers, window_size=window_size, qk_norm=qk_norm, cross_attn_norm=cross_attn_norm, eps=eps, flf_pos_embed_token_number=flf_pos_embed_token_number, wan_attn_block_class=WanAttentionBlockAudio, image_model=image_model, device=device, dtype=dtype, operations=operations)
1513+
super().__init__(model_type='i2v', patch_size=patch_size, text_len=text_len, in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim, text_dim=text_dim, out_dim=out_dim, num_heads=num_heads, num_layers=num_layers, window_size=window_size, qk_norm=qk_norm, cross_attn_norm=cross_attn_norm, eps=eps, flf_pos_embed_token_number=flf_pos_embed_token_number, wan_attn_block_class=WanAttentionBlockAudio, image_model=image_model, device=device, dtype=dtype, operations=operations)
15141514

15151515
self.audio_proj = AudioProjModel(seq_len=8, blocks=5, channels=1280, intermediate_dim=512, output_dim=1536, context_tokens=audio_token_num, dtype=dtype, device=device, operations=operations)
15161516

@@ -1539,12 +1539,6 @@ def forward_orig(
15391539
e0 = self.time_projection(e).unflatten(2, (6, self.dim))
15401540

15411541
if reference_latent is not None:
1542-
if reference_latent.shape[1] < 36:
1543-
padding_needed = 36 - reference_latent.shape[1]
1544-
padding = torch.zeros(reference_latent.shape[0], padding_needed, *reference_latent.shape[2:],
1545-
device=reference_latent.device, dtype=reference_latent.dtype)
1546-
reference_latent = torch.cat([padding, reference_latent], dim=1) # pad at beginning like c_concat
1547-
15481542
ref = self.patch_embedding(reference_latent.float()).to(x.dtype)
15491543
ref = ref.flatten(2).transpose(1, 2)
15501544
freqs_ref = self.rope_encode(reference_latent.shape[-3], reference_latent.shape[-2], reference_latent.shape[-1], t_start=time, device=x.device, dtype=x.dtype)

comfy/model_base.py

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1227,22 +1227,23 @@ def extra_conds(self, **kwargs):
12271227
if audio_embed is not None:
12281228
out['audio_embed'] = comfy.conds.CONDRegular(audio_embed)
12291229

1230-
if "c_concat" not in out or "concat_latent_image" in kwargs: # 1.7B model OR I2V mode
1231-
reference_latents = kwargs.get("reference_latents", None)
1232-
if reference_latents is not None:
1233-
out['reference_latent'] = comfy.conds.CONDRegular(self.process_latent_in(reference_latents[-1]))
1230+
reference_latents = kwargs.get("reference_latents", None)
1231+
1232+
if "c_concat" not in out and reference_latents is not None and reference_latents[0].shape[1] == 16: # 1.7B model
1233+
out['reference_latent'] = comfy.conds.CONDRegular(self.process_latent_in(reference_latents[-1]))
12341234
else:
1235-
noise_shape = list(noise.shape)
1236-
noise_shape[1] += 4
1237-
concat_latent = torch.zeros(noise_shape, device=noise.device, dtype=noise.dtype)
1238-
zero_vae_values_first = torch.tensor([0.8660, -0.4326, -0.0017, -0.4884, -0.5283, 0.9207, -0.9896, 0.4433, -0.5543, -0.0113, 0.5753, -0.6000, -0.8346, -0.3497, -0.1926, -0.6938]).view(1, 16, 1, 1, 1)
1239-
zero_vae_values_second = torch.tensor([1.0869, -1.2370, 0.0206, -0.4357, -0.6411, 2.0307, -1.5972, 1.2659, -0.8595, -0.4654, 0.9638, -1.6330, -1.4310, -0.1098, -0.3856, -1.4583]).view(1, 16, 1, 1, 1)
1240-
zero_vae_values = torch.tensor([0.8642, -1.8583, 0.1577, 0.1350, -0.3641, 2.5863, -1.9670, 1.6065, -1.0475, -0.8678, 1.1734, -1.8138, -1.5933, -0.7721, -0.3289, -1.3745]).view(1, 16, 1, 1, 1)
1241-
concat_latent[:, 4:] = zero_vae_values
1242-
concat_latent[:, 4:, :1] = zero_vae_values_first
1243-
concat_latent[:, 4:, 1:2] = zero_vae_values_second
1244-
out['c_concat'] = comfy.conds.CONDNoiseShape(concat_latent)
1245-
reference_latents = kwargs.get("reference_latents", None)
1235+
concat_latent_image = kwargs.get("concat_latent_image", None)
1236+
if concat_latent_image is None:
1237+
noise_shape = list(noise.shape)
1238+
noise_shape[1] += 4
1239+
concat_latent = torch.zeros(noise_shape, device=noise.device, dtype=noise.dtype)
1240+
zero_vae_values_first = torch.tensor([0.8660, -0.4326, -0.0017, -0.4884, -0.5283, 0.9207, -0.9896, 0.4433, -0.5543, -0.0113, 0.5753, -0.6000, -0.8346, -0.3497, -0.1926, -0.6938]).view(1, 16, 1, 1, 1)
1241+
zero_vae_values_second = torch.tensor([1.0869, -1.2370, 0.0206, -0.4357, -0.6411, 2.0307, -1.5972, 1.2659, -0.8595, -0.4654, 0.9638, -1.6330, -1.4310, -0.1098, -0.3856, -1.4583]).view(1, 16, 1, 1, 1)
1242+
zero_vae_values = torch.tensor([0.8642, -1.8583, 0.1577, 0.1350, -0.3641, 2.5863, -1.9670, 1.6065, -1.0475, -0.8678, 1.1734, -1.8138, -1.5933, -0.7721, -0.3289, -1.3745]).view(1, 16, 1, 1, 1)
1243+
concat_latent[:, 4:] = zero_vae_values
1244+
concat_latent[:, 4:, :1] = zero_vae_values_first
1245+
concat_latent[:, 4:, 1:2] = zero_vae_values_second
1246+
out['c_concat'] = comfy.conds.CONDNoiseShape(concat_latent)
12461247
if reference_latents is not None:
12471248
ref_latent = self.process_latent_in(reference_latents[-1])
12481249
ref_latent_shape = list(ref_latent.shape)

0 commit comments

Comments
 (0)