@@ -228,8 +228,8 @@ def forward(
228228 )
229229
230230
231- class HunyuanVideoTimestepTextProjEmbeddings (nn .Module ):
232- def __init__ (self , embedding_dim : int , pooled_projection_dim : int , image_condition_type : Optional [str ] = None ):
231+ class HunyuanVideoConditionEmbedding (nn .Module ):
232+ def __init__ (self , embedding_dim : int , pooled_projection_dim : int , guidance_embeds : bool , image_condition_type : Optional [str ] = None ):
233233 super ().__init__ ()
234234
235235 self .image_condition_type = image_condition_type
@@ -238,7 +238,11 @@ def __init__(self, embedding_dim: int, pooled_projection_dim: int, image_conditi
238238 self .timestep_embedder = TimestepEmbedding (in_channels = 256 , time_embed_dim = embedding_dim )
239239 self .text_embedder = PixArtAlphaTextProjection (pooled_projection_dim , embedding_dim , act_fn = "silu" )
240240
241- def forward (self , timestep : torch .Tensor , pooled_projection : torch .Tensor ) -> Tuple [torch .Tensor , torch .Tensor ]:
241+ self .guidance_embedder = None
242+ if guidance_embeds :
243+ self .guidance_embedder = TimestepEmbedding (in_channels = 256 , time_embed_dim = embedding_dim )
244+
245+ def forward (self , timestep : torch .Tensor , pooled_projection : torch .Tensor , guidance : Optional [torch .Tensor ] = None ) -> Tuple [torch .Tensor , torch .Tensor ]:
242246 timesteps_proj = self .time_proj (timestep )
243247 timesteps_emb = self .timestep_embedder (timesteps_proj .to (dtype = pooled_projection .dtype )) # (N, D)
244248 pooled_projections = self .text_embedder (pooled_projection )
@@ -248,8 +252,13 @@ def forward(self, timestep: torch.Tensor, pooled_projection: torch.Tensor) -> Tu
248252 if self .image_condition_type == "token_replace" :
249253 token_replace_timestep = torch .zeros_like (timestep )
250254 token_replace_proj = self .time_proj (token_replace_timestep )
251- token_replace_emb = self .timestep_embedder (token_replace_proj )
252- token_replace_emb = token_replace_emb + conditioning
255+ token_replace_emb = self .timestep_embedder (token_replace_proj .to (dtype = pooled_projection .dtype ))
256+ token_replace_emb = token_replace_emb + pooled_projections
257+
258+ if self .guidance_embedder is not None :
259+ guidance_proj = self .time_proj (guidance )
260+ guidance_emb = self .guidance_embedder (guidance_proj .to (dtype = pooled_projection .dtype ))
261+ conditioning = conditioning + guidance_emb
253262
254263 return conditioning , token_replace_emb
255264
@@ -665,7 +674,7 @@ def forward(
665674
666675 hidden_states_zero = norm_hidden_states [:, :num_tokens ] * (1 + tr_scale_mlp [:, None ]) + tr_shift_mlp [:, None ]
667676 hidden_states_orig = norm_hidden_states [:, num_tokens :] * (1 + scale_mlp [:, None ]) + shift_mlp [:, None ]
668- hidden_states = torch .cat ([hidden_states_zero , hidden_states_orig ], dim = 1 )
677+ norm_hidden_states = torch .cat ([hidden_states_zero , hidden_states_orig ], dim = 1 )
669678 norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp [:, None ]) + c_shift_mlp [:, None ]
670679
671680 # 4. Feed-forward
@@ -717,6 +726,10 @@ class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin,
717726 The value of theta to use in the RoPE layer.
718727 rope_axes_dim (`Tuple[int]`, defaults to `(16, 56, 56)`):
719728 The dimensions of the axes to use in the RoPE layer.
729+ image_condition_type (`str`, *optional*, defaults to `None`):
730+ The type of image conditioning to use. If `None`, no image conditioning is used. If `latent_concat`, the
731+ image is concatenated to the latent stream. If `token_replace`, the image is used to replace first-frame
732+ tokens in the latent stream and apply conditioning.
720733 """
721734
722735 _supports_gradient_checkpointing = True
@@ -761,12 +774,9 @@ def __init__(
761774 text_embed_dim , num_attention_heads , attention_head_dim , num_layers = num_refiner_layers
762775 )
763776
764- if guidance_embeds :
765- self .time_text_embed = CombinedTimestepGuidanceTextProjEmbeddings (inner_dim , pooled_projection_dim )
766- else :
767- self .time_text_embed = HunyuanVideoTimestepTextProjEmbeddings (
768- inner_dim , pooled_projection_dim , image_condition_type
769- )
777+ self .time_text_embed = HunyuanVideoConditionEmbedding (
778+ inner_dim , pooled_projection_dim , guidance_embeds , image_condition_type
779+ )
770780
771781 # 2. RoPE
772782 self .rope = HunyuanVideoRotaryPosEmbed (patch_size , patch_size_t , rope_axes_dim , rope_theta )
@@ -904,10 +914,7 @@ def forward(
904914 image_rotary_emb = self .rope (hidden_states )
905915
906916 # 2. Conditional embeddings
907- if self .config .guidance_embeds :
908- temb = self .time_text_embed (timestep , guidance , pooled_projections )
909- else :
910- temb , token_replace_emb = self .time_text_embed (timestep , pooled_projections )
917+ temb , token_replace_emb = self .time_text_embed (timestep , pooled_projections , guidance )
911918
912919 hidden_states = self .x_embedder (hidden_states )
913920 encoder_hidden_states = self .context_embedder (encoder_hidden_states , timestep , encoder_attention_mask )
0 commit comments