1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414
15- import math
16- from functools import partial
17- from typing import Dict , List , Optional , Tuple , Union
15+ from typing import Any , Dict , List , Optional , Tuple , Union
1816
1917import torch
2018import torch .nn as nn
2422from ...utils import is_torch_version
2523from ..attention import FeedForward
2624from ..attention_processor import Attention , AttentionProcessor
27- from ..embeddings import get_1d_rotary_pos_embed , get_timestep_embedding
25+ from ..embeddings import (
26+ CombinedTimestepGuidanceTextProjEmbeddings ,
27+ CombinedTimestepTextProjEmbeddings ,
28+ get_1d_rotary_pos_embed ,
29+ )
2830from ..modeling_outputs import Transformer2DModelOutput
2931from ..modeling_utils import ModelMixin
3032from ..normalization import AdaLayerNormContinuous , AdaLayerNormZero , AdaLayerNormZeroSingle
@@ -123,19 +125,6 @@ def __call__(
123125 return hidden_states , encoder_hidden_states
124126
125127
126- class MLPEmbedder (nn .Module ):
127- """copied from https://github.com/black-forest-labs/flux/blob/main/src/flux/modules/layers.py"""
128-
129- def __init__ (self , in_dim : int , hidden_dim : int ):
130- super ().__init__ ()
131- self .in_layer = nn .Linear (in_dim , hidden_dim , bias = True )
132- self .silu = nn .SiLU ()
133- self .out_layer = nn .Linear (hidden_dim , hidden_dim , bias = True )
134-
135- def forward (self , x : torch .Tensor ) -> torch .Tensor :
136- return self .out_layer (self .silu (self .in_layer (x )))
137-
138-
139128class PatchEmbed (nn .Module ):
140129 def __init__ (
141130 self ,
@@ -154,49 +143,21 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
154143 return hidden_states
155144
156145
157- class TextProjection (nn .Module ):
158- def __init__ (self , in_channels , hidden_size , act_layer ) :
146+ class HunyuanVideoAdaNorm (nn .Module ):
147+ def __init__ (self , in_features : int , out_features : Optional [ int ] = None ) -> None :
159148 super ().__init__ ()
160- self .linear_1 = nn .Linear (in_features = in_channels , out_features = hidden_size , bias = True )
161- self .act_1 = act_layer ()
162- self .linear_2 = nn .Linear (in_features = hidden_size , out_features = hidden_size , bias = True )
163-
164- def forward (self , caption ):
165- hidden_states = self .linear_1 (caption )
166- hidden_states = self .act_1 (hidden_states )
167- hidden_states = self .linear_2 (hidden_states )
168- return hidden_states
169-
170149
171- class TimestepEmbedder (nn .Module ):
172- """
173- Embeds scalar timesteps into vector representations.
174- """
175-
176- def __init__ (
177- self ,
178- hidden_size ,
179- act_layer ,
180- frequency_embedding_size = 256 ,
181- max_period = 10000 ,
182- out_size = None ,
183- ):
184- super ().__init__ ()
185- self .frequency_embedding_size = frequency_embedding_size
186- self .max_period = max_period
187- if out_size is None :
188- out_size = hidden_size
189-
190- self .mlp = nn .Sequential (
191- nn .Linear (frequency_embedding_size , hidden_size , bias = True ),
192- act_layer (),
193- nn .Linear (hidden_size , out_size , bias = True ),
194- )
150+ out_features = out_features or 2 * in_features
151+ self .linear = nn .Linear (in_features , out_features )
152+ self .nonlinearity = nn .SiLU ()
195153
196- def forward (self , t ):
197- t_freq = get_timestep_embedding (t , self .frequency_embedding_size , flip_sin_to_cos = True , max_period = self .max_period , downscale_freq_shift = 0 ).type (self .mlp [0 ].weight .dtype )
198- t_emb = self .mlp (t_freq )
199- return t_emb
154+ def forward (
155+ self , temb : torch .Tensor
156+ ) -> Tuple [torch .Tensor , torch .Tensor , torch .Tensor , torch .Tensor , torch .Tensor ]:
157+ temb = self .linear (self .nonlinearity (temb ))
158+ gate_msa , gate_mlp = temb .chunk (2 , dim = 1 )
159+ gate_msa , gate_mlp = gate_msa .unsqueeze (1 ), gate_mlp .unsqueeze (1 )
160+ return gate_msa , gate_mlp
200161
201162
202163class IndividualTokenRefinerBlock (nn .Module ):
@@ -224,29 +185,27 @@ def __init__(
224185 self .norm2 = nn .LayerNorm (hidden_size , elementwise_affine = True , eps = 1e-6 )
225186 self .mlp = FeedForward (hidden_size , mult = mlp_width_ratio , activation_fn = "silu" , dropout = mlp_drop_rate )
226187
227- self .adaLN_modulation = nn .Sequential (
228- nn .SiLU (),
229- nn .Linear (hidden_size , 2 * hidden_size , bias = True ),
230- )
188+ self .norm_out = HunyuanVideoAdaNorm (hidden_size , 2 * hidden_size )
231189
232190 def forward (
233191 self ,
234192 hidden_states : torch .Tensor ,
235193 temb : torch .Tensor ,
236194 attention_mask : Optional [torch .Tensor ] = None ,
237195 ) -> torch .Tensor :
238- gate_msa , gate_mlp = self .adaLN_modulation (temb ).chunk (2 , dim = 1 )
239-
240196 norm_hidden_states = self .norm1 (hidden_states )
241197
242198 attn_output = self .attn (
243199 hidden_states = norm_hidden_states ,
244200 encoder_hidden_states = None ,
245201 attention_mask = attention_mask ,
246202 )
247- hidden_states = hidden_states + attn_output * gate_msa .unsqueeze (1 )
248203
249- hidden_states = hidden_states + self .mlp (self .norm2 (hidden_states )) * gate_mlp .unsqueeze (1 )
204+ gate_msa , gate_mlp = self .norm_out (temb )
205+ hidden_states = hidden_states + attn_output * gate_msa
206+
207+ ff_output = self .mlp (self .norm2 (hidden_states ))
208+ hidden_states = hidden_states + ff_output * gate_mlp
250209
251210 return hidden_states
252211
@@ -313,10 +272,10 @@ def __init__(
313272
314273 hidden_size = num_attention_heads * attention_head_dim
315274
316- self .input_embedder = nn . Linear ( in_channels , hidden_size , bias = True )
317- self . time_embed = TimestepEmbedder ( hidden_size , nn . SiLU )
318- self . context_embed = TextProjection ( in_channels , hidden_size , nn . SiLU )
319-
275+ self .time_text_embed = CombinedTimestepTextProjEmbeddings (
276+ embedding_dim = hidden_size , pooled_projection_dim = in_channels
277+ )
278+ self . proj_in = nn . Linear ( in_channels , hidden_size , bias = True )
320279 self .token_refiner = IndividualTokenRefiner (
321280 num_attention_heads = num_attention_heads ,
322281 attention_head_dim = attention_head_dim ,
@@ -332,21 +291,17 @@ def forward(
332291 timestep : torch .LongTensor ,
333292 attention_mask : Optional [torch .LongTensor ] = None ,
334293 ) -> torch .Tensor :
335- original_dtype = hidden_states .dtype
336- temb = self .time_embed (timestep )
337-
338294 if attention_mask is None :
339295 pooled_projections = hidden_states .mean (dim = 1 )
340296 else :
297+ original_dtype = hidden_states .dtype
341298 mask_float = attention_mask .float ().unsqueeze (- 1 )
342299 pooled_projections = (hidden_states * mask_float ).sum (dim = 1 ) / mask_float .sum (dim = 1 )
343300 pooled_projections = pooled_projections .to (original_dtype )
344301
345- pooled_projections = self .context_embed (pooled_projections )
346- emb = temb + pooled_projections
347-
348- hidden_states = self .input_embedder (hidden_states )
349- hidden_states = self .token_refiner (hidden_states , emb , attention_mask )
302+ temb = self .time_text_embed (timestep , pooled_projections )
303+ hidden_states = self .proj_in (hidden_states )
304+ hidden_states = self .token_refiner (hidden_states , temb , attention_mask )
350305
351306 return hidden_states
352307
@@ -561,14 +516,7 @@ def __init__(
561516 text_embed_dim , num_attention_heads , attention_head_dim , num_layers = num_refiner_layers
562517 )
563518
564- # time modulation
565- self .time_in = TimestepEmbedder (inner_dim , nn .SiLU )
566-
567- # text modulation
568- self .vector_in = MLPEmbedder (text_embed_dim_2 , inner_dim )
569-
570- # guidance modulation
571- self .guidance_in = TimestepEmbedder (inner_dim , nn .SiLU )
519+ self .time_text_embed = CombinedTimestepGuidanceTextProjEmbeddings (inner_dim , text_embed_dim_2 )
572520
573521 # 3. RoPE
574522 self .rope = HunyuanVideoRotaryPosEmbed (patch_size , patch_size_t , rope_dim_list , rope_theta )
@@ -679,30 +627,55 @@ def forward(
679627
680628 image_rotary_emb = self .rope (hidden_states )
681629
682- temb = self .time_in (timestep )
683- temb = temb + self .vector_in (encoder_hidden_states_2 )
684- temb = temb + self .guidance_in (guidance )
630+ temb = self .time_text_embed (timestep , guidance , encoder_hidden_states_2 )
685631
686632 # Embed image and text.
687633 hidden_states = self .img_in (hidden_states )
688634 encoder_hidden_states = self .txt_in (encoder_hidden_states , timestep , encoder_attention_mask )
689635
690- use_reentrant = is_torch_version (">=" , "1.11.0" )
691- block_forward = (
692- partial (torch .utils .checkpoint .checkpoint , use_reentrant = use_reentrant )
693- if torch .is_grad_enabled () and self .gradient_checkpointing
694- else lambda x : x
695- )
636+ if torch .is_grad_enabled () and self .gradient_checkpointing :
696637
697- for _ , block in enumerate (self .transformer_blocks ):
698- hidden_states , encoder_hidden_states = block_forward (block )(
699- hidden_states , encoder_hidden_states , temb , image_rotary_emb
700- )
638+ def create_custom_forward (module , return_dict = None ):
639+ def custom_forward (* inputs ):
640+ if return_dict is not None :
641+ return module (* inputs , return_dict = return_dict )
642+ else :
643+ return module (* inputs )
701644
702- for block in self .single_transformer_blocks :
703- hidden_states , encoder_hidden_states = block_forward (block )(
704- hidden_states , encoder_hidden_states , temb , image_rotary_emb
705- )
645+ return custom_forward
646+
647+ ckpt_kwargs : Dict [str , Any ] = {"use_reentrant" : False } if is_torch_version (">=" , "1.11.0" ) else {}
648+
649+ for block in self .transformer_blocks :
650+ hidden_states , encoder_hidden_states = torch .utils .checkpoint .checkpoint (
651+ create_custom_forward (block ),
652+ hidden_states ,
653+ encoder_hidden_states ,
654+ temb ,
655+ image_rotary_emb ,
656+ ** ckpt_kwargs ,
657+ )
658+
659+ for block in self .single_transformer_blocks :
660+ hidden_states , encoder_hidden_states = torch .utils .checkpoint .checkpoint (
661+ create_custom_forward (block ),
662+ hidden_states ,
663+ encoder_hidden_states ,
664+ temb ,
665+ image_rotary_emb ,
666+ ** ckpt_kwargs ,
667+ )
668+
669+ else :
670+ for block in self .transformer_blocks :
671+ hidden_states , encoder_hidden_states = block (
672+ hidden_states , encoder_hidden_states , temb , image_rotary_emb
673+ )
674+
675+ for block in self .single_transformer_blocks :
676+ hidden_states , encoder_hidden_states = block (
677+ hidden_states , encoder_hidden_states , temb , image_rotary_emb
678+ )
706679
707680 hidden_states = self .norm_out (hidden_states , temb )
708681 hidden_states = self .proj_out (hidden_states )
0 commit comments