1313# limitations under the License.
1414
1515import math
16+ from functools import partial
1617from typing import Dict , List , Optional , Tuple , Union
1718
1819import torch
2122from einops import rearrange
2223
2324from ...configuration_utils import ConfigMixin , register_to_config
25+ from ...utils import is_torch_version
2426from ..attention import FeedForward
25- from ..attention_processor import Attention
27+ from ..attention_processor import Attention , AttentionProcessor
2628from ..modeling_outputs import Transformer2DModelOutput
2729from ..modeling_utils import ModelMixin
2830from ..normalization import AdaLayerNormContinuous , AdaLayerNormZero , AdaLayerNormZeroSingle , RMSNorm
@@ -514,13 +516,13 @@ def __init__(
514516 self ,
515517 num_attention_heads : int ,
516518 attention_head_dim : int ,
517- mlp_width_ratio : float = 4.0 ,
519+ mlp_ratio : float = 4.0 ,
518520 qk_norm : str = "rms_norm" ,
519521 ):
520522 super ().__init__ ()
521523
522524 hidden_size = num_attention_heads * attention_head_dim
523- mlp_hidden_dim = int (hidden_size * mlp_width_ratio )
525+ mlp_hidden_dim = int (hidden_size * mlp_ratio )
524526
525527 self .hidden_size = hidden_size
526528 self .heads_num = num_attention_heads
@@ -549,7 +551,6 @@ def forward(
549551 hidden_states : torch .Tensor ,
550552 encoder_hidden_states : torch .Tensor ,
551553 temb : torch .Tensor ,
552- txt_len : int ,
553554 image_rotary_emb : Optional [Tuple [torch .Tensor , torch .Tensor ]] = None ,
554555 ) -> torch .Tensor :
555556 text_seq_length = encoder_hidden_states .shape [1 ]
@@ -565,25 +566,6 @@ def forward(
565566 norm_hidden_states [:, - text_seq_length :, :],
566567 )
567568
568- # qkv, mlp = torch.split(self.linear1(norm_hidden_states), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
569- # q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num)
570-
571- # # Apply QK-Norm if needed.
572- # q = self.q_norm(q).to(v)
573- # k = self.k_norm(k).to(v)
574-
575- # if image_rotary_emb is not None:
576- # img_q, txt_q = q[:, :-txt_len, :, :], q[:, -txt_len:, :, :]
577- # img_k, txt_k = k[:, :-txt_len, :, :], k[:, -txt_len:, :, :]
578- # img_qq, img_kk = apply_rotary_emb(img_q, img_k, image_rotary_emb, head_first=False)
579- # img_q, img_k = img_qq, img_kk
580- # q = torch.cat((img_q, txt_q), dim=1)
581- # k = torch.cat((img_k, txt_k), dim=1)
582-
583- # attn = attention(q, k, v)
584- # output = self.linear2(torch.cat((attn, self.act_mlp(mlp)), 2))
585- # output = hidden_states + output * gate.unsqueeze(1)
586-
587569 attn_output , context_attn_output = self .attn (
588570 hidden_states = norm_hidden_states ,
589571 encoder_hidden_states = norm_encoder_hidden_states ,
@@ -607,7 +589,7 @@ def __init__(
607589 self ,
608590 hidden_size : int ,
609591 heads_num : int ,
610- mlp_width_ratio : float ,
592+ mlp_ratio : float ,
611593 qk_norm : str = "rms_norm" ,
612594 ):
613595 super ().__init__ ()
@@ -629,10 +611,10 @@ def __init__(
629611 self .txt_attn_proj = nn .Linear (hidden_size , hidden_size )
630612
631613 self .norm2 = nn .LayerNorm (hidden_size , elementwise_affine = False , eps = 1e-6 )
632- self .ff = FeedForward (hidden_size , mult = mlp_width_ratio , activation_fn = "gelu-approximate" )
614+ self .ff = FeedForward (hidden_size , mult = mlp_ratio , activation_fn = "gelu-approximate" )
633615
634616 self .norm2_context = nn .LayerNorm (hidden_size , elementwise_affine = False , eps = 1e-6 )
635- self .ff_context = FeedForward (hidden_size , mult = mlp_width_ratio , activation_fn = "gelu-approximate" )
617+ self .ff_context = FeedForward (hidden_size , mult = mlp_ratio , activation_fn = "gelu-approximate" )
636618
637619 def forward (
638620 self ,
@@ -690,70 +672,23 @@ def forward(
690672
691673
692674class HunyuanVideoTransformer3DModel (ModelMixin , ConfigMixin ):
693- """
694- HunyuanVideo Transformer backbone
695-
696- Inherited from ModelMixin and ConfigMixin for compatibility with diffusers' sampler StableDiffusionPipeline.
697-
698- Reference: [1] Flux.1: https://github.com/black-forest-labs/flux [2] MMDiT: http://arxiv.org/abs/2403.03206
699-
700- Parameters ---------- args: argparse.Namespace
701- The arguments parsed by argparse.
702- patch_size: list
703- The size of the patch.
704- in_channels: int
705- The number of input channels.
706- out_channels: int
707- The number of output channels.
708- hidden_size: int
709- The hidden size of the transformer backbone.
710- heads_num: int
711- The number of attention heads.
712- mlp_width_ratio: float
713- The ratio of the hidden size of the MLP in the transformer block.
714- mlp_act_type: str
715- The activation function of the MLP in the transformer block.
716- depth_double_blocks: int
717- The number of transformer blocks in the double blocks.
718- depth_single_blocks: int
719- The number of transformer blocks in the single blocks.
720- rope_dim_list: list
721- The dimension of the rotary embedding for t, h, w.
722- qkv_bias: bool
723- Whether to use bias in the qkv linear layer.
724- qk_norm: bool
725- Whether to use qk norm.
726- qk_norm_type: str
727- The type of qk norm.
728- guidance_embed: bool
729- Whether to use guidance embedding for distillation.
730- text_projection: str
731- The type of the text projection, default is single_refiner.
732- use_attention_mask: bool
733- Whether to use attention mask for text encoder.
734- dtype: torch.dtype
735- The dtype of the model.
736- device: torch.device
737- The device of the model.
738- """
739-
740675 @register_to_config
741676 def __init__ (
742677 self ,
743- patch_size : int = 2 ,
744- patch_size_t : int = 1 ,
745678 in_channels : int = 16 ,
746679 out_channels : int = 16 ,
747680 num_attention_heads : int = 24 ,
748681 attention_head_dim : int = 128 ,
749- mlp_width_ratio : float = 4.0 ,
750- mm_double_blocks_depth : int = 20 ,
751- mm_single_blocks_depth : int = 40 ,
682+ num_layers : int = 20 ,
683+ num_single_layers : int = 40 ,
684+ mlp_ratio : float = 4.0 ,
685+ patch_size : int = 2 ,
686+ patch_size_t : int = 1 ,
752687 rope_dim_list : List [int ] = [16 , 56 , 56 ],
753688 qk_norm : str = "rms_norm" ,
754689 guidance_embed : bool = True ,
755- text_states_dim : int = 4096 ,
756- text_states_dim_2 : int = 768 ,
690+ text_embed_dim : int = 4096 ,
691+ text_embed_dim_2 : int = 768 ,
757692 ) -> None :
758693 super ().__init__ ()
759694
@@ -762,51 +697,106 @@ def __init__(
762697 self .guidance_embed = guidance_embed
763698 self .rope_dim_list = rope_dim_list
764699
765- if sum (rope_dim_list ) != attention_head_dim :
766- raise ValueError (f"Got { rope_dim_list } but expected positional dim { attention_head_dim } " )
767-
768700 # image projection
769701 self .img_in = PatchEmbed ((patch_size_t , patch_size , patch_size ), in_channels , inner_dim )
770702
771703 # text projection
772- self .txt_in = SingleTokenRefiner (text_states_dim , inner_dim , num_attention_heads , depth = 2 )
704+ self .txt_in = SingleTokenRefiner (text_embed_dim , inner_dim , num_attention_heads , depth = 2 )
773705
774706 # time modulation
775707 self .time_in = TimestepEmbedder (inner_dim , nn .SiLU )
776708
777709 # text modulation
778- self .vector_in = MLPEmbedder (text_states_dim_2 , inner_dim )
710+ self .vector_in = MLPEmbedder (text_embed_dim_2 , inner_dim )
779711
780712 # guidance modulation
781713 self .guidance_in = TimestepEmbedder (inner_dim , nn .SiLU )
782714
783715 self .transformer_blocks = nn .ModuleList (
784716 [
785- HunyuanVideoTransformerBlock (
786- inner_dim ,
787- num_attention_heads ,
788- mlp_width_ratio = mlp_width_ratio ,
789- qk_norm = qk_norm ,
790- )
791- for _ in range (mm_double_blocks_depth )
717+ HunyuanVideoTransformerBlock (inner_dim , num_attention_heads , mlp_ratio = mlp_ratio , qk_norm = qk_norm )
718+ for _ in range (num_layers )
792719 ]
793720 )
794721
795722 self .single_transformer_blocks = nn .ModuleList (
796723 [
797724 HunyuanVideoSingleTransformerBlock (
798- num_attention_heads ,
799- attention_head_dim ,
800- mlp_width_ratio = mlp_width_ratio ,
801- qk_norm = qk_norm ,
725+ num_attention_heads , attention_head_dim , mlp_ratio = mlp_ratio , qk_norm = qk_norm
802726 )
803- for _ in range (mm_single_blocks_depth )
727+ for _ in range (num_single_layers )
804728 ]
805729 )
806730
807731 self .norm_out = AdaLayerNormContinuous (inner_dim , inner_dim , elementwise_affine = False , eps = 1e-6 )
808732 self .proj_out = nn .Linear (inner_dim , patch_size_t * patch_size * patch_size * out_channels )
809733
734+ self .gradient_checkpointing = False
735+
736+ @property
737+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
738+ def attn_processors (self ) -> Dict [str , AttentionProcessor ]:
739+ r"""
740+ Returns:
741+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
742+ indexed by its weight name.
743+ """
744+ # set recursively
745+ processors = {}
746+
747+ def fn_recursive_add_processors (name : str , module : torch .nn .Module , processors : Dict [str , AttentionProcessor ]):
748+ if hasattr (module , "get_processor" ):
749+ processors [f"{ name } .processor" ] = module .get_processor ()
750+
751+ for sub_name , child in module .named_children ():
752+ fn_recursive_add_processors (f"{ name } .{ sub_name } " , child , processors )
753+
754+ return processors
755+
756+ for name , module in self .named_children ():
757+ fn_recursive_add_processors (name , module , processors )
758+
759+ return processors
760+
761+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
762+ def set_attn_processor (self , processor : Union [AttentionProcessor , Dict [str , AttentionProcessor ]]):
763+ r"""
764+ Sets the attention processor to use to compute attention.
765+
766+ Parameters:
767+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
768+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
769+ for **all** `Attention` layers.
770+
771+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
772+ processor. This is strongly recommended when setting trainable attention processors.
773+
774+ """
775+ count = len (self .attn_processors .keys ())
776+
777+ if isinstance (processor , dict ) and len (processor ) != count :
778+ raise ValueError (
779+ f"A dict of processors was passed, but the number of processors { len (processor )} does not match the"
780+ f" number of attention layers: { count } . Please make sure to pass { count } processor classes."
781+ )
782+
783+ def fn_recursive_attn_processor (name : str , module : torch .nn .Module , processor ):
784+ if hasattr (module , "set_processor" ):
785+ if not isinstance (processor , dict ):
786+ module .set_processor (processor )
787+ else :
788+ module .set_processor (processor .pop (f"{ name } .processor" ))
789+
790+ for sub_name , child in module .named_children ():
791+ fn_recursive_attn_processor (f"{ name } .{ sub_name } " , child , processor )
792+
793+ for name , module in self .named_children ():
794+ fn_recursive_attn_processor (name , module , processor )
795+
796+ def _set_gradient_checkpointing (self , module , value = False ):
797+ if hasattr (module , "gradient_checkpointing" ):
798+ module .gradient_checkpointing = value
799+
810800 def forward (
811801 self ,
812802 hidden_states : torch .Tensor ,
@@ -843,29 +833,23 @@ def forward(
843833 hidden_states = self .img_in (hidden_states )
844834 encoder_hidden_states = self .txt_in (encoder_hidden_states , timestep , encoder_attention_mask )
845835
846- txt_seq_len = encoder_hidden_states .shape [1 ]
836+ use_reentrant = is_torch_version (">=" , "1.11.0" )
837+ block_forward = (
838+ partial (torch .utils .checkpoint .checkpoint , use_reentrant = use_reentrant )
839+ if torch .is_grad_enabled () and self .gradient_checkpointing
840+ else lambda x : x
841+ )
847842
848843 freqs_cis = (freqs_cos , freqs_sin ) if freqs_cos is not None else None
849844 for _ , block in enumerate (self .transformer_blocks ):
850- double_block_args = [
851- hidden_states ,
852- encoder_hidden_states ,
853- temb ,
854- freqs_cis ,
855- ]
856-
857- hidden_states , encoder_hidden_states = block (* double_block_args )
845+ hidden_states , encoder_hidden_states = block_forward (block )(
846+ hidden_states , encoder_hidden_states , temb , freqs_cis
847+ )
858848
859849 for block in self .single_transformer_blocks :
860- single_block_args = [
861- hidden_states ,
862- encoder_hidden_states ,
863- temb ,
864- txt_seq_len ,
865- (freqs_cos , freqs_sin ),
866- ]
867-
868- hidden_states , encoder_hidden_states = block (* single_block_args )
850+ hidden_states , encoder_hidden_states = block_forward (block )(
851+ hidden_states , encoder_hidden_states , temb , freqs_cis
852+ )
869853
870854 hidden_states = self .norm_out (hidden_states , temb )
871855 hidden_states = self .proj_out (hidden_states )
0 commit comments