Skip to content

Commit 2846939

Browse files
committed
update
1 parent 85fc267 commit 2846939

File tree

2 files changed

+177
-25
lines changed

2 files changed

+177
-25
lines changed

src/diffusers/models/transformers/transformer_hunyuan_video.py

Lines changed: 164 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727
from ..attention_processor import Attention, AttentionProcessor
2828
from ..cache_utils import CacheMixin
2929
from ..embeddings import (
30-
CombinedTimestepGuidanceTextProjEmbeddings,
3130
CombinedTimestepTextProjEmbeddings,
3231
PixArtAlphaTextProjection,
3332
TimestepEmbedding,
@@ -179,6 +178,7 @@ def forward(
179178
class HunyuanVideoTokenReplaceAdaLayerNormZero(nn.Module):
180179
def __init__(self, embedding_dim: int, norm_type: str = "layer_norm", bias: bool = True):
181180
super().__init__()
181+
182182
self.silu = nn.SiLU()
183183
self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=bias)
184184

@@ -228,8 +228,53 @@ def forward(
228228
)
229229

230230

231+
class HunyuanVideoTokenReplaceAdaLayerNormZeroSingle(nn.Module):
232+
def __init__(self, embedding_dim: int, norm_type: str = "layer_norm", bias: bool = True):
233+
super().__init__()
234+
235+
self.silu = nn.SiLU()
236+
self.linear = nn.Linear(embedding_dim, 3 * embedding_dim, bias=bias)
237+
238+
if norm_type == "layer_norm":
239+
self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
240+
else:
241+
raise ValueError(
242+
f"Unsupported `norm_type` ({norm_type}) provided. Supported ones are: 'layer_norm', 'fp32_layer_norm'."
243+
)
244+
245+
def forward(
246+
self,
247+
hidden_states: torch.Tensor,
248+
emb: torch.Tensor,
249+
token_replace_emb: torch.Tensor,
250+
first_frame_num_tokens: int,
251+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
252+
emb = self.linear(self.silu(emb))
253+
token_replace_emb = self.linear(self.silu(token_replace_emb))
254+
255+
shift_msa, scale_msa, gate_msa = emb.chunk(3, dim=1)
256+
tr_shift_msa, tr_scale_msa, tr_gate_msa = token_replace_emb.chunk(3, dim=1)
257+
258+
norm_hidden_states = self.norm(hidden_states)
259+
hidden_states_zero = (
260+
norm_hidden_states[:, :first_frame_num_tokens] * (1 + tr_scale_msa[:, None]) + tr_shift_msa[:, None]
261+
)
262+
hidden_states_orig = (
263+
norm_hidden_states[:, first_frame_num_tokens:] * (1 + scale_msa[:, None]) + shift_msa[:, None]
264+
)
265+
hidden_states = torch.cat([hidden_states_zero, hidden_states_orig], dim=1)
266+
267+
return hidden_states, gate_msa, tr_gate_msa
268+
269+
231270
class HunyuanVideoConditionEmbedding(nn.Module):
232-
def __init__(self, embedding_dim: int, pooled_projection_dim: int, guidance_embeds: bool, image_condition_type: Optional[str] = None):
271+
def __init__(
272+
self,
273+
embedding_dim: int,
274+
pooled_projection_dim: int,
275+
guidance_embeds: bool,
276+
image_condition_type: Optional[str] = None,
277+
):
233278
super().__init__()
234279

235280
self.image_condition_type = image_condition_type
@@ -242,7 +287,9 @@ def __init__(self, embedding_dim: int, pooled_projection_dim: int, guidance_embe
242287
if guidance_embeds:
243288
self.guidance_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
244289

245-
def forward(self, timestep: torch.Tensor, pooled_projection: torch.Tensor, guidance: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]:
290+
def forward(
291+
self, timestep: torch.Tensor, pooled_projection: torch.Tensor, guidance: Optional[torch.Tensor] = None
292+
) -> Tuple[torch.Tensor, torch.Tensor]:
246293
timesteps_proj = self.time_proj(timestep)
247294
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=pooled_projection.dtype)) # (N, D)
248295
pooled_projections = self.text_embedder(pooled_projection)
@@ -480,6 +527,8 @@ def forward(
480527
temb: torch.Tensor,
481528
attention_mask: Optional[torch.Tensor] = None,
482529
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
530+
*args,
531+
**kwargs,
483532
) -> torch.Tensor:
484533
text_seq_length = encoder_hidden_states.shape[1]
485534
hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1)
@@ -558,6 +607,7 @@ def forward(
558607
temb: torch.Tensor,
559608
attention_mask: Optional[torch.Tensor] = None,
560609
freqs_cis: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
610+
*args,
561611
**kwargs,
562612
) -> Tuple[torch.Tensor, torch.Tensor]:
563613
# 1. Input normalization
@@ -594,6 +644,86 @@ def forward(
594644
return hidden_states, encoder_hidden_states
595645

596646

647+
class HunyuanVideoTokenReplaceSingleTransformerBlock(nn.Module):
648+
def __init__(
649+
self,
650+
num_attention_heads: int,
651+
attention_head_dim: int,
652+
mlp_ratio: float = 4.0,
653+
qk_norm: str = "rms_norm",
654+
) -> None:
655+
super().__init__()
656+
657+
hidden_size = num_attention_heads * attention_head_dim
658+
mlp_dim = int(hidden_size * mlp_ratio)
659+
660+
self.attn = Attention(
661+
query_dim=hidden_size,
662+
cross_attention_dim=None,
663+
dim_head=attention_head_dim,
664+
heads=num_attention_heads,
665+
out_dim=hidden_size,
666+
bias=True,
667+
processor=HunyuanVideoAttnProcessor2_0(),
668+
qk_norm=qk_norm,
669+
eps=1e-6,
670+
pre_only=True,
671+
)
672+
673+
self.norm = HunyuanVideoTokenReplaceAdaLayerNormZeroSingle(hidden_size, norm_type="layer_norm")
674+
self.proj_mlp = nn.Linear(hidden_size, mlp_dim)
675+
self.act_mlp = nn.GELU(approximate="tanh")
676+
self.proj_out = nn.Linear(hidden_size + mlp_dim, hidden_size)
677+
678+
def forward(
679+
self,
680+
hidden_states: torch.Tensor,
681+
encoder_hidden_states: torch.Tensor,
682+
temb: torch.Tensor,
683+
attention_mask: Optional[torch.Tensor] = None,
684+
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
685+
token_replace_emb: torch.Tensor = None,
686+
num_tokens: int = None,
687+
) -> torch.Tensor:
688+
text_seq_length = encoder_hidden_states.shape[1]
689+
hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1)
690+
691+
residual = hidden_states
692+
693+
# 1. Input normalization
694+
norm_hidden_states, gate, tr_gate = self.norm(hidden_states, temb, token_replace_emb, num_tokens)
695+
mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
696+
697+
norm_hidden_states, norm_encoder_hidden_states = (
698+
norm_hidden_states[:, :-text_seq_length, :],
699+
norm_hidden_states[:, -text_seq_length:, :],
700+
)
701+
702+
# 2. Attention
703+
attn_output, context_attn_output = self.attn(
704+
hidden_states=norm_hidden_states,
705+
encoder_hidden_states=norm_encoder_hidden_states,
706+
attention_mask=attention_mask,
707+
image_rotary_emb=image_rotary_emb,
708+
)
709+
attn_output = torch.cat([attn_output, context_attn_output], dim=1)
710+
711+
# 3. Modulation and residual connection
712+
hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
713+
714+
proj_output = self.proj_out(hidden_states)
715+
hidden_states_zero = proj_output[:, :num_tokens] * tr_gate.unsqueeze(1)
716+
hidden_states_orig = proj_output[:, num_tokens:] * gate.unsqueeze(1)
717+
hidden_states = torch.cat([hidden_states_zero, hidden_states_orig], dim=1)
718+
hidden_states = hidden_states + residual
719+
720+
hidden_states, encoder_hidden_states = (
721+
hidden_states[:, :-text_seq_length, :],
722+
hidden_states[:, -text_seq_length:, :],
723+
)
724+
return hidden_states, encoder_hidden_states
725+
726+
597727
class HunyuanVideoTokenReplaceTransformerBlock(nn.Module):
598728
def __init__(
599729
self,
@@ -664,8 +794,8 @@ def forward(
664794
)
665795

666796
# 3. Modulation and residual connection
667-
hidden_states_zero = hidden_states[:, :num_tokens] + attn_output[:, :num_tokens] * tr_gate_msa
668-
hidden_states_orig = hidden_states[:, num_tokens:] + attn_output[:, num_tokens:] * gate_msa
797+
hidden_states_zero = hidden_states[:, :num_tokens] + attn_output[:, :num_tokens] * tr_gate_msa.unsqueeze(1)
798+
hidden_states_orig = hidden_states[:, num_tokens:] + attn_output[:, num_tokens:] * gate_msa.unsqueeze(1)
669799
hidden_states = torch.cat([hidden_states_zero, hidden_states_orig], dim=1)
670800
encoder_hidden_states = encoder_hidden_states + context_attn_output * c_gate_msa.unsqueeze(1)
671801

@@ -681,8 +811,8 @@ def forward(
681811
ff_output = self.ff(norm_hidden_states)
682812
context_ff_output = self.ff_context(norm_encoder_hidden_states)
683813

684-
hidden_states_zero = hidden_states[:, :num_tokens] + ff_output[:, :num_tokens] * tr_gate_mlp
685-
hidden_states_orig = hidden_states[:, num_tokens:] + ff_output[:, num_tokens:] * gate_mlp
814+
hidden_states_zero = hidden_states[:, :num_tokens] + ff_output[:, :num_tokens] * tr_gate_mlp.unsqueeze(1)
815+
hidden_states_orig = hidden_states[:, num_tokens:] + ff_output[:, num_tokens:] * gate_mlp.unsqueeze(1)
686816
hidden_states = torch.cat([hidden_states_zero, hidden_states_orig], dim=1)
687817
encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
688818

@@ -802,14 +932,24 @@ def __init__(
802932
)
803933

804934
# 4. Single stream transformer blocks
805-
self.single_transformer_blocks = nn.ModuleList(
806-
[
807-
HunyuanVideoSingleTransformerBlock(
808-
num_attention_heads, attention_head_dim, mlp_ratio=mlp_ratio, qk_norm=qk_norm
809-
)
810-
for _ in range(num_single_layers)
811-
]
812-
)
935+
if image_condition_type == "token_replace":
936+
self.single_transformer_blocks = nn.ModuleList(
937+
[
938+
HunyuanVideoTokenReplaceSingleTransformerBlock(
939+
num_attention_heads, attention_head_dim, mlp_ratio=mlp_ratio, qk_norm=qk_norm
940+
)
941+
for _ in range(num_single_layers)
942+
]
943+
)
944+
else:
945+
self.single_transformer_blocks = nn.ModuleList(
946+
[
947+
HunyuanVideoSingleTransformerBlock(
948+
num_attention_heads, attention_head_dim, mlp_ratio=mlp_ratio, qk_norm=qk_norm
949+
)
950+
for _ in range(num_single_layers)
951+
]
952+
)
813953

814954
# 5. Output projection
815955
self.norm_out = AdaLayerNormContinuous(inner_dim, inner_dim, elementwise_affine=False, eps=1e-6)
@@ -957,6 +1097,8 @@ def forward(
9571097
temb,
9581098
attention_mask,
9591099
image_rotary_emb,
1100+
token_replace_emb,
1101+
first_frame_num_tokens,
9601102
)
9611103

9621104
else:
@@ -973,7 +1115,13 @@ def forward(
9731115

9741116
for block in self.single_transformer_blocks:
9751117
hidden_states, encoder_hidden_states = block(
976-
hidden_states, encoder_hidden_states, temb, attention_mask, image_rotary_emb
1118+
hidden_states,
1119+
encoder_hidden_states,
1120+
temb,
1121+
attention_mask,
1122+
image_rotary_emb,
1123+
token_replace_emb,
1124+
first_frame_num_tokens,
9771125
)
9781126

9791127
# 5. Output projection

src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -72,10 +72,10 @@
7272
7373
>>> # If using hunyuanvideo-community/HunyuanVideo-I2V
7474
>>> output = pipe(image=image, prompt=prompt, guidance_scale=6.0).frames[0]
75-
75+
7676
>>> # If using hunyuanvideo-community/HunyuanVideo-I2V-33ch
7777
>>> output = pipe(image=image, prompt=prompt, guidance_scale=1.0, true_cfg_scale=1.0).frames[0]
78-
78+
7979
>>> export_to_video(output, "output.mp4", fps=15)
8080
```
8181
"""
@@ -506,15 +506,14 @@ def prepare_latents(
506506
image = image.unsqueeze(2) # [B, C, 1, H, W]
507507
if isinstance(generator, list):
508508
image_latents = [
509-
retrieve_latents(self.vae.encode(image[i].unsqueeze(0)), generator[i]) for i in range(batch_size)
509+
retrieve_latents(self.vae.encode(image[i].unsqueeze(0)), generator[i], "argmax")
510+
for i in range(batch_size)
510511
]
511512
else:
512-
image_latents = [retrieve_latents(self.vae.encode(img.unsqueeze(0)), generator) for img in image]
513+
image_latents = [retrieve_latents(self.vae.encode(img.unsqueeze(0)), generator, "argmax") for img in image]
513514

514515
image_latents = torch.cat(image_latents, dim=0).to(dtype) * self.vae_scaling_factor
515-
516-
if image_condition_type == "latent_concat":
517-
image_latents = image_latents.repeat(1, 1, num_latent_frames, 1, 1)
516+
image_latents = image_latents.repeat(1, 1, num_latent_frames, 1, 1)
518517

519518
if latents is None:
520519
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
@@ -524,6 +523,9 @@ def prepare_latents(
524523
t = torch.tensor([0.999]).to(device=device)
525524
latents = latents * t + image_latents * (1 - t)
526525

526+
if image_condition_type == "token_replace":
527+
image_latents = image_latents[:, :, :1]
528+
527529
return latents, image_latents
528530

529531
def enable_vae_slicing(self):
@@ -817,7 +819,9 @@ def __call__(
817819
# 6. Prepare guidance condition
818820
guidance = None
819821
if self.transformer.config.guidance_embeds:
820-
guidance = torch.tensor([guidance_scale] * latents.shape[0], dtype=transformer_dtype, device=device) * 1000.0
822+
guidance = (
823+
torch.tensor([guidance_scale] * latents.shape[0], dtype=transformer_dtype, device=device) * 1000.0
824+
)
821825

822826
# 7. Denoising loop
823827
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
@@ -889,7 +893,7 @@ def __call__(
889893
self._current_timestep = None
890894

891895
if not output_type == "latent":
892-
latents = latents.to(self.vae.dtype) / self.vae.config.scaling_factor
896+
latents = latents.to(self.vae.dtype) / self.vae_scaling_factor
893897
video = self.vae.decode(latents, return_dict=False)[0]
894898
if image_condition_type == "latent_concat":
895899
video = video[:, :, 4:, :, :]

0 commit comments

Comments
 (0)