Skip to content

Commit e559ae8

Browse files
authored
Update transformer_hunyuan_video.py
1 parent 01780c3 commit e559ae8

File tree

1 file changed

+141
-0
lines changed

1 file changed

+141
-0
lines changed

src/diffusers/models/transformers/transformer_hunyuan_video.py

Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,103 @@ def __call__(
136136

137137
return hidden_states, encoder_hidden_states
138138

139+
class FusedHunyuanVideoAttnProcessor2_0:
140+
def __init__(self):
141+
if not hasattr(F, "scaled_dot_product_attention"):
142+
raise ImportError(
143+
"FusedHunyuanVideoAttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0."
144+
)
145+
146+
def __call__(
147+
self,
148+
attn: Attention,
149+
hidden_states: torch.Tensor,
150+
encoder_hidden_states: Optional[torch.Tensor] = None,
151+
attention_mask: Optional[torch.Tensor] = None,
152+
image_rotary_emb: Optional[torch.Tensor] = None,
153+
) -> torch.Tensor:
154+
if attn.add_q_proj is None and encoder_hidden_states is not None:
155+
hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1)
156+
157+
# 1. QKV projections
158+
qkv = attn.to_qkv(hidden_states)
159+
split_size = qkv.shape[-1] // 3
160+
query, key, value = torch.split(qkv, split_size, dim=-1)
161+
162+
query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2)
163+
key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
164+
value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2)
165+
166+
# 2. QK normalization
167+
if attn.norm_q is not None:
168+
query = attn.norm_q(query)
169+
if attn.norm_k is not None:
170+
key = attn.norm_k(key)
171+
172+
# 3. Rotational positional embeddings applied to latent stream
173+
if image_rotary_emb is not None:
174+
from ..embeddings import apply_rotary_emb
175+
176+
if attn.add_q_proj is None and encoder_hidden_states is not None:
177+
query = torch.cat(
178+
[
179+
apply_rotary_emb(query[:, :, : -encoder_hidden_states.shape[1]], image_rotary_emb),
180+
query[:, :, -encoder_hidden_states.shape[1] :],
181+
],
182+
dim=2,
183+
)
184+
key = torch.cat(
185+
[
186+
apply_rotary_emb(key[:, :, : -encoder_hidden_states.shape[1]], image_rotary_emb),
187+
key[:, :, -encoder_hidden_states.shape[1] :],
188+
],
189+
dim=2,
190+
)
191+
else:
192+
query = apply_rotary_emb(query, image_rotary_emb)
193+
key = apply_rotary_emb(key, image_rotary_emb)
194+
195+
# 4. Encoder condition QKV projection and normalization
196+
if attn.add_q_proj is not None and encoder_hidden_states is not None:
197+
encoder_query = attn.add_q_proj(encoder_hidden_states)
198+
encoder_key = attn.add_k_proj(encoder_hidden_states)
199+
encoder_value = attn.add_v_proj(encoder_hidden_states)
200+
201+
encoder_query = encoder_query.unflatten(2, (attn.heads, -1)).transpose(1, 2)
202+
encoder_key = encoder_key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
203+
encoder_value = encoder_value.unflatten(2, (attn.heads, -1)).transpose(1, 2)
204+
205+
if attn.norm_added_q is not None:
206+
encoder_query = attn.norm_added_q(encoder_query)
207+
if attn.norm_added_k is not None:
208+
encoder_key = attn.norm_added_k(encoder_key)
209+
210+
query = torch.cat([query, encoder_query], dim=2)
211+
key = torch.cat([key, encoder_key], dim=2)
212+
value = torch.cat([value, encoder_value], dim=2)
213+
214+
# 5. Attention
215+
hidden_states = F.scaled_dot_product_attention(
216+
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
217+
)
218+
hidden_states = hidden_states.transpose(1, 2).flatten(2, 3)
219+
hidden_states = hidden_states.to(query.dtype)
220+
221+
# 6. Output projection
222+
if encoder_hidden_states is not None:
223+
hidden_states, encoder_hidden_states = (
224+
hidden_states[:, : -encoder_hidden_states.shape[1]],
225+
hidden_states[:, -encoder_hidden_states.shape[1] :],
226+
)
227+
228+
if getattr(attn, "to_out", None) is not None:
229+
hidden_states = attn.to_out[0](hidden_states)
230+
hidden_states = attn.to_out[1](hidden_states)
231+
232+
if getattr(attn, "to_add_out", None) is not None:
233+
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
234+
235+
return hidden_states, encoder_hidden_states
139236

140237
class HunyuanVideoPatchEmbed(nn.Module):
141238
def __init__(
@@ -214,6 +311,10 @@ def forward(
214311
)
215312

216313
gate_msa, gate_mlp = self.norm_out(temb)
314+
# QKV fusion fix
315+
if isinstance(attn_output, tuple):
316+
attn_output = attn_output[0]
317+
217318
hidden_states = hidden_states + attn_output * gate_msa
218319

219320
ff_output = self.ff(self.norm2(hidden_states))
@@ -604,6 +705,46 @@ def __init__(
604705

605706
self.gradient_checkpointing = False
606707

708+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedHunyuanVideoAttnProcessor2_0
709+
def fuse_qkv_projections(self):
710+
"""
711+
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
712+
are fused. For cross-attention modules, key and value projection matrices are fused.
713+
714+
<Tip warning={true}>
715+
716+
This API is 🧪 experimental.
717+
718+
</Tip>
719+
"""
720+
self.original_attn_processors = None
721+
722+
for _, attn_processor in self.attn_processors.items():
723+
if "Added" in str(attn_processor.__class__.__name__):
724+
raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
725+
726+
self.original_attn_processors = self.attn_processors
727+
728+
for module in self.modules():
729+
if isinstance(module, Attention):
730+
module.fuse_projections(fuse=True)
731+
732+
self.set_attn_processor(FusedHunyuanVideoAttnProcessor2_0())
733+
734+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
735+
def unfuse_qkv_projections(self):
736+
"""Disables the fused QKV projection if enabled.
737+
738+
<Tip warning={true}>
739+
740+
This API is 🧪 experimental.
741+
742+
</Tip>
743+
744+
"""
745+
if self.original_attn_processors is not None:
746+
self.set_attn_processor(self.original_attn_processors)
747+
607748
@property
608749
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
609750
def attn_processors(self) -> Dict[str, AttentionProcessor]:

0 commit comments

Comments
 (0)