Skip to content

Commit f637479

Browse files
committed
refactor double transformer block attention
1 parent ca4c81e commit f637479

File tree

2 files changed

+84
-94
lines changed

2 files changed

+84
-94
lines changed

scripts/convert_hunyuan_video_to_diffusers.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,22 @@ def remap_norm_scale_shift_(key, state_dict):
1414
state_dict[key.replace("final_layer.adaLN_modulation.1", "norm_out.linear")] = new_weight
1515

1616

17+
def remap_img_attn_qkv_(key, state_dict):
18+
weight = state_dict.pop(key)
19+
to_q, to_k, to_v = weight.chunk(3, dim=0)
20+
state_dict[key.replace("img_attn_qkv", "attn.to_q")] = to_q
21+
state_dict[key.replace("img_attn_qkv", "attn.to_k")] = to_k
22+
state_dict[key.replace("img_attn_qkv", "attn.to_v")] = to_v
23+
24+
25+
def remap_txt_attn_qkv_(key, state_dict):
26+
weight = state_dict.pop(key)
27+
to_q, to_k, to_v = weight.chunk(3, dim=0)
28+
state_dict[key.replace("txt_attn_qkv", "attn.add_q_proj")] = to_q
29+
state_dict[key.replace("txt_attn_qkv", "attn.add_k_proj")] = to_k
30+
state_dict[key.replace("txt_attn_qkv", "attn.add_v_proj")] = to_v
31+
32+
1733
def remap_single_transformer_blocks_(key, state_dict):
1834
hidden_size = 3072
1935

@@ -53,6 +69,12 @@ def remap_single_transformer_blocks_(key, state_dict):
5369
# "vector_in.in_layer": "time_text_embed.text_embedder.linear_1",
5470
# "vector_in.out_layer": "time_text_embed.text_embedder.linear_2",
5571
"double_blocks": "transformer_blocks",
72+
"img_attn_q_norm": "attn.norm_q",
73+
"img_attn_k_norm": "attn.norm_k",
74+
"img_attn_proj": "attn.to_out.0",
75+
"txt_attn_q_norm": "attn.norm_added_q",
76+
"txt_attn_k_norm": "attn.norm_added_k",
77+
"txt_attn_proj": "attn.to_add_out",
5678
"img_mod.linear": "norm1.linear",
5779
"img_norm1": "norm1.norm",
5880
"img_norm2": "norm2",
@@ -71,6 +93,8 @@ def remap_single_transformer_blocks_(key, state_dict):
7193

7294
TRANSFORMER_SPECIAL_KEYS_REMAP = {
7395
"final_layer.adaLN_modulation.1": remap_norm_scale_shift_,
96+
"img_attn_qkv": remap_img_attn_qkv_,
97+
"txt_attn_qkv": remap_txt_attn_qkv_,
7498
"single_blocks": remap_single_transformer_blocks_,
7599
}
76100

src/diffusers/models/transformers/transformer_hunyuan_video.py

Lines changed: 60 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -360,13 +360,13 @@ class IndividualTokenRefinerBlock(nn.Module):
360360
def __init__(
361361
self,
362362
hidden_size,
363-
heads_num,
363+
num_attention_heads: int,
364364
mlp_width_ratio: str = 4.0,
365365
mlp_drop_rate: float = 0.0,
366366
qkv_bias: bool = True,
367-
):
367+
) -> None:
368368
super().__init__()
369-
self.heads_num = heads_num
369+
self.heads_num = num_attention_heads
370370

371371
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6)
372372
self.self_attn_qkv = nn.Linear(hidden_size, hidden_size * 3, bias=qkv_bias)
@@ -383,25 +383,25 @@ def __init__(
383383

384384
def forward(
385385
self,
386-
x: torch.Tensor,
387-
c: torch.Tensor,
388-
attn_mask: torch.Tensor = None,
389-
):
390-
gate_msa, gate_mlp = self.adaLN_modulation(c).chunk(2, dim=1)
386+
hidden_states: torch.Tensor,
387+
temb: torch.Tensor,
388+
attention_mask: Optional[torch.Tensor] = None,
389+
) -> torch.Tensor:
390+
gate_msa, gate_mlp = self.adaLN_modulation(temb).chunk(2, dim=1)
391391

392-
norm_x = self.norm1(x)
392+
norm_x = self.norm1(hidden_states)
393393
qkv = self.self_attn_qkv(norm_x)
394394
q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num)
395395

396396
# Self-Attention
397-
attn = attention(q, k, v, attn_mask=attn_mask)
397+
attn = attention(q, k, v, attn_mask=attention_mask)
398398

399-
x = x + self.self_attn_proj(attn) * gate_msa.unsqueeze(1)
399+
hidden_states = hidden_states + self.self_attn_proj(attn) * gate_msa.unsqueeze(1)
400400

401401
# FFN Layer
402-
x = x + self.mlp(self.norm2(x)) * gate_mlp.unsqueeze(1)
402+
hidden_states = hidden_states + self.mlp(self.norm2(hidden_states)) * gate_mlp.unsqueeze(1)
403403

404-
return x
404+
return hidden_states
405405

406406

407407
class IndividualTokenRefiner(nn.Module):
@@ -419,7 +419,7 @@ def __init__(
419419
[
420420
IndividualTokenRefinerBlock(
421421
hidden_size=hidden_size,
422-
heads_num=heads_num,
422+
num_attention_heads=heads_num,
423423
mlp_width_ratio=mlp_width_ratio,
424424
mlp_drop_rate=mlp_drop_rate,
425425
qkv_bias=qkv_bias,
@@ -430,41 +430,34 @@ def __init__(
430430

431431
def forward(
432432
self,
433-
x: torch.Tensor,
434-
c: torch.LongTensor,
435-
mask: Optional[torch.Tensor] = None,
433+
hidden_states: torch.Tensor,
434+
temb: torch.Tensor,
435+
attention_mask: Optional[torch.Tensor] = None,
436436
):
437437
self_attn_mask = None
438-
if mask is not None:
439-
batch_size = mask.shape[0]
440-
seq_len = mask.shape[1]
441-
mask = mask.to(x.device).bool()
442-
# batch_size x 1 x seq_len x seq_len
443-
self_attn_mask_1 = mask.view(batch_size, 1, 1, seq_len).repeat(1, 1, seq_len, 1)
444-
# batch_size x 1 x seq_len x seq_len
438+
if attention_mask is not None:
439+
batch_size = attention_mask.shape[0]
440+
seq_len = attention_mask.shape[1]
441+
attention_mask = attention_mask.to(hidden_states.device).bool()
442+
self_attn_mask_1 = attention_mask.view(batch_size, 1, 1, seq_len).repeat(1, 1, seq_len, 1)
445443
self_attn_mask_2 = self_attn_mask_1.transpose(2, 3)
446-
# batch_size x 1 x seq_len x seq_len, 1 for broadcasting of heads_num
447444
self_attn_mask = (self_attn_mask_1 & self_attn_mask_2).bool()
448-
# avoids self-attention weight being NaN for padding tokens
449445
self_attn_mask[:, :, :, 0] = True
450446

451447
for block in self.blocks:
452-
x = block(x, c, self_attn_mask)
453-
return x
448+
hidden_states = block(hidden_states, temb, self_attn_mask)
449+
450+
return hidden_states
454451

455452

456453
class SingleTokenRefiner(nn.Module):
457-
"""
458-
A single token refiner block for llm text embedding refine.
459-
"""
460-
461454
def __init__(
462455
self,
463-
in_channels,
464-
hidden_size,
465-
num_attention_heads,
466-
depth,
467-
mlp_width_ratio: float = 4.0,
456+
in_channels: int,
457+
hidden_size: int,
458+
num_attention_heads: int,
459+
depth: int,
460+
mlp_ratio: float = 4.0,
468461
mlp_drop_rate: float = 0.0,
469462
qkv_bias: bool = True,
470463
):
@@ -481,7 +474,7 @@ def __init__(
481474
hidden_size=hidden_size,
482475
heads_num=num_attention_heads,
483476
depth=depth,
484-
mlp_width_ratio=mlp_width_ratio,
477+
mlp_width_ratio=mlp_ratio,
485478
mlp_drop_rate=mlp_drop_rate,
486479
qkv_bias=qkv_bias,
487480
)
@@ -587,28 +580,31 @@ def forward(
587580
class HunyuanVideoTransformerBlock(nn.Module):
588581
def __init__(
589582
self,
590-
hidden_size: int,
591-
heads_num: int,
583+
num_attention_heads: int,
584+
attention_head_dim: int,
592585
mlp_ratio: float,
593586
qk_norm: str = "rms_norm",
594-
):
587+
) -> None:
595588
super().__init__()
596589

597-
self.heads_num = heads_num
598-
head_dim = hidden_size // heads_num
590+
hidden_size = num_attention_heads * attention_head_dim
599591

600592
self.norm1 = AdaLayerNormZero(hidden_size, norm_type="layer_norm")
601593
self.norm1_context = AdaLayerNormZero(hidden_size, norm_type="layer_norm")
602594

603-
self.img_attn_qkv = nn.Linear(hidden_size, hidden_size * 3)
604-
self.img_attn_q_norm = RMSNorm(head_dim, elementwise_affine=True, eps=1e-6)
605-
self.img_attn_k_norm = RMSNorm(head_dim, elementwise_affine=True, eps=1e-6)
606-
self.img_attn_proj = nn.Linear(hidden_size, hidden_size)
607-
608-
self.txt_attn_qkv = nn.Linear(hidden_size, hidden_size * 3)
609-
self.txt_attn_q_norm = RMSNorm(head_dim, elementwise_affine=True, eps=1e-6)
610-
self.txt_attn_k_norm = RMSNorm(head_dim, elementwise_affine=True, eps=1e-6)
611-
self.txt_attn_proj = nn.Linear(hidden_size, hidden_size)
595+
self.attn = Attention(
596+
query_dim=hidden_size,
597+
cross_attention_dim=None,
598+
added_kv_proj_dim=hidden_size,
599+
dim_head=attention_head_dim,
600+
heads=num_attention_heads,
601+
out_dim=hidden_size,
602+
context_pre_only=False,
603+
bias=True,
604+
processor=HunyuanVideoAttnProcessor2_0(),
605+
qk_norm=qk_norm,
606+
eps=1e-6,
607+
)
612608

613609
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
614610
self.ff = FeedForward(hidden_size, mult=mlp_ratio, activation_fn="gelu-approximate")
@@ -627,35 +623,15 @@ def forward(
627623
norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
628624
encoder_hidden_states, emb=temb
629625
)
630-
631-
img_qkv = self.img_attn_qkv(norm_hidden_states)
632-
img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num)
633-
# Apply QK-Norm if needed
634-
img_q = self.img_attn_q_norm(img_q).to(img_v)
635-
img_k = self.img_attn_k_norm(img_k).to(img_v)
636-
637-
# Apply RoPE if needed.
638-
if freqs_cis is not None:
639-
img_qq, img_kk = apply_rotary_emb(img_q, img_k, freqs_cis, head_first=False)
640-
assert (
641-
img_qq.shape == img_q.shape and img_kk.shape == img_k.shape
642-
), f"img_kk: {img_qq.shape}, img_q: {img_q.shape}, img_kk: {img_kk.shape}, img_k: {img_k.shape}"
643-
img_q, img_k = img_qq, img_kk
644-
645-
txt_qkv = self.txt_attn_qkv(norm_encoder_hidden_states)
646-
txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num)
647-
txt_q = self.txt_attn_q_norm(txt_q).to(txt_v)
648-
txt_k = self.txt_attn_k_norm(txt_k).to(txt_v)
649-
650-
q = torch.cat((img_q, txt_q), dim=1)
651-
k = torch.cat((img_k, txt_k), dim=1)
652-
v = torch.cat((img_v, txt_v), dim=1)
653-
attn = attention(q, k, v)
654-
655-
img_attn, txt_attn = attn[:, : hidden_states.shape[1]], attn[:, hidden_states.shape[1] :]
656-
657-
hidden_states = hidden_states + self.img_attn_proj(img_attn) * gate_msa.unsqueeze(1)
658-
encoder_hidden_states = encoder_hidden_states + self.txt_attn_proj(txt_attn) * c_gate_msa.unsqueeze(1)
626+
627+
img_attn, txt_attn = self.attn(
628+
hidden_states=norm_hidden_states,
629+
encoder_hidden_states=norm_encoder_hidden_states,
630+
image_rotary_emb=freqs_cis,
631+
)
632+
633+
hidden_states = hidden_states + img_attn * gate_msa.unsqueeze(1)
634+
encoder_hidden_states = encoder_hidden_states + txt_attn * c_gate_msa.unsqueeze(1)
659635

660636
norm_hidden_states = self.norm2(hidden_states)
661637
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
@@ -686,15 +662,14 @@ def __init__(
686662
patch_size_t: int = 1,
687663
rope_dim_list: List[int] = [16, 56, 56],
688664
qk_norm: str = "rms_norm",
689-
guidance_embed: bool = True,
665+
guidance_embeds: bool = True,
690666
text_embed_dim: int = 4096,
691667
text_embed_dim_2: int = 768,
692668
) -> None:
693669
super().__init__()
694670

695671
inner_dim = num_attention_heads * attention_head_dim
696672
out_channels = out_channels or in_channels
697-
self.guidance_embed = guidance_embed
698673
self.rope_dim_list = rope_dim_list
699674

700675
# image projection
@@ -714,7 +689,7 @@ def __init__(
714689

715690
self.transformer_blocks = nn.ModuleList(
716691
[
717-
HunyuanVideoTransformerBlock(inner_dim, num_attention_heads, mlp_ratio=mlp_ratio, qk_norm=qk_norm)
692+
HunyuanVideoTransformerBlock(num_attention_heads, attention_head_dim, mlp_ratio=mlp_ratio, qk_norm=qk_norm)
718693
for _ in range(num_layers)
719694
]
720695
)
@@ -816,18 +791,9 @@ def forward(
816791
post_patch_height = height // p
817792
post_patch_width = width // p
818793

819-
# Prepare modulation vectors.
820794
temb = self.time_in(timestep)
821-
822-
# text modulation
823795
temb = temb + self.vector_in(encoder_hidden_states_2)
824-
825-
# guidance modulation
826-
if self.guidance_embed:
827-
if guidance is None:
828-
raise ValueError("Didn't get guidance strength for guidance distilled model.")
829-
830-
temb = temb + self.guidance_in(guidance)
796+
temb = temb + self.guidance_in(guidance)
831797

832798
# Embed image and text.
833799
hidden_states = self.img_in(hidden_states)

0 commit comments

Comments
 (0)