diff --git a/src/diffusers/models/transformers/transformer_hunyuan_video_framepack.py b/src/diffusers/models/transformers/transformer_hunyuan_video_framepack.py index 58b811569403..349c0f797978 100644 --- a/src/diffusers/models/transformers/transformer_hunyuan_video_framepack.py +++ b/src/diffusers/models/transformers/transformer_hunyuan_video_framepack.py @@ -152,9 +152,19 @@ def __init__( # 1. Latent and condition embedders self.x_embedder = HunyuanVideoPatchEmbed((patch_size_t, patch_size, patch_size), in_channels, inner_dim) + + # Framepack history projection embedder + self.clean_x_embedder = None + if has_clean_x_embedder: + self.clean_x_embedder = HunyuanVideoHistoryPatchEmbed(in_channels, inner_dim) + self.context_embedder = HunyuanVideoTokenRefiner( text_embed_dim, num_attention_heads, attention_head_dim, num_layers=num_refiner_layers ) + + # Framepack image-conditioning embedder + self.image_projection = FramepackClipVisionProjection(image_proj_dim, inner_dim) if has_image_proj else None + self.time_text_embed = HunyuanVideoConditionEmbedding( inner_dim, pooled_projection_dim, guidance_embeds, image_condition_type ) @@ -186,13 +196,6 @@ def __init__( self.norm_out = AdaLayerNormContinuous(inner_dim, inner_dim, elementwise_affine=False, eps=1e-6) self.proj_out = nn.Linear(inner_dim, patch_size_t * patch_size * patch_size * out_channels) - # Framepack specific modules - self.image_projection = FramepackClipVisionProjection(image_proj_dim, inner_dim) if has_image_proj else None - - self.clean_x_embedder = None - if has_clean_x_embedder: - self.clean_x_embedder = HunyuanVideoHistoryPatchEmbed(in_channels, inner_dim) - self.use_gradient_checkpointing = False def forward(