diff --git a/docs/source/en/api/pipelines/hunyuan_video.md b/docs/source/en/api/pipelines/hunyuan_video.md index f8039902976e..5d068c8b6ef8 100644 --- a/docs/source/en/api/pipelines/hunyuan_video.md +++ b/docs/source/en/api/pipelines/hunyuan_video.md @@ -50,7 +50,8 @@ The following models are available for the image-to-video pipeline: | Model name | Description | |:---|:---| | [`Skywork/SkyReels-V1-Hunyuan-I2V`](https://huggingface.co/Skywork/SkyReels-V1-Hunyuan-I2V) | Skywork's custom finetune of HunyuanVideo (de-distilled). Performs best with `97x544x960` resolution. Performs best at `97x544x960` resolution, `guidance_scale=1.0`, `true_cfg_scale=6.0` and a negative prompt. | -| [`hunyuanvideo-community/HunyuanVideo-I2V`](https://huggingface.co/hunyuanvideo-community/HunyuanVideo-I2V) | Tecent's official HunyuanVideo I2V model. Performs best at resolutions of 480, 720, 960, 1280. A higher `shift` value when initializing the scheduler is recommended (good values are between 7 and 20) | +| [`hunyuanvideo-community/HunyuanVideo-I2V-33ch`](https://huggingface.co/hunyuanvideo-community/HunyuanVideo-I2V) | Tecent's official HunyuanVideo 33-channel I2V model. Performs best at resolutions of 480, 720, 960, 1280. A higher `shift` value when initializing the scheduler is recommended (good values are between 7 and 20). | +| [`hunyuanvideo-community/HunyuanVideo-I2V`](https://huggingface.co/hunyuanvideo-community/HunyuanVideo-I2V) | Tecent's official HunyuanVideo 16-channel I2V model. Performs best at resolutions of 480, 720, 960, 1280. A higher `shift` value when initializing the scheduler is recommended (good values are between 7 and 20) | ## Quantization diff --git a/scripts/convert_hunyuan_video_to_diffusers.py b/scripts/convert_hunyuan_video_to_diffusers.py index ca6ec152f66f..c84809d7f68a 100644 --- a/scripts/convert_hunyuan_video_to_diffusers.py +++ b/scripts/convert_hunyuan_video_to_diffusers.py @@ -160,8 +160,9 @@ def remap_single_transformer_blocks_(key, state_dict): "pooled_projection_dim": 768, "rope_theta": 256.0, "rope_axes_dim": (16, 56, 56), + "image_condition_type": None, }, - "HYVideo-T/2-I2V": { + "HYVideo-T/2-I2V-33ch": { "in_channels": 16 * 2 + 1, "out_channels": 16, "num_attention_heads": 24, @@ -178,6 +179,26 @@ def remap_single_transformer_blocks_(key, state_dict): "pooled_projection_dim": 768, "rope_theta": 256.0, "rope_axes_dim": (16, 56, 56), + "image_condition_type": "latent_concat", + }, + "HYVideo-T/2-I2V-16ch": { + "in_channels": 16, + "out_channels": 16, + "num_attention_heads": 24, + "attention_head_dim": 128, + "num_layers": 20, + "num_single_layers": 40, + "num_refiner_layers": 2, + "mlp_ratio": 4.0, + "patch_size": 2, + "patch_size_t": 1, + "qk_norm": "rms_norm", + "guidance_embeds": True, + "text_embed_dim": 4096, + "pooled_projection_dim": 768, + "rope_theta": 256.0, + "rope_axes_dim": (16, 56, 56), + "image_condition_type": "token_replace", }, } diff --git a/src/diffusers/models/transformers/transformer_hunyuan_video.py b/src/diffusers/models/transformers/transformer_hunyuan_video.py index bb0cef057992..36f914f0b5c1 100644 --- a/src/diffusers/models/transformers/transformer_hunyuan_video.py +++ b/src/diffusers/models/transformers/transformer_hunyuan_video.py @@ -27,13 +27,15 @@ from ..attention_processor import Attention, AttentionProcessor from ..cache_utils import CacheMixin from ..embeddings import ( - CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, + PixArtAlphaTextProjection, + TimestepEmbedding, + Timesteps, get_1d_rotary_pos_embed, ) from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin -from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle +from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle, FP32LayerNorm logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -173,6 +175,141 @@ def forward( return gate_msa, gate_mlp +class HunyuanVideoTokenReplaceAdaLayerNormZero(nn.Module): + def __init__(self, embedding_dim: int, norm_type: str = "layer_norm", bias: bool = True): + super().__init__() + + self.silu = nn.SiLU() + self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=bias) + + if norm_type == "layer_norm": + self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6) + elif norm_type == "fp32_layer_norm": + self.norm = FP32LayerNorm(embedding_dim, elementwise_affine=False, bias=False) + else: + raise ValueError( + f"Unsupported `norm_type` ({norm_type}) provided. Supported ones are: 'layer_norm', 'fp32_layer_norm'." + ) + + def forward( + self, + hidden_states: torch.Tensor, + emb: torch.Tensor, + token_replace_emb: torch.Tensor, + first_frame_num_tokens: int, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + emb = self.linear(self.silu(emb)) + token_replace_emb = self.linear(self.silu(token_replace_emb)) + + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk(6, dim=1) + tr_shift_msa, tr_scale_msa, tr_gate_msa, tr_shift_mlp, tr_scale_mlp, tr_gate_mlp = token_replace_emb.chunk( + 6, dim=1 + ) + + norm_hidden_states = self.norm(hidden_states) + hidden_states_zero = ( + norm_hidden_states[:, :first_frame_num_tokens] * (1 + tr_scale_msa[:, None]) + tr_shift_msa[:, None] + ) + hidden_states_orig = ( + norm_hidden_states[:, first_frame_num_tokens:] * (1 + scale_msa[:, None]) + shift_msa[:, None] + ) + hidden_states = torch.cat([hidden_states_zero, hidden_states_orig], dim=1) + + return ( + hidden_states, + gate_msa, + shift_mlp, + scale_mlp, + gate_mlp, + tr_gate_msa, + tr_shift_mlp, + tr_scale_mlp, + tr_gate_mlp, + ) + + +class HunyuanVideoTokenReplaceAdaLayerNormZeroSingle(nn.Module): + def __init__(self, embedding_dim: int, norm_type: str = "layer_norm", bias: bool = True): + super().__init__() + + self.silu = nn.SiLU() + self.linear = nn.Linear(embedding_dim, 3 * embedding_dim, bias=bias) + + if norm_type == "layer_norm": + self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6) + else: + raise ValueError( + f"Unsupported `norm_type` ({norm_type}) provided. Supported ones are: 'layer_norm', 'fp32_layer_norm'." + ) + + def forward( + self, + hidden_states: torch.Tensor, + emb: torch.Tensor, + token_replace_emb: torch.Tensor, + first_frame_num_tokens: int, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + emb = self.linear(self.silu(emb)) + token_replace_emb = self.linear(self.silu(token_replace_emb)) + + shift_msa, scale_msa, gate_msa = emb.chunk(3, dim=1) + tr_shift_msa, tr_scale_msa, tr_gate_msa = token_replace_emb.chunk(3, dim=1) + + norm_hidden_states = self.norm(hidden_states) + hidden_states_zero = ( + norm_hidden_states[:, :first_frame_num_tokens] * (1 + tr_scale_msa[:, None]) + tr_shift_msa[:, None] + ) + hidden_states_orig = ( + norm_hidden_states[:, first_frame_num_tokens:] * (1 + scale_msa[:, None]) + shift_msa[:, None] + ) + hidden_states = torch.cat([hidden_states_zero, hidden_states_orig], dim=1) + + return hidden_states, gate_msa, tr_gate_msa + + +class HunyuanVideoConditionEmbedding(nn.Module): + def __init__( + self, + embedding_dim: int, + pooled_projection_dim: int, + guidance_embeds: bool, + image_condition_type: Optional[str] = None, + ): + super().__init__() + + self.image_condition_type = image_condition_type + + self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) + self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) + self.text_embedder = PixArtAlphaTextProjection(pooled_projection_dim, embedding_dim, act_fn="silu") + + self.guidance_embedder = None + if guidance_embeds: + self.guidance_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) + + def forward( + self, timestep: torch.Tensor, pooled_projection: torch.Tensor, guidance: Optional[torch.Tensor] = None + ) -> Tuple[torch.Tensor, torch.Tensor]: + timesteps_proj = self.time_proj(timestep) + timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=pooled_projection.dtype)) # (N, D) + pooled_projections = self.text_embedder(pooled_projection) + conditioning = timesteps_emb + pooled_projections + + token_replace_emb = None + if self.image_condition_type == "token_replace": + token_replace_timestep = torch.zeros_like(timestep) + token_replace_proj = self.time_proj(token_replace_timestep) + token_replace_emb = self.timestep_embedder(token_replace_proj.to(dtype=pooled_projection.dtype)) + token_replace_emb = token_replace_emb + pooled_projections + + if self.guidance_embedder is not None: + guidance_proj = self.time_proj(guidance) + guidance_emb = self.guidance_embedder(guidance_proj.to(dtype=pooled_projection.dtype)) + conditioning = conditioning + guidance_emb + + return conditioning, token_replace_emb + + class HunyuanVideoIndividualTokenRefinerBlock(nn.Module): def __init__( self, @@ -390,6 +527,8 @@ def forward( temb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + *args, + **kwargs, ) -> torch.Tensor: text_seq_length = encoder_hidden_states.shape[1] hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1) @@ -468,6 +607,8 @@ def forward( temb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, freqs_cis: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + *args, + **kwargs, ) -> Tuple[torch.Tensor, torch.Tensor]: # 1. Input normalization norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb) @@ -503,6 +644,181 @@ def forward( return hidden_states, encoder_hidden_states +class HunyuanVideoTokenReplaceSingleTransformerBlock(nn.Module): + def __init__( + self, + num_attention_heads: int, + attention_head_dim: int, + mlp_ratio: float = 4.0, + qk_norm: str = "rms_norm", + ) -> None: + super().__init__() + + hidden_size = num_attention_heads * attention_head_dim + mlp_dim = int(hidden_size * mlp_ratio) + + self.attn = Attention( + query_dim=hidden_size, + cross_attention_dim=None, + dim_head=attention_head_dim, + heads=num_attention_heads, + out_dim=hidden_size, + bias=True, + processor=HunyuanVideoAttnProcessor2_0(), + qk_norm=qk_norm, + eps=1e-6, + pre_only=True, + ) + + self.norm = HunyuanVideoTokenReplaceAdaLayerNormZeroSingle(hidden_size, norm_type="layer_norm") + self.proj_mlp = nn.Linear(hidden_size, mlp_dim) + self.act_mlp = nn.GELU(approximate="tanh") + self.proj_out = nn.Linear(hidden_size + mlp_dim, hidden_size) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + temb: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + token_replace_emb: torch.Tensor = None, + num_tokens: int = None, + ) -> torch.Tensor: + text_seq_length = encoder_hidden_states.shape[1] + hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1) + + residual = hidden_states + + # 1. Input normalization + norm_hidden_states, gate, tr_gate = self.norm(hidden_states, temb, token_replace_emb, num_tokens) + mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states)) + + norm_hidden_states, norm_encoder_hidden_states = ( + norm_hidden_states[:, :-text_seq_length, :], + norm_hidden_states[:, -text_seq_length:, :], + ) + + # 2. Attention + attn_output, context_attn_output = self.attn( + hidden_states=norm_hidden_states, + encoder_hidden_states=norm_encoder_hidden_states, + attention_mask=attention_mask, + image_rotary_emb=image_rotary_emb, + ) + attn_output = torch.cat([attn_output, context_attn_output], dim=1) + + # 3. Modulation and residual connection + hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2) + + proj_output = self.proj_out(hidden_states) + hidden_states_zero = proj_output[:, :num_tokens] * tr_gate.unsqueeze(1) + hidden_states_orig = proj_output[:, num_tokens:] * gate.unsqueeze(1) + hidden_states = torch.cat([hidden_states_zero, hidden_states_orig], dim=1) + hidden_states = hidden_states + residual + + hidden_states, encoder_hidden_states = ( + hidden_states[:, :-text_seq_length, :], + hidden_states[:, -text_seq_length:, :], + ) + return hidden_states, encoder_hidden_states + + +class HunyuanVideoTokenReplaceTransformerBlock(nn.Module): + def __init__( + self, + num_attention_heads: int, + attention_head_dim: int, + mlp_ratio: float, + qk_norm: str = "rms_norm", + ) -> None: + super().__init__() + + hidden_size = num_attention_heads * attention_head_dim + + self.norm1 = HunyuanVideoTokenReplaceAdaLayerNormZero(hidden_size, norm_type="layer_norm") + self.norm1_context = AdaLayerNormZero(hidden_size, norm_type="layer_norm") + + self.attn = Attention( + query_dim=hidden_size, + cross_attention_dim=None, + added_kv_proj_dim=hidden_size, + dim_head=attention_head_dim, + heads=num_attention_heads, + out_dim=hidden_size, + context_pre_only=False, + bias=True, + processor=HunyuanVideoAttnProcessor2_0(), + qk_norm=qk_norm, + eps=1e-6, + ) + + self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.ff = FeedForward(hidden_size, mult=mlp_ratio, activation_fn="gelu-approximate") + + self.norm2_context = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.ff_context = FeedForward(hidden_size, mult=mlp_ratio, activation_fn="gelu-approximate") + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + temb: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + freqs_cis: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + token_replace_emb: torch.Tensor = None, + num_tokens: int = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + # 1. Input normalization + ( + norm_hidden_states, + gate_msa, + shift_mlp, + scale_mlp, + gate_mlp, + tr_gate_msa, + tr_shift_mlp, + tr_scale_mlp, + tr_gate_mlp, + ) = self.norm1(hidden_states, temb, token_replace_emb, num_tokens) + norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context( + encoder_hidden_states, emb=temb + ) + + # 2. Joint attention + attn_output, context_attn_output = self.attn( + hidden_states=norm_hidden_states, + encoder_hidden_states=norm_encoder_hidden_states, + attention_mask=attention_mask, + image_rotary_emb=freqs_cis, + ) + + # 3. Modulation and residual connection + hidden_states_zero = hidden_states[:, :num_tokens] + attn_output[:, :num_tokens] * tr_gate_msa.unsqueeze(1) + hidden_states_orig = hidden_states[:, num_tokens:] + attn_output[:, num_tokens:] * gate_msa.unsqueeze(1) + hidden_states = torch.cat([hidden_states_zero, hidden_states_orig], dim=1) + encoder_hidden_states = encoder_hidden_states + context_attn_output * c_gate_msa.unsqueeze(1) + + norm_hidden_states = self.norm2(hidden_states) + norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states) + + hidden_states_zero = norm_hidden_states[:, :num_tokens] * (1 + tr_scale_mlp[:, None]) + tr_shift_mlp[:, None] + hidden_states_orig = norm_hidden_states[:, num_tokens:] * (1 + scale_mlp[:, None]) + shift_mlp[:, None] + norm_hidden_states = torch.cat([hidden_states_zero, hidden_states_orig], dim=1) + norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None] + + # 4. Feed-forward + ff_output = self.ff(norm_hidden_states) + context_ff_output = self.ff_context(norm_encoder_hidden_states) + + hidden_states_zero = hidden_states[:, :num_tokens] + ff_output[:, :num_tokens] * tr_gate_mlp.unsqueeze(1) + hidden_states_orig = hidden_states[:, num_tokens:] + ff_output[:, num_tokens:] * gate_mlp.unsqueeze(1) + hidden_states = torch.cat([hidden_states_zero, hidden_states_orig], dim=1) + encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output + + return hidden_states, encoder_hidden_states + + class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin): r""" A Transformer model for video-like data used in [HunyuanVideo](https://huggingface.co/tencent/HunyuanVideo). @@ -540,6 +856,10 @@ class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, The value of theta to use in the RoPE layer. rope_axes_dim (`Tuple[int]`, defaults to `(16, 56, 56)`): The dimensions of the axes to use in the RoPE layer. + image_condition_type (`str`, *optional*, defaults to `None`): + The type of image conditioning to use. If `None`, no image conditioning is used. If `latent_concat`, the + image is concatenated to the latent stream. If `token_replace`, the image is used to replace first-frame + tokens in the latent stream and apply conditioning. """ _supports_gradient_checkpointing = True @@ -570,9 +890,16 @@ def __init__( pooled_projection_dim: int = 768, rope_theta: float = 256.0, rope_axes_dim: Tuple[int] = (16, 56, 56), + image_condition_type: Optional[str] = None, ) -> None: super().__init__() + supported_image_condition_types = ["latent_concat", "token_replace"] + if image_condition_type is not None and image_condition_type not in supported_image_condition_types: + raise ValueError( + f"Invalid `image_condition_type` ({image_condition_type}). Supported ones are: {supported_image_condition_types}" + ) + inner_dim = num_attention_heads * attention_head_dim out_channels = out_channels or in_channels @@ -582,33 +909,52 @@ def __init__( text_embed_dim, num_attention_heads, attention_head_dim, num_layers=num_refiner_layers ) - if guidance_embeds: - self.time_text_embed = CombinedTimestepGuidanceTextProjEmbeddings(inner_dim, pooled_projection_dim) - else: - self.time_text_embed = CombinedTimestepTextProjEmbeddings(inner_dim, pooled_projection_dim) + self.time_text_embed = HunyuanVideoConditionEmbedding( + inner_dim, pooled_projection_dim, guidance_embeds, image_condition_type + ) # 2. RoPE self.rope = HunyuanVideoRotaryPosEmbed(patch_size, patch_size_t, rope_axes_dim, rope_theta) # 3. Dual stream transformer blocks - self.transformer_blocks = nn.ModuleList( - [ - HunyuanVideoTransformerBlock( - num_attention_heads, attention_head_dim, mlp_ratio=mlp_ratio, qk_norm=qk_norm - ) - for _ in range(num_layers) - ] - ) + if image_condition_type == "token_replace": + self.transformer_blocks = nn.ModuleList( + [ + HunyuanVideoTokenReplaceTransformerBlock( + num_attention_heads, attention_head_dim, mlp_ratio=mlp_ratio, qk_norm=qk_norm + ) + for _ in range(num_layers) + ] + ) + else: + self.transformer_blocks = nn.ModuleList( + [ + HunyuanVideoTransformerBlock( + num_attention_heads, attention_head_dim, mlp_ratio=mlp_ratio, qk_norm=qk_norm + ) + for _ in range(num_layers) + ] + ) # 4. Single stream transformer blocks - self.single_transformer_blocks = nn.ModuleList( - [ - HunyuanVideoSingleTransformerBlock( - num_attention_heads, attention_head_dim, mlp_ratio=mlp_ratio, qk_norm=qk_norm - ) - for _ in range(num_single_layers) - ] - ) + if image_condition_type == "token_replace": + self.single_transformer_blocks = nn.ModuleList( + [ + HunyuanVideoTokenReplaceSingleTransformerBlock( + num_attention_heads, attention_head_dim, mlp_ratio=mlp_ratio, qk_norm=qk_norm + ) + for _ in range(num_single_layers) + ] + ) + else: + self.single_transformer_blocks = nn.ModuleList( + [ + HunyuanVideoSingleTransformerBlock( + num_attention_heads, attention_head_dim, mlp_ratio=mlp_ratio, qk_norm=qk_norm + ) + for _ in range(num_single_layers) + ] + ) # 5. Output projection self.norm_out = AdaLayerNormContinuous(inner_dim, inner_dim, elementwise_affine=False, eps=1e-6) @@ -707,15 +1053,13 @@ def forward( post_patch_num_frames = num_frames // p_t post_patch_height = height // p post_patch_width = width // p + first_frame_num_tokens = 1 * post_patch_height * post_patch_width # 1. RoPE image_rotary_emb = self.rope(hidden_states) # 2. Conditional embeddings - if self.config.guidance_embeds: - temb = self.time_text_embed(timestep, guidance, pooled_projections) - else: - temb = self.time_text_embed(timestep, pooled_projections) + temb, token_replace_emb = self.time_text_embed(timestep, pooled_projections, guidance) hidden_states = self.x_embedder(hidden_states) encoder_hidden_states = self.context_embedder(encoder_hidden_states, timestep, encoder_attention_mask) @@ -746,6 +1090,8 @@ def forward( temb, attention_mask, image_rotary_emb, + token_replace_emb, + first_frame_num_tokens, ) for block in self.single_transformer_blocks: @@ -756,17 +1102,31 @@ def forward( temb, attention_mask, image_rotary_emb, + token_replace_emb, + first_frame_num_tokens, ) else: for block in self.transformer_blocks: hidden_states, encoder_hidden_states = block( - hidden_states, encoder_hidden_states, temb, attention_mask, image_rotary_emb + hidden_states, + encoder_hidden_states, + temb, + attention_mask, + image_rotary_emb, + token_replace_emb, + first_frame_num_tokens, ) for block in self.single_transformer_blocks: hidden_states, encoder_hidden_states = block( - hidden_states, encoder_hidden_states, temb, attention_mask, image_rotary_emb + hidden_states, + encoder_hidden_states, + temb, + attention_mask, + image_rotary_emb, + token_replace_emb, + first_frame_num_tokens, ) # 5. Output projection diff --git a/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py b/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py index 5a600dda4326..774b72e6c7c1 100644 --- a/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py +++ b/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py @@ -54,6 +54,7 @@ >>> from diffusers import HunyuanVideoImageToVideoPipeline, HunyuanVideoTransformer3DModel >>> from diffusers.utils import load_image, export_to_video + >>> # Available checkpoints: hunyuanvideo-community/HunyuanVideo-I2V, hunyuanvideo-community/HunyuanVideo-I2V-33ch >>> model_id = "hunyuanvideo-community/HunyuanVideo-I2V" >>> transformer = HunyuanVideoTransformer3DModel.from_pretrained( ... model_id, subfolder="transformer", torch_dtype=torch.bfloat16 @@ -69,7 +70,12 @@ ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/guitar-man.png" ... ) - >>> output = pipe(image=image, prompt=prompt).frames[0] + >>> # If using hunyuanvideo-community/HunyuanVideo-I2V + >>> output = pipe(image=image, prompt=prompt, guidance_scale=6.0).frames[0] + + >>> # If using hunyuanvideo-community/HunyuanVideo-I2V-33ch + >>> output = pipe(image=image, prompt=prompt, guidance_scale=1.0, true_cfg_scale=1.0).frames[0] + >>> export_to_video(output, "output.mp4", fps=15) ``` """ @@ -399,7 +405,8 @@ def encode_prompt( device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, max_sequence_length: int = 256, - ): + image_embed_interleave: int = 2, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: if prompt_embeds is None: prompt_embeds, prompt_attention_mask = self._get_llama_prompt_embeds( image, @@ -409,6 +416,7 @@ def encode_prompt( device=device, dtype=dtype, max_sequence_length=max_sequence_length, + image_embed_interleave=image_embed_interleave, ) if pooled_prompt_embeds is None: @@ -433,6 +441,8 @@ def check_inputs( prompt_embeds=None, callback_on_step_end_tensor_inputs=None, prompt_template=None, + true_cfg_scale=1.0, + guidance_scale=1.0, ): if height % 16 != 0 or width % 16 != 0: raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.") @@ -471,6 +481,13 @@ def check_inputs( f"`prompt_template` has to contain a key `template` but only found {prompt_template.keys()}" ) + if true_cfg_scale > 1.0 and guidance_scale > 1.0: + logger.warning( + "Both `true_cfg_scale` and `guidance_scale` are greater than 1.0. This will result in both " + "classifier-free guidance and embedded-guidance to be applied. This is not recommended " + "as it may lead to higher memory usage, slower inference and potentially worse results." + ) + def prepare_latents( self, image: torch.Tensor, @@ -483,6 +500,7 @@ def prepare_latents( device: Optional[torch.device] = None, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, + image_condition_type: str = "latent_concat", ) -> torch.Tensor: if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( @@ -497,10 +515,11 @@ def prepare_latents( image = image.unsqueeze(2) # [B, C, 1, H, W] if isinstance(generator, list): image_latents = [ - retrieve_latents(self.vae.encode(image[i].unsqueeze(0)), generator[i]) for i in range(batch_size) + retrieve_latents(self.vae.encode(image[i].unsqueeze(0)), generator[i], "argmax") + for i in range(batch_size) ] else: - image_latents = [retrieve_latents(self.vae.encode(img.unsqueeze(0)), generator) for img in image] + image_latents = [retrieve_latents(self.vae.encode(img.unsqueeze(0)), generator, "argmax") for img in image] image_latents = torch.cat(image_latents, dim=0).to(dtype) * self.vae_scaling_factor image_latents = image_latents.repeat(1, 1, num_latent_frames, 1, 1) @@ -513,6 +532,9 @@ def prepare_latents( t = torch.tensor([0.999]).to(device=device) latents = latents * t + image_latents * (1 - t) + if image_condition_type == "token_replace": + image_latents = image_latents[:, :, :1] + return latents, image_latents def enable_vae_slicing(self): @@ -598,6 +620,7 @@ def __call__( callback_on_step_end_tensor_inputs: List[str] = ["latents"], prompt_template: Dict[str, Any] = DEFAULT_PROMPT_TEMPLATE, max_sequence_length: int = 256, + image_embed_interleave: Optional[int] = None, ): r""" The call function to the pipeline for generation. @@ -704,12 +727,22 @@ def __call__( prompt_embeds, callback_on_step_end_tensor_inputs, prompt_template, + true_cfg_scale, + guidance_scale, ) + image_condition_type = self.transformer.config.image_condition_type has_neg_prompt = negative_prompt is not None or ( negative_prompt_embeds is not None and negative_pooled_prompt_embeds is not None ) do_true_cfg = true_cfg_scale > 1 and has_neg_prompt + image_embed_interleave = ( + image_embed_interleave + if image_embed_interleave is not None + else ( + 2 if image_condition_type == "latent_concat" else 4 if image_condition_type == "token_replace" else 1 + ) + ) self._guidance_scale = guidance_scale self._attention_kwargs = attention_kwargs @@ -729,7 +762,12 @@ def __call__( # 3. Prepare latent variables vae_dtype = self.vae.dtype image_tensor = self.video_processor.preprocess(image, height, width).to(device, vae_dtype) - num_channels_latents = (self.transformer.config.in_channels - 1) // 2 + + if image_condition_type == "latent_concat": + num_channels_latents = (self.transformer.config.in_channels - 1) // 2 + elif image_condition_type == "token_replace": + num_channels_latents = self.transformer.config.in_channels + latents, image_latents = self.prepare_latents( image_tensor, batch_size * num_videos_per_prompt, @@ -741,10 +779,12 @@ def __call__( device, generator, latents, + image_condition_type, ) - image_latents[:, :, 1:] = 0 - mask = image_latents.new_ones(image_latents.shape[0], 1, *image_latents.shape[2:]) - mask[:, :, 1:] = 0 + if image_condition_type == "latent_concat": + image_latents[:, :, 1:] = 0 + mask = image_latents.new_ones(image_latents.shape[0], 1, *image_latents.shape[2:]) + mask[:, :, 1:] = 0 # 4. Encode input prompt transformer_dtype = self.transformer.dtype @@ -759,6 +799,7 @@ def __call__( prompt_attention_mask=prompt_attention_mask, device=device, max_sequence_length=max_sequence_length, + image_embed_interleave=image_embed_interleave, ) prompt_embeds = prompt_embeds.to(transformer_dtype) prompt_attention_mask = prompt_attention_mask.to(transformer_dtype) @@ -782,10 +823,17 @@ def __call__( negative_prompt_attention_mask = negative_prompt_attention_mask.to(transformer_dtype) negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.to(transformer_dtype) - # 4. Prepare timesteps + # 5. Prepare timesteps sigmas = np.linspace(1.0, 0.0, num_inference_steps + 1)[:-1] if sigmas is None else sigmas timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, sigmas=sigmas) + # 6. Prepare guidance condition + guidance = None + if self.transformer.config.guidance_embeds: + guidance = ( + torch.tensor([guidance_scale] * latents.shape[0], dtype=transformer_dtype, device=device) * 1000.0 + ) + # 7. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order self._num_timesteps = len(timesteps) @@ -796,16 +844,21 @@ def __call__( continue self._current_timestep = t - latent_model_input = torch.cat([latents, image_latents, mask], dim=1).to(transformer_dtype) # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = t.expand(latents.shape[0]).to(latents.dtype) + if image_condition_type == "latent_concat": + latent_model_input = torch.cat([latents, image_latents, mask], dim=1).to(transformer_dtype) + elif image_condition_type == "token_replace": + latent_model_input = torch.cat([image_latents, latents[:, :, 1:]], dim=2).to(transformer_dtype) + noise_pred = self.transformer( hidden_states=latent_model_input, timestep=timestep, encoder_hidden_states=prompt_embeds, encoder_attention_mask=prompt_attention_mask, pooled_projections=pooled_prompt_embeds, + guidance=guidance, attention_kwargs=attention_kwargs, return_dict=False, )[0] @@ -817,13 +870,20 @@ def __call__( encoder_hidden_states=negative_prompt_embeds, encoder_attention_mask=negative_prompt_attention_mask, pooled_projections=negative_pooled_prompt_embeds, + guidance=guidance, attention_kwargs=attention_kwargs, return_dict=False, )[0] noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred) # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + if image_condition_type == "latent_concat": + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + elif image_condition_type == "token_replace": + latents = latents = self.scheduler.step( + noise_pred[:, :, 1:], t, latents[:, :, 1:], return_dict=False + )[0] + latents = torch.cat([image_latents, latents], dim=2) if callback_on_step_end is not None: callback_kwargs = {} @@ -844,12 +904,16 @@ def __call__( self._current_timestep = None if not output_type == "latent": - latents = latents.to(self.vae.dtype) / self.vae.config.scaling_factor + latents = latents.to(self.vae.dtype) / self.vae_scaling_factor video = self.vae.decode(latents, return_dict=False)[0] - video = video[:, :, 4:, :, :] + if image_condition_type == "latent_concat": + video = video[:, :, 4:, :, :] video = self.video_processor.postprocess_video(video, output_type=output_type) else: - video = latents[:, :, 1:, :, :] + if image_condition_type == "latent_concat": + video = latents[:, :, 1:, :, :] + else: + video = latents # Offload all models self.maybe_free_model_hooks() diff --git a/tests/models/transformers/test_models_transformer_hunyuan_video.py b/tests/models/transformers/test_models_transformer_hunyuan_video.py index 2b81dc876433..495131ad6fd8 100644 --- a/tests/models/transformers/test_models_transformer_hunyuan_video.py +++ b/tests/models/transformers/test_models_transformer_hunyuan_video.py @@ -80,6 +80,7 @@ def prepare_init_args_and_inputs_for_common(self): "text_embed_dim": 16, "pooled_projection_dim": 8, "rope_axes_dim": (2, 4, 4), + "image_condition_type": None, } inputs_dict = self.dummy_input return init_dict, inputs_dict @@ -144,6 +145,7 @@ def prepare_init_args_and_inputs_for_common(self): "text_embed_dim": 16, "pooled_projection_dim": 8, "rope_axes_dim": (2, 4, 4), + "image_condition_type": None, } inputs_dict = self.dummy_input return init_dict, inputs_dict @@ -209,6 +211,75 @@ def prepare_init_args_and_inputs_for_common(self): "text_embed_dim": 16, "pooled_projection_dim": 8, "rope_axes_dim": (2, 4, 4), + "image_condition_type": "latent_concat", + } + inputs_dict = self.dummy_input + return init_dict, inputs_dict + + def test_output(self): + super().test_output(expected_output_shape=(1, *self.output_shape)) + + def test_gradient_checkpointing_is_applied(self): + expected_set = {"HunyuanVideoTransformer3DModel"} + super().test_gradient_checkpointing_is_applied(expected_set=expected_set) + + +class HunyuanVideoTokenReplaceImageToVideoTransformer3DTests(ModelTesterMixin, unittest.TestCase): + model_class = HunyuanVideoTransformer3DModel + main_input_name = "hidden_states" + uses_custom_attn_processor = True + + @property + def dummy_input(self): + batch_size = 1 + num_channels = 2 + num_frames = 1 + height = 16 + width = 16 + text_encoder_embedding_dim = 16 + pooled_projection_dim = 8 + sequence_length = 12 + + hidden_states = torch.randn((batch_size, num_channels, num_frames, height, width)).to(torch_device) + timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device) + encoder_hidden_states = torch.randn((batch_size, sequence_length, text_encoder_embedding_dim)).to(torch_device) + pooled_projections = torch.randn((batch_size, pooled_projection_dim)).to(torch_device) + encoder_attention_mask = torch.ones((batch_size, sequence_length)).to(torch_device) + guidance = torch.randint(0, 1000, size=(batch_size,)).to(torch_device, dtype=torch.float32) + + return { + "hidden_states": hidden_states, + "timestep": timestep, + "encoder_hidden_states": encoder_hidden_states, + "pooled_projections": pooled_projections, + "encoder_attention_mask": encoder_attention_mask, + "guidance": guidance, + } + + @property + def input_shape(self): + return (8, 1, 16, 16) + + @property + def output_shape(self): + return (4, 1, 16, 16) + + def prepare_init_args_and_inputs_for_common(self): + init_dict = { + "in_channels": 2, + "out_channels": 4, + "num_attention_heads": 2, + "attention_head_dim": 10, + "num_layers": 1, + "num_single_layers": 1, + "num_refiner_layers": 1, + "patch_size": 1, + "patch_size_t": 1, + "guidance_embeds": True, + "text_embed_dim": 16, + "pooled_projection_dim": 8, + "rope_axes_dim": (2, 4, 4), + "image_condition_type": "token_replace", } inputs_dict = self.dummy_input return init_dict, inputs_dict diff --git a/tests/pipelines/hunyuan_video/test_hunyuan_image2video.py b/tests/pipelines/hunyuan_video/test_hunyuan_image2video.py index c18e5c0ad8fb..5802bde87a61 100644 --- a/tests/pipelines/hunyuan_video/test_hunyuan_image2video.py +++ b/tests/pipelines/hunyuan_video/test_hunyuan_image2video.py @@ -83,6 +83,7 @@ def get_dummy_components(self, num_layers: int = 1, num_single_layers: int = 1): text_embed_dim=16, pooled_projection_dim=8, rope_axes_dim=(2, 4, 4), + image_condition_type="latent_concat", ) torch.manual_seed(0)