Skip to content

Commit a9bd457

Browse files
committed
cleanup a bit
1 parent 1e80f7c commit a9bd457

File tree

1 file changed

+99
-115
lines changed

1 file changed

+99
-115
lines changed

src/diffusers/models/transformers/transformer_hunyuan_video.py

Lines changed: 99 additions & 115 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
import math
16+
from functools import partial
1617
from typing import Dict, List, Optional, Tuple, Union
1718

1819
import torch
@@ -21,8 +22,9 @@
2122
from einops import rearrange
2223

2324
from ...configuration_utils import ConfigMixin, register_to_config
25+
from ...utils import is_torch_version
2426
from ..attention import FeedForward
25-
from ..attention_processor import Attention
27+
from ..attention_processor import Attention, AttentionProcessor
2628
from ..modeling_outputs import Transformer2DModelOutput
2729
from ..modeling_utils import ModelMixin
2830
from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle, RMSNorm
@@ -514,13 +516,13 @@ def __init__(
514516
self,
515517
num_attention_heads: int,
516518
attention_head_dim: int,
517-
mlp_width_ratio: float = 4.0,
519+
mlp_ratio: float = 4.0,
518520
qk_norm: str = "rms_norm",
519521
):
520522
super().__init__()
521523

522524
hidden_size = num_attention_heads * attention_head_dim
523-
mlp_hidden_dim = int(hidden_size * mlp_width_ratio)
525+
mlp_hidden_dim = int(hidden_size * mlp_ratio)
524526

525527
self.hidden_size = hidden_size
526528
self.heads_num = num_attention_heads
@@ -549,7 +551,6 @@ def forward(
549551
hidden_states: torch.Tensor,
550552
encoder_hidden_states: torch.Tensor,
551553
temb: torch.Tensor,
552-
txt_len: int,
553554
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
554555
) -> torch.Tensor:
555556
text_seq_length = encoder_hidden_states.shape[1]
@@ -565,25 +566,6 @@ def forward(
565566
norm_hidden_states[:, -text_seq_length:, :],
566567
)
567568

568-
# qkv, mlp = torch.split(self.linear1(norm_hidden_states), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
569-
# q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num)
570-
571-
# # Apply QK-Norm if needed.
572-
# q = self.q_norm(q).to(v)
573-
# k = self.k_norm(k).to(v)
574-
575-
# if image_rotary_emb is not None:
576-
# img_q, txt_q = q[:, :-txt_len, :, :], q[:, -txt_len:, :, :]
577-
# img_k, txt_k = k[:, :-txt_len, :, :], k[:, -txt_len:, :, :]
578-
# img_qq, img_kk = apply_rotary_emb(img_q, img_k, image_rotary_emb, head_first=False)
579-
# img_q, img_k = img_qq, img_kk
580-
# q = torch.cat((img_q, txt_q), dim=1)
581-
# k = torch.cat((img_k, txt_k), dim=1)
582-
583-
# attn = attention(q, k, v)
584-
# output = self.linear2(torch.cat((attn, self.act_mlp(mlp)), 2))
585-
# output = hidden_states + output * gate.unsqueeze(1)
586-
587569
attn_output, context_attn_output = self.attn(
588570
hidden_states=norm_hidden_states,
589571
encoder_hidden_states=norm_encoder_hidden_states,
@@ -607,7 +589,7 @@ def __init__(
607589
self,
608590
hidden_size: int,
609591
heads_num: int,
610-
mlp_width_ratio: float,
592+
mlp_ratio: float,
611593
qk_norm: str = "rms_norm",
612594
):
613595
super().__init__()
@@ -629,10 +611,10 @@ def __init__(
629611
self.txt_attn_proj = nn.Linear(hidden_size, hidden_size)
630612

631613
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
632-
self.ff = FeedForward(hidden_size, mult=mlp_width_ratio, activation_fn="gelu-approximate")
614+
self.ff = FeedForward(hidden_size, mult=mlp_ratio, activation_fn="gelu-approximate")
633615

634616
self.norm2_context = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
635-
self.ff_context = FeedForward(hidden_size, mult=mlp_width_ratio, activation_fn="gelu-approximate")
617+
self.ff_context = FeedForward(hidden_size, mult=mlp_ratio, activation_fn="gelu-approximate")
636618

637619
def forward(
638620
self,
@@ -690,70 +672,23 @@ def forward(
690672

691673

692674
class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin):
693-
"""
694-
HunyuanVideo Transformer backbone
695-
696-
Inherited from ModelMixin and ConfigMixin for compatibility with diffusers' sampler StableDiffusionPipeline.
697-
698-
Reference: [1] Flux.1: https://github.com/black-forest-labs/flux [2] MMDiT: http://arxiv.org/abs/2403.03206
699-
700-
Parameters ---------- args: argparse.Namespace
701-
The arguments parsed by argparse.
702-
patch_size: list
703-
The size of the patch.
704-
in_channels: int
705-
The number of input channels.
706-
out_channels: int
707-
The number of output channels.
708-
hidden_size: int
709-
The hidden size of the transformer backbone.
710-
heads_num: int
711-
The number of attention heads.
712-
mlp_width_ratio: float
713-
The ratio of the hidden size of the MLP in the transformer block.
714-
mlp_act_type: str
715-
The activation function of the MLP in the transformer block.
716-
depth_double_blocks: int
717-
The number of transformer blocks in the double blocks.
718-
depth_single_blocks: int
719-
The number of transformer blocks in the single blocks.
720-
rope_dim_list: list
721-
The dimension of the rotary embedding for t, h, w.
722-
qkv_bias: bool
723-
Whether to use bias in the qkv linear layer.
724-
qk_norm: bool
725-
Whether to use qk norm.
726-
qk_norm_type: str
727-
The type of qk norm.
728-
guidance_embed: bool
729-
Whether to use guidance embedding for distillation.
730-
text_projection: str
731-
The type of the text projection, default is single_refiner.
732-
use_attention_mask: bool
733-
Whether to use attention mask for text encoder.
734-
dtype: torch.dtype
735-
The dtype of the model.
736-
device: torch.device
737-
The device of the model.
738-
"""
739-
740675
@register_to_config
741676
def __init__(
742677
self,
743-
patch_size: int = 2,
744-
patch_size_t: int = 1,
745678
in_channels: int = 16,
746679
out_channels: int = 16,
747680
num_attention_heads: int = 24,
748681
attention_head_dim: int = 128,
749-
mlp_width_ratio: float = 4.0,
750-
mm_double_blocks_depth: int = 20,
751-
mm_single_blocks_depth: int = 40,
682+
num_layers: int = 20,
683+
num_single_layers: int = 40,
684+
mlp_ratio: float = 4.0,
685+
patch_size: int = 2,
686+
patch_size_t: int = 1,
752687
rope_dim_list: List[int] = [16, 56, 56],
753688
qk_norm: str = "rms_norm",
754689
guidance_embed: bool = True,
755-
text_states_dim: int = 4096,
756-
text_states_dim_2: int = 768,
690+
text_embed_dim: int = 4096,
691+
text_embed_dim_2: int = 768,
757692
) -> None:
758693
super().__init__()
759694

@@ -762,51 +697,106 @@ def __init__(
762697
self.guidance_embed = guidance_embed
763698
self.rope_dim_list = rope_dim_list
764699

765-
if sum(rope_dim_list) != attention_head_dim:
766-
raise ValueError(f"Got {rope_dim_list} but expected positional dim {attention_head_dim}")
767-
768700
# image projection
769701
self.img_in = PatchEmbed((patch_size_t, patch_size, patch_size), in_channels, inner_dim)
770702

771703
# text projection
772-
self.txt_in = SingleTokenRefiner(text_states_dim, inner_dim, num_attention_heads, depth=2)
704+
self.txt_in = SingleTokenRefiner(text_embed_dim, inner_dim, num_attention_heads, depth=2)
773705

774706
# time modulation
775707
self.time_in = TimestepEmbedder(inner_dim, nn.SiLU)
776708

777709
# text modulation
778-
self.vector_in = MLPEmbedder(text_states_dim_2, inner_dim)
710+
self.vector_in = MLPEmbedder(text_embed_dim_2, inner_dim)
779711

780712
# guidance modulation
781713
self.guidance_in = TimestepEmbedder(inner_dim, nn.SiLU)
782714

783715
self.transformer_blocks = nn.ModuleList(
784716
[
785-
HunyuanVideoTransformerBlock(
786-
inner_dim,
787-
num_attention_heads,
788-
mlp_width_ratio=mlp_width_ratio,
789-
qk_norm=qk_norm,
790-
)
791-
for _ in range(mm_double_blocks_depth)
717+
HunyuanVideoTransformerBlock(inner_dim, num_attention_heads, mlp_ratio=mlp_ratio, qk_norm=qk_norm)
718+
for _ in range(num_layers)
792719
]
793720
)
794721

795722
self.single_transformer_blocks = nn.ModuleList(
796723
[
797724
HunyuanVideoSingleTransformerBlock(
798-
num_attention_heads,
799-
attention_head_dim,
800-
mlp_width_ratio=mlp_width_ratio,
801-
qk_norm=qk_norm,
725+
num_attention_heads, attention_head_dim, mlp_ratio=mlp_ratio, qk_norm=qk_norm
802726
)
803-
for _ in range(mm_single_blocks_depth)
727+
for _ in range(num_single_layers)
804728
]
805729
)
806730

807731
self.norm_out = AdaLayerNormContinuous(inner_dim, inner_dim, elementwise_affine=False, eps=1e-6)
808732
self.proj_out = nn.Linear(inner_dim, patch_size_t * patch_size * patch_size * out_channels)
809733

734+
self.gradient_checkpointing = False
735+
736+
@property
737+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
738+
def attn_processors(self) -> Dict[str, AttentionProcessor]:
739+
r"""
740+
Returns:
741+
`dict` of attention processors: A dictionary containing all attention processors used in the model with
742+
indexed by its weight name.
743+
"""
744+
# set recursively
745+
processors = {}
746+
747+
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
748+
if hasattr(module, "get_processor"):
749+
processors[f"{name}.processor"] = module.get_processor()
750+
751+
for sub_name, child in module.named_children():
752+
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
753+
754+
return processors
755+
756+
for name, module in self.named_children():
757+
fn_recursive_add_processors(name, module, processors)
758+
759+
return processors
760+
761+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
762+
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
763+
r"""
764+
Sets the attention processor to use to compute attention.
765+
766+
Parameters:
767+
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
768+
The instantiated processor class or a dictionary of processor classes that will be set as the processor
769+
for **all** `Attention` layers.
770+
771+
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
772+
processor. This is strongly recommended when setting trainable attention processors.
773+
774+
"""
775+
count = len(self.attn_processors.keys())
776+
777+
if isinstance(processor, dict) and len(processor) != count:
778+
raise ValueError(
779+
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
780+
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
781+
)
782+
783+
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
784+
if hasattr(module, "set_processor"):
785+
if not isinstance(processor, dict):
786+
module.set_processor(processor)
787+
else:
788+
module.set_processor(processor.pop(f"{name}.processor"))
789+
790+
for sub_name, child in module.named_children():
791+
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
792+
793+
for name, module in self.named_children():
794+
fn_recursive_attn_processor(name, module, processor)
795+
796+
def _set_gradient_checkpointing(self, module, value=False):
797+
if hasattr(module, "gradient_checkpointing"):
798+
module.gradient_checkpointing = value
799+
810800
def forward(
811801
self,
812802
hidden_states: torch.Tensor,
@@ -843,29 +833,23 @@ def forward(
843833
hidden_states = self.img_in(hidden_states)
844834
encoder_hidden_states = self.txt_in(encoder_hidden_states, timestep, encoder_attention_mask)
845835

846-
txt_seq_len = encoder_hidden_states.shape[1]
836+
use_reentrant = is_torch_version(">=", "1.11.0")
837+
block_forward = (
838+
partial(torch.utils.checkpoint.checkpoint, use_reentrant=use_reentrant)
839+
if torch.is_grad_enabled() and self.gradient_checkpointing
840+
else lambda x: x
841+
)
847842

848843
freqs_cis = (freqs_cos, freqs_sin) if freqs_cos is not None else None
849844
for _, block in enumerate(self.transformer_blocks):
850-
double_block_args = [
851-
hidden_states,
852-
encoder_hidden_states,
853-
temb,
854-
freqs_cis,
855-
]
856-
857-
hidden_states, encoder_hidden_states = block(*double_block_args)
845+
hidden_states, encoder_hidden_states = block_forward(block)(
846+
hidden_states, encoder_hidden_states, temb, freqs_cis
847+
)
858848

859849
for block in self.single_transformer_blocks:
860-
single_block_args = [
861-
hidden_states,
862-
encoder_hidden_states,
863-
temb,
864-
txt_seq_len,
865-
(freqs_cos, freqs_sin),
866-
]
867-
868-
hidden_states, encoder_hidden_states = block(*single_block_args)
850+
hidden_states, encoder_hidden_states = block_forward(block)(
851+
hidden_states, encoder_hidden_states, temb, freqs_cis
852+
)
869853

870854
hidden_states = self.norm_out(hidden_states, temb)
871855
hidden_states = self.proj_out(hidden_states)

0 commit comments

Comments
 (0)