From 4d4be9acd3798e145847ceec59f373c00ff501aa Mon Sep 17 00:00:00 2001 From: Aryan Date: Thu, 5 Dec 2024 20:53:14 +0100 Subject: [PATCH 01/58] copy transformer --- src/diffusers/__init__.py | 2 + src/diffusers/models/__init__.py | 2 + src/diffusers/models/transformers/__init__.py | 1 + .../transformers/transformer_hunyuan_video.py | 1470 +++++++++++++++++ src/diffusers/utils/dummy_pt_objects.py | 15 + 5 files changed, 1490 insertions(+) create mode 100644 src/diffusers/models/transformers/transformer_hunyuan_video.py diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index db46dc1d8801..caefc093ed64 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -99,6 +99,7 @@ "HunyuanDiT2DControlNetModel", "HunyuanDiT2DModel", "HunyuanDiT2DMultiControlNetModel", + "HunyuanVideoTransformer3DModel", "I2VGenXLUNet", "Kandinsky3UNet", "LatteTransformer3DModel", @@ -591,6 +592,7 @@ HunyuanDiT2DControlNetModel, HunyuanDiT2DModel, HunyuanDiT2DMultiControlNetModel, + HunyuanVideoTransformer3DModel, I2VGenXLUNet, Kandinsky3UNet, LatteTransformer3DModel, diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index 65e2418ac794..fdb09fd4ebb5 100644 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -63,6 +63,7 @@ _import_structure["transformers.transformer_allegro"] = ["AllegroTransformer3DModel"] _import_structure["transformers.transformer_cogview3plus"] = ["CogView3PlusTransformer2DModel"] _import_structure["transformers.transformer_flux"] = ["FluxTransformer2DModel"] + _import_structure["transformers.transformer_hunyuan_video"] = ["HunyuanVideoTransformer3DModel"] _import_structure["transformers.transformer_mochi"] = ["MochiTransformer3DModel"] _import_structure["transformers.transformer_sd3"] = ["SD3Transformer2DModel"] _import_structure["transformers.transformer_temporal"] = ["TransformerTemporalModel"] @@ -122,6 +123,7 @@ DualTransformer2DModel, FluxTransformer2DModel, HunyuanDiT2DModel, + HunyuanVideoTransformer3DModel, LatteTransformer3DModel, LuminaNextDiT2DModel, MochiTransformer3DModel, diff --git a/src/diffusers/models/transformers/__init__.py b/src/diffusers/models/transformers/__init__.py index a2c087d708a4..2bdade63d3d2 100644 --- a/src/diffusers/models/transformers/__init__.py +++ b/src/diffusers/models/transformers/__init__.py @@ -17,6 +17,7 @@ from .transformer_allegro import AllegroTransformer3DModel from .transformer_cogview3plus import CogView3PlusTransformer2DModel from .transformer_flux import FluxTransformer2DModel + from .transformer_hunyuan_video import HunyuanVideoTransformer3DModel from .transformer_mochi import MochiTransformer3DModel from .transformer_sd3 import SD3Transformer2DModel from .transformer_temporal import TransformerTemporalModel diff --git a/src/diffusers/models/transformers/transformer_hunyuan_video.py b/src/diffusers/models/transformers/transformer_hunyuan_video.py new file mode 100644 index 000000000000..8c9c26f9cc03 --- /dev/null +++ b/src/diffusers/models/transformers/transformer_hunyuan_video.py @@ -0,0 +1,1470 @@ +import collections.abc +import itertools +import math +from functools import partial +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.models import ModelMixin + + +try: + from flash_attn.flash_attn_interface import flash_attn_varlen_func +except ImportError: + flash_attn_varlen_func = None + + +MEMORY_LAYOUT = { + "flash": ( + lambda x: x.view(x.shape[0] * x.shape[1], *x.shape[2:]), + lambda x: x, + ), + "torch": ( + lambda x: x.transpose(1, 2), + lambda x: x.transpose(1, 2), + ), + "vanilla": ( + lambda x: x.transpose(1, 2), + lambda x: x.transpose(1, 2), + ), +} + + +def get_cu_seqlens(text_mask, img_len): + """Calculate cu_seqlens_q, cu_seqlens_kv using text_mask and img_len + + Args: + text_mask (torch.Tensor): the mask of text + img_len (int): the length of image + + Returns: + torch.Tensor: the calculated cu_seqlens for flash attention + """ + batch_size = text_mask.shape[0] + text_len = text_mask.sum(dim=1) + max_len = text_mask.shape[1] + img_len + + cu_seqlens = torch.zeros([2 * batch_size + 1], dtype=torch.int32, device="cuda") + + for i in range(batch_size): + s = text_len[i] + img_len + s1 = i * max_len + s + s2 = (i + 1) * max_len + cu_seqlens[2 * i + 1] = s1 + cu_seqlens[2 * i + 2] = s2 + + return cu_seqlens + + +def attention( + q, + k, + v, + mode="torch", + drop_rate=0, + attn_mask=None, + causal=False, + cu_seqlens_q=None, + cu_seqlens_kv=None, + max_seqlen_q=None, + max_seqlen_kv=None, + batch_size=1, +): + """ + Perform QKV self attention. + + Args: + q (torch.Tensor): Query tensor with shape [b, s, a, d], where a is the number of heads. + k (torch.Tensor): Key tensor with shape [b, s1, a, d] + v (torch.Tensor): Value tensor with shape [b, s1, a, d] + mode (str): Attention mode. Choose from 'self_flash', 'cross_flash', 'torch', and 'vanilla'. + drop_rate (float): Dropout rate in attention map. (default: 0) + attn_mask (torch.Tensor): Attention mask with shape [b, s1] (cross_attn), or [b, a, s, s1] (torch or vanilla). + (default: None) + causal (bool): Whether to use causal attention. (default: False) + cu_seqlens_q (torch.Tensor): dtype torch.int32. The cumulative sequence lengths of the sequences in the batch, + used to index into q. + cu_seqlens_kv (torch.Tensor): + dtype torch.int32. The cumulative sequence lengths of the sequences in the batch, used to index into kv. + max_seqlen_q (int): The maximum sequence length in the batch of q. + max_seqlen_kv (int): The maximum sequence length in the batch of k and v. + + Returns: + torch.Tensor: Output tensor after self attention with shape [b, s, ad] + """ + pre_attn_layout, post_attn_layout = MEMORY_LAYOUT[mode] + q = pre_attn_layout(q) + k = pre_attn_layout(k) + v = pre_attn_layout(v) + + if mode == "torch": + if attn_mask is not None and attn_mask.dtype != torch.bool: + attn_mask = attn_mask.to(q.dtype) + x = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, dropout_p=drop_rate, is_causal=causal) + elif mode == "flash": + x = flash_attn_varlen_func( + q, + k, + v, + cu_seqlens_q, + cu_seqlens_kv, + max_seqlen_q, + max_seqlen_kv, + ) + # x with shape [(bxs), a, d] + x = x.view(batch_size, max_seqlen_q, x.shape[-2], x.shape[-1]) # reshape x to [b, s, a, d] + elif mode == "vanilla": + scale_factor = 1 / math.sqrt(q.size(-1)) + + b, a, s, _ = q.shape + s1 = k.size(2) + attn_bias = torch.zeros(b, a, s, s1, dtype=q.dtype, device=q.device) + if causal: + # Only applied to self attention + assert attn_mask is None, "Causal mask and attn_mask cannot be used together" + temp_mask = torch.ones(b, a, s, s, dtype=torch.bool, device=q.device).tril(diagonal=0) + attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) + attn_bias.to(q.dtype) + + if attn_mask is not None: + if attn_mask.dtype == torch.bool: + attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf")) + else: + attn_bias += attn_mask + + # TODO: Maybe force q and k to be float32 to avoid numerical overflow + attn = (q @ k.transpose(-2, -1)) * scale_factor + attn += attn_bias + attn = attn.softmax(dim=-1) + attn = torch.dropout(attn, p=drop_rate, train=True) + x = attn @ v + else: + raise NotImplementedError(f"Unsupported attention mode: {mode}") + + x = post_attn_layout(x) + b, s, a, d = x.shape + out = x.reshape(b, s, -1) + return out + + +def _ntuple(n): + def parse(x): + if isinstance(x, collections.abc.Iterable) and not isinstance(x, str): + x = tuple(x) + if len(x) == 1: + x = tuple(itertools.repeat(x[0], n)) + return x + return tuple(itertools.repeat(x, n)) + + return parse + + +to_1tuple = _ntuple(1) +to_2tuple = _ntuple(2) + + +def get_activation_layer(act_type): + """get activation layer + + Args: + act_type (str): the activation type + + Returns: + torch.nn.functional: the activation layer + """ + if act_type == "gelu": + return lambda: nn.GELU() + elif act_type == "gelu_tanh": + # Approximate `tanh` requires torch >= 1.13 + return lambda: nn.GELU(approximate="tanh") + elif act_type == "relu": + return nn.ReLU + elif act_type == "silu": + return nn.SiLU + else: + raise ValueError(f"Unknown activation type: {act_type}") + + +def reshape_for_broadcast( + freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]], + x: torch.Tensor, + head_first=False, +): + """ + Reshape frequency tensor for broadcasting it with another tensor. + + This function reshapes the frequency tensor to have the same shape as the target tensor 'x' for the purpose of + broadcasting the frequency tensor during element-wise operations. + + Notes: + When using FlashMHAModified, head_first should be False. When using Attention, head_first should be True. + + Args: + freqs_cis (Union[torch.Tensor, Tuple[torch.Tensor]]): Frequency tensor to be reshaped. + x (torch.Tensor): Target tensor for broadcasting compatibility. + head_first (bool): head dimension first (except batch dim) or not. + + Returns: + torch.Tensor: Reshaped frequency tensor. + + Raises: + AssertionError: If the frequency tensor doesn't match the expected shape. AssertionError: If the target tensor + 'x' doesn't have the expected number of dimensions. + """ + ndim = x.ndim + assert 0 <= 1 < ndim + + if isinstance(freqs_cis, tuple): + # freqs_cis: (cos, sin) in real space + if head_first: + assert freqs_cis[0].shape == ( + x.shape[-2], + x.shape[-1], + ), f"freqs_cis shape {freqs_cis[0].shape} does not match x shape {x.shape}" + shape = [d if i == ndim - 2 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] + else: + assert freqs_cis[0].shape == ( + x.shape[1], + x.shape[-1], + ), f"freqs_cis shape {freqs_cis[0].shape} does not match x shape {x.shape}" + shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] + return freqs_cis[0].view(*shape), freqs_cis[1].view(*shape) + else: + # freqs_cis: values in complex space + if head_first: + assert freqs_cis.shape == ( + x.shape[-2], + x.shape[-1], + ), f"freqs_cis shape {freqs_cis.shape} does not match x shape {x.shape}" + shape = [d if i == ndim - 2 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] + else: + assert freqs_cis.shape == ( + x.shape[1], + x.shape[-1], + ), f"freqs_cis shape {freqs_cis.shape} does not match x shape {x.shape}" + shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] + return freqs_cis.view(*shape) + + +def rotate_half(x): + x_real, x_imag = x.float().reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2] + return torch.stack([-x_imag, x_real], dim=-1).flatten(3) + + +def apply_rotary_emb( + xq: torch.Tensor, + xk: torch.Tensor, + freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], + head_first: bool = False, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Apply rotary embeddings to input tensors using the given frequency tensor. + + This function applies rotary embeddings to the given query 'xq' and key 'xk' tensors using the provided frequency + tensor 'freqs_cis'. The input tensors are reshaped as complex numbers, and the frequency tensor is reshaped for + broadcasting compatibility. The resulting tensors contain rotary embeddings and are returned as real tensors. + + Args: + xq (torch.Tensor): Query tensor to apply rotary embeddings. [B, S, H, D] + xk (torch.Tensor): Key tensor to apply rotary embeddings. [B, S, H, D] + freqs_cis (torch.Tensor or tuple): Precomputed frequency tensor for complex exponential. + head_first (bool): head dimension first (except batch dim) or not. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings. + + """ + xk_out = None + if isinstance(freqs_cis, tuple): + cos, sin = reshape_for_broadcast(freqs_cis, xq, head_first) # [S, D] + cos, sin = cos.to(xq.device), sin.to(xq.device) + # real * cos - imag * sin + # imag * cos + real * sin + xq_out = (xq.float() * cos + rotate_half(xq.float()) * sin).type_as(xq) + xk_out = (xk.float() * cos + rotate_half(xk.float()) * sin).type_as(xk) + else: + # view_as_complex will pack [..., D/2, 2](real) to [..., D/2](complex) + xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # [B, S, H, D//2] + freqs_cis = reshape_for_broadcast(freqs_cis, xq_, head_first).to(xq.device) # [S, D//2] --> [1, S, 1, D//2] + # (real, imag) * (cos, sin) = (real * cos - imag * sin, imag * cos + real * sin) + # view_as_real will expand [..., D/2](complex) to [..., D/2, 2](real) + xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3).type_as(xq) + xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) # [B, S, H, D//2] + xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3).type_as(xk) + + return xq_out, xk_out + + +class RMSNorm(nn.Module): + def __init__( + self, + dim: int, + elementwise_affine=True, + eps: float = 1e-6, + device=None, + dtype=None, + ): + """ + Initialize the RMSNorm normalization layer. + + Args: + dim (int): The dimension of the input tensor. + eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6. + + Attributes: + eps (float): A small value added to the denominator for numerical stability. + weight (nn.Parameter): Learnable scaling parameter. + + """ + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.eps = eps + if elementwise_affine: + self.weight = nn.Parameter(torch.ones(dim, **factory_kwargs)) + + def _norm(self, x): + """ + Apply the RMSNorm normalization to the input tensor. + + Args: + x (torch.Tensor): The input tensor. + + Returns: + torch.Tensor: The normalized tensor. + + """ + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + + def forward(self, x): + """ + Forward pass through the RMSNorm layer. + + Args: + x (torch.Tensor): The input tensor. + + Returns: + torch.Tensor: The output tensor after applying RMSNorm. + + """ + output = self._norm(x.float()).type_as(x) + if hasattr(self, "weight"): + output = output * self.weight + return output + + +def get_norm_layer(norm_layer): + """ + Get the normalization layer. + + Args: + norm_layer (str): The type of normalization layer. + + Returns: + norm_layer (nn.Module): The normalization layer. + """ + if norm_layer == "layer": + return nn.LayerNorm + elif norm_layer == "rms": + return RMSNorm + else: + raise NotImplementedError(f"Norm layer {norm_layer} is not implemented") + + +def modulate(x, shift=None, scale=None): + """modulate by shift and scale + + Args: + x (torch.Tensor): input tensor. + shift (torch.Tensor, optional): shift tensor. Defaults to None. + scale (torch.Tensor, optional): scale tensor. Defaults to None. + + Returns: + torch.Tensor: the output tensor after modulate. + """ + if scale is None and shift is None: + return x + elif shift is None: + return x * (1 + scale.unsqueeze(1)) + elif scale is None: + return x + shift.unsqueeze(1) + else: + return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) + + +def apply_gate(x, gate=None, tanh=False): + """AI is creating summary for apply_gate + + Args: + x (torch.Tensor): input tensor. + gate (torch.Tensor, optional): gate tensor. Defaults to None. + tanh (bool, optional): whether to use tanh function. Defaults to False. + + Returns: + torch.Tensor: the output tensor after apply gate. + """ + if gate is None: + return x + if tanh: + return x * gate.unsqueeze(1).tanh() + else: + return x * gate.unsqueeze(1) + + +class ModulateDiT(nn.Module): + """Modulation layer for DiT.""" + + def __init__( + self, + hidden_size: int, + factor: int, + act_layer: Callable, + dtype=None, + device=None, + ): + factory_kwargs = {"dtype": dtype, "device": device} + super().__init__() + self.act = act_layer() + self.linear = nn.Linear(hidden_size, factor * hidden_size, bias=True, **factory_kwargs) + # Zero-initialize the modulation + nn.init.zeros_(self.linear.weight) + nn.init.zeros_(self.linear.bias) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.linear(self.act(x)) + + +class MLP(nn.Module): + """MLP as used in Vision Transformer, MLP-Mixer and related networks""" + + def __init__( + self, + in_channels, + hidden_channels=None, + out_features=None, + act_layer=nn.GELU, + norm_layer=None, + bias=True, + drop=0.0, + use_conv=False, + device=None, + dtype=None, + ): + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + out_features = out_features or in_channels + hidden_channels = hidden_channels or in_channels + bias = to_2tuple(bias) + drop_probs = to_2tuple(drop) + linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear + + self.fc1 = linear_layer(in_channels, hidden_channels, bias=bias[0], **factory_kwargs) + self.act = act_layer() + self.drop1 = nn.Dropout(drop_probs[0]) + self.norm = norm_layer(hidden_channels, **factory_kwargs) if norm_layer is not None else nn.Identity() + self.fc2 = linear_layer(hidden_channels, out_features, bias=bias[1], **factory_kwargs) + self.drop2 = nn.Dropout(drop_probs[1]) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop1(x) + x = self.norm(x) + x = self.fc2(x) + x = self.drop2(x) + return x + + +# +class MLPEmbedder(nn.Module): + """copied from https://github.com/black-forest-labs/flux/blob/main/src/flux/modules/layers.py""" + + def __init__(self, in_dim: int, hidden_dim: int, device=None, dtype=None): + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True, **factory_kwargs) + self.silu = nn.SiLU() + self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True, **factory_kwargs) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.out_layer(self.silu(self.in_layer(x))) + + +class FinalLayer(nn.Module): + """The final layer of DiT.""" + + def __init__(self, hidden_size, patch_size, out_channels, act_layer, device=None, dtype=None): + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + + # Just use LayerNorm for the final layer + self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs) + if isinstance(patch_size, int): + self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True, **factory_kwargs) + else: + self.linear = nn.Linear( + hidden_size, + patch_size[0] * patch_size[1] * patch_size[2] * out_channels, + bias=True, + ) + nn.init.zeros_(self.linear.weight) + nn.init.zeros_(self.linear.bias) + + # Here we don't distinguish between the modulate types. Just use the simple one. + self.adaLN_modulation = nn.Sequential( + act_layer(), + nn.Linear(hidden_size, 2 * hidden_size, bias=True, **factory_kwargs), + ) + # Zero-initialize the modulation + nn.init.zeros_(self.adaLN_modulation[1].weight) + nn.init.zeros_(self.adaLN_modulation[1].bias) + + def forward(self, x, c): + shift, scale = self.adaLN_modulation(c).chunk(2, dim=1) + x = modulate(self.norm_final(x), shift=shift, scale=scale) + x = self.linear(x) + return x + + +class PatchEmbed(nn.Module): + """2D Image to Patch Embedding + + Image to Patch Embedding using Conv2d + + A convolution based approach to patchifying a 2D image w/ embedding projection. + + Based on the impl in https://github.com/google-research/vision_transformer + + Hacked together by / Copyright 2020 Ross Wightman + + Remove the _assert function in forward function to be compatible with multi-resolution images. + """ + + def __init__( + self, + patch_size=16, + in_chans=3, + embed_dim=768, + norm_layer=None, + flatten=True, + bias=True, + dtype=None, + device=None, + ): + factory_kwargs = {"dtype": dtype, "device": device} + super().__init__() + patch_size = to_2tuple(patch_size) + self.patch_size = patch_size + self.flatten = flatten + + self.proj = nn.Conv3d( + in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias, **factory_kwargs + ) + nn.init.xavier_uniform_(self.proj.weight.view(self.proj.weight.size(0), -1)) + if bias: + nn.init.zeros_(self.proj.bias) + + self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + + def forward(self, x): + x = self.proj(x) + if self.flatten: + x = x.flatten(2).transpose(1, 2) # BCHW -> BNC + x = self.norm(x) + return x + + +class TextProjection(nn.Module): + """ + Projects text embeddings. Also handles dropout for classifier-free guidance. + + Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py + """ + + def __init__(self, in_channels, hidden_size, act_layer, dtype=None, device=None): + factory_kwargs = {"dtype": dtype, "device": device} + super().__init__() + self.linear_1 = nn.Linear(in_features=in_channels, out_features=hidden_size, bias=True, **factory_kwargs) + self.act_1 = act_layer() + self.linear_2 = nn.Linear(in_features=hidden_size, out_features=hidden_size, bias=True, **factory_kwargs) + + def forward(self, caption): + hidden_states = self.linear_1(caption) + hidden_states = self.act_1(hidden_states) + hidden_states = self.linear_2(hidden_states) + return hidden_states + + +def timestep_embedding(t, dim, max_period=10000): + """ + Create sinusoidal timestep embeddings. + + Args: + t (torch.Tensor): a 1-D Tensor of N indices, one per batch element. These may be fractional. + dim (int): the dimension of the output. + max_period (int): controls the minimum frequency of the embeddings. + + Returns: + embedding (torch.Tensor): An (N, D) Tensor of positional embeddings. + + .. ref_link: https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py + """ + half = dim // 2 + freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to( + device=t.device + ) + args = t[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + +class TimestepEmbedder(nn.Module): + """ + Embeds scalar timesteps into vector representations. + """ + + def __init__( + self, + hidden_size, + act_layer, + frequency_embedding_size=256, + max_period=10000, + out_size=None, + dtype=None, + device=None, + ): + factory_kwargs = {"dtype": dtype, "device": device} + super().__init__() + self.frequency_embedding_size = frequency_embedding_size + self.max_period = max_period + if out_size is None: + out_size = hidden_size + + self.mlp = nn.Sequential( + nn.Linear(frequency_embedding_size, hidden_size, bias=True, **factory_kwargs), + act_layer(), + nn.Linear(hidden_size, out_size, bias=True, **factory_kwargs), + ) + nn.init.normal_(self.mlp[0].weight, std=0.02) + nn.init.normal_(self.mlp[2].weight, std=0.02) + + def forward(self, t): + t_freq = timestep_embedding(t, self.frequency_embedding_size, self.max_period).type(self.mlp[0].weight.dtype) + t_emb = self.mlp(t_freq) + return t_emb + + +class IndividualTokenRefinerBlock(nn.Module): + def __init__( + self, + hidden_size, + heads_num, + mlp_width_ratio: str = 4.0, + mlp_drop_rate: float = 0.0, + act_type: str = "silu", + qk_norm: bool = False, + qk_norm_type: str = "layer", + qkv_bias: bool = True, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + ): + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.heads_num = heads_num + head_dim = hidden_size // heads_num + mlp_hidden_dim = int(hidden_size * mlp_width_ratio) + + self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6, **factory_kwargs) + self.self_attn_qkv = nn.Linear(hidden_size, hidden_size * 3, bias=qkv_bias, **factory_kwargs) + qk_norm_layer = get_norm_layer(qk_norm_type) + self.self_attn_q_norm = ( + qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity() + ) + self.self_attn_k_norm = ( + qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity() + ) + self.self_attn_proj = nn.Linear(hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs) + + self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6, **factory_kwargs) + act_layer = get_activation_layer(act_type) + self.mlp = MLP( + in_channels=hidden_size, + hidden_channels=mlp_hidden_dim, + act_layer=act_layer, + drop=mlp_drop_rate, + **factory_kwargs, + ) + + self.adaLN_modulation = nn.Sequential( + act_layer(), + nn.Linear(hidden_size, 2 * hidden_size, bias=True, **factory_kwargs), + ) + # Zero-initialize the modulation + nn.init.zeros_(self.adaLN_modulation[1].weight) + nn.init.zeros_(self.adaLN_modulation[1].bias) + + def forward( + self, + x: torch.Tensor, + c: torch.Tensor, # timestep_aware_representations + context_aware_representations + attn_mask: torch.Tensor = None, + ): + gate_msa, gate_mlp = self.adaLN_modulation(c).chunk(2, dim=1) + + norm_x = self.norm1(x) + qkv = self.self_attn_qkv(norm_x) + q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num) + # Apply QK-Norm if needed + q = self.self_attn_q_norm(q).to(v) + k = self.self_attn_k_norm(k).to(v) + + # Self-Attention + attn = attention(q, k, v, mode="torch", attn_mask=attn_mask) + + x = x + apply_gate(self.self_attn_proj(attn), gate_msa) + + # FFN Layer + x = x + apply_gate(self.mlp(self.norm2(x)), gate_mlp) + + return x + + +class IndividualTokenRefiner(nn.Module): + def __init__( + self, + hidden_size, + heads_num, + depth, + mlp_width_ratio: float = 4.0, + mlp_drop_rate: float = 0.0, + act_type: str = "silu", + qk_norm: bool = False, + qk_norm_type: str = "layer", + qkv_bias: bool = True, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + ): + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.blocks = nn.ModuleList( + [ + IndividualTokenRefinerBlock( + hidden_size=hidden_size, + heads_num=heads_num, + mlp_width_ratio=mlp_width_ratio, + mlp_drop_rate=mlp_drop_rate, + act_type=act_type, + qk_norm=qk_norm, + qk_norm_type=qk_norm_type, + qkv_bias=qkv_bias, + **factory_kwargs, + ) + for _ in range(depth) + ] + ) + + def forward( + self, + x: torch.Tensor, + c: torch.LongTensor, + mask: Optional[torch.Tensor] = None, + ): + self_attn_mask = None + if mask is not None: + batch_size = mask.shape[0] + seq_len = mask.shape[1] + mask = mask.to(x.device) + # batch_size x 1 x seq_len x seq_len + self_attn_mask_1 = mask.view(batch_size, 1, 1, seq_len).repeat(1, 1, seq_len, 1) + # batch_size x 1 x seq_len x seq_len + self_attn_mask_2 = self_attn_mask_1.transpose(2, 3) + # batch_size x 1 x seq_len x seq_len, 1 for broadcasting of heads_num + self_attn_mask = (self_attn_mask_1 & self_attn_mask_2).bool() + # avoids self-attention weight being NaN for padding tokens + self_attn_mask[:, :, :, 0] = True + + for block in self.blocks: + x = block(x, c, self_attn_mask) + return x + + +class SingleTokenRefiner(nn.Module): + """ + A single token refiner block for llm text embedding refine. + """ + + def __init__( + self, + in_channels, + hidden_size, + heads_num, + depth, + mlp_width_ratio: float = 4.0, + mlp_drop_rate: float = 0.0, + act_type: str = "silu", + qk_norm: bool = False, + qk_norm_type: str = "layer", + qkv_bias: bool = True, + attn_mode: str = "torch", + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + ): + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.attn_mode = attn_mode + assert self.attn_mode == "torch", "Only support 'torch' mode for token refiner." + + self.input_embedder = nn.Linear(in_channels, hidden_size, bias=True, **factory_kwargs) + + act_layer = get_activation_layer(act_type) + # Build timestep embedding layer + self.t_embedder = TimestepEmbedder(hidden_size, act_layer, **factory_kwargs) + # Build context embedding layer + self.c_embedder = TextProjection(in_channels, hidden_size, act_layer, **factory_kwargs) + + self.individual_token_refiner = IndividualTokenRefiner( + hidden_size=hidden_size, + heads_num=heads_num, + depth=depth, + mlp_width_ratio=mlp_width_ratio, + mlp_drop_rate=mlp_drop_rate, + act_type=act_type, + qk_norm=qk_norm, + qk_norm_type=qk_norm_type, + qkv_bias=qkv_bias, + **factory_kwargs, + ) + + def forward( + self, + x: torch.Tensor, + t: torch.LongTensor, + mask: Optional[torch.LongTensor] = None, + ): + timestep_aware_representations = self.t_embedder(t) + + if mask is None: + context_aware_representations = x.mean(dim=1) + else: + mask_float = mask.float().unsqueeze(-1) # [b, s1, 1] + context_aware_representations = (x * mask_float).sum(dim=1) / mask_float.sum(dim=1) + context_aware_representations = self.c_embedder(context_aware_representations) + c = timestep_aware_representations + context_aware_representations + + x = self.input_embedder(x) + + x = self.individual_token_refiner(x, c, mask) + + return x + + +class HunyuanVideoDoubleStreamBlock(nn.Module): + """ + A multimodal dit block with seperate modulation for text and image/video, see more details (SD3): + https://arxiv.org/abs/2403.03206 + (Flux.1): https://github.com/black-forest-labs/flux + """ + + def __init__( + self, + hidden_size: int, + heads_num: int, + mlp_width_ratio: float, + mlp_act_type: str = "gelu_tanh", + qk_norm: bool = True, + qk_norm_type: str = "rms", + qkv_bias: bool = False, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + ): + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + + self.deterministic = False + self.heads_num = heads_num + head_dim = hidden_size // heads_num + mlp_hidden_dim = int(hidden_size * mlp_width_ratio) + + self.img_mod = ModulateDiT( + hidden_size, + factor=6, + act_layer=get_activation_layer("silu"), + **factory_kwargs, + ) + self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs) + + self.img_attn_qkv = nn.Linear(hidden_size, hidden_size * 3, bias=qkv_bias, **factory_kwargs) + qk_norm_layer = get_norm_layer(qk_norm_type) + self.img_attn_q_norm = ( + qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity() + ) + self.img_attn_k_norm = ( + qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity() + ) + self.img_attn_proj = nn.Linear(hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs) + + self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs) + self.img_mlp = MLP( + hidden_size, + mlp_hidden_dim, + act_layer=get_activation_layer(mlp_act_type), + bias=True, + **factory_kwargs, + ) + + self.txt_mod = ModulateDiT( + hidden_size, + factor=6, + act_layer=get_activation_layer("silu"), + **factory_kwargs, + ) + self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs) + + self.txt_attn_qkv = nn.Linear(hidden_size, hidden_size * 3, bias=qkv_bias, **factory_kwargs) + self.txt_attn_q_norm = ( + qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity() + ) + self.txt_attn_k_norm = ( + qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity() + ) + self.txt_attn_proj = nn.Linear(hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs) + + self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs) + self.txt_mlp = MLP( + hidden_size, + mlp_hidden_dim, + act_layer=get_activation_layer(mlp_act_type), + bias=True, + **factory_kwargs, + ) + + def enable_deterministic(self): + self.deterministic = True + + def disable_deterministic(self): + self.deterministic = False + + def forward( + self, + img: torch.Tensor, + txt: torch.Tensor, + vec: torch.Tensor, + cu_seqlens_q: Optional[torch.Tensor] = None, + cu_seqlens_kv: Optional[torch.Tensor] = None, + max_seqlen_q: Optional[int] = None, + max_seqlen_kv: Optional[int] = None, + freqs_cis: tuple = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + ( + img_mod1_shift, + img_mod1_scale, + img_mod1_gate, + img_mod2_shift, + img_mod2_scale, + img_mod2_gate, + ) = self.img_mod(vec).chunk(6, dim=-1) + ( + txt_mod1_shift, + txt_mod1_scale, + txt_mod1_gate, + txt_mod2_shift, + txt_mod2_scale, + txt_mod2_gate, + ) = self.txt_mod(vec).chunk(6, dim=-1) + + # Prepare image for attention. + img_modulated = self.img_norm1(img) + img_modulated = modulate(img_modulated, shift=img_mod1_shift, scale=img_mod1_scale) + img_qkv = self.img_attn_qkv(img_modulated) + img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num) + # Apply QK-Norm if needed + img_q = self.img_attn_q_norm(img_q).to(img_v) + img_k = self.img_attn_k_norm(img_k).to(img_v) + + # Apply RoPE if needed. + if freqs_cis is not None: + img_qq, img_kk = apply_rotary_emb(img_q, img_k, freqs_cis, head_first=False) + assert ( + img_qq.shape == img_q.shape and img_kk.shape == img_k.shape + ), f"img_kk: {img_qq.shape}, img_q: {img_q.shape}, img_kk: {img_kk.shape}, img_k: {img_k.shape}" + img_q, img_k = img_qq, img_kk + + # Prepare txt for attention. + txt_modulated = self.txt_norm1(txt) + txt_modulated = modulate(txt_modulated, shift=txt_mod1_shift, scale=txt_mod1_scale) + txt_qkv = self.txt_attn_qkv(txt_modulated) + txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num) + # Apply QK-Norm if needed. + txt_q = self.txt_attn_q_norm(txt_q).to(txt_v) + txt_k = self.txt_attn_k_norm(txt_k).to(txt_v) + + # Run actual attention. + q = torch.cat((img_q, txt_q), dim=1) + k = torch.cat((img_k, txt_k), dim=1) + v = torch.cat((img_v, txt_v), dim=1) + assert ( + cu_seqlens_q.shape[0] == 2 * img.shape[0] + 1 + ), f"cu_seqlens_q.shape:{cu_seqlens_q.shape}, img.shape[0]:{img.shape[0]}" + attn = attention( + q, + k, + v, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_kv=cu_seqlens_kv, + max_seqlen_q=max_seqlen_q, + max_seqlen_kv=max_seqlen_kv, + batch_size=img_k.shape[0], + ) + + img_attn, txt_attn = attn[:, : img.shape[1]], attn[:, img.shape[1] :] + + # Calculate the img bloks. + img = img + apply_gate(self.img_attn_proj(img_attn), gate=img_mod1_gate) + img = img + apply_gate( + self.img_mlp(modulate(self.img_norm2(img), shift=img_mod2_shift, scale=img_mod2_scale)), + gate=img_mod2_gate, + ) + + # Calculate the txt bloks. + txt = txt + apply_gate(self.txt_attn_proj(txt_attn), gate=txt_mod1_gate) + txt = txt + apply_gate( + self.txt_mlp(modulate(self.txt_norm2(txt), shift=txt_mod2_shift, scale=txt_mod2_scale)), + gate=txt_mod2_gate, + ) + + return img, txt + + +class HunyuanVideoSingleStreamBlock(nn.Module): + """ + A DiT block with parallel linear layers as described in https://arxiv.org/abs/2302.05442 and adapted modulation + interface. Also refer to (SD3): https://arxiv.org/abs/2403.03206 + (Flux.1): https://github.com/black-forest-labs/flux + """ + + def __init__( + self, + hidden_size: int, + heads_num: int, + mlp_width_ratio: float = 4.0, + mlp_act_type: str = "gelu_tanh", + qk_norm: bool = True, + qk_norm_type: str = "rms", + qk_scale: float = None, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + ): + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + + self.deterministic = False + self.hidden_size = hidden_size + self.heads_num = heads_num + head_dim = hidden_size // heads_num + mlp_hidden_dim = int(hidden_size * mlp_width_ratio) + self.mlp_hidden_dim = mlp_hidden_dim + self.scale = qk_scale or head_dim**-0.5 + + # qkv and mlp_in + self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + mlp_hidden_dim, **factory_kwargs) + # proj and mlp_out + self.linear2 = nn.Linear(hidden_size + mlp_hidden_dim, hidden_size, **factory_kwargs) + + qk_norm_layer = get_norm_layer(qk_norm_type) + self.q_norm = ( + qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity() + ) + self.k_norm = ( + qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity() + ) + + self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs) + + self.mlp_act = get_activation_layer(mlp_act_type)() + self.modulation = ModulateDiT( + hidden_size, + factor=3, + act_layer=get_activation_layer("silu"), + **factory_kwargs, + ) + + def enable_deterministic(self): + self.deterministic = True + + def disable_deterministic(self): + self.deterministic = False + + def forward( + self, + x: torch.Tensor, + vec: torch.Tensor, + txt_len: int, + cu_seqlens_q: Optional[torch.Tensor] = None, + cu_seqlens_kv: Optional[torch.Tensor] = None, + max_seqlen_q: Optional[int] = None, + max_seqlen_kv: Optional[int] = None, + freqs_cis: Tuple[torch.Tensor, torch.Tensor] = None, + ) -> torch.Tensor: + mod_shift, mod_scale, mod_gate = self.modulation(vec).chunk(3, dim=-1) + x_mod = modulate(self.pre_norm(x), shift=mod_shift, scale=mod_scale) + qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1) + + q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num) + + # Apply QK-Norm if needed. + q = self.q_norm(q).to(v) + k = self.k_norm(k).to(v) + + # Apply RoPE if needed. + if freqs_cis is not None: + img_q, txt_q = q[:, :-txt_len, :, :], q[:, -txt_len:, :, :] + img_k, txt_k = k[:, :-txt_len, :, :], k[:, -txt_len:, :, :] + img_qq, img_kk = apply_rotary_emb(img_q, img_k, freqs_cis, head_first=False) + assert ( + img_qq.shape == img_q.shape and img_kk.shape == img_k.shape + ), f"img_kk: {img_qq.shape}, img_q: {img_q.shape}, img_kk: {img_kk.shape}, img_k: {img_k.shape}" + img_q, img_k = img_qq, img_kk + q = torch.cat((img_q, txt_q), dim=1) + k = torch.cat((img_k, txt_k), dim=1) + + # Compute attention. + assert ( + cu_seqlens_q.shape[0] == 2 * x.shape[0] + 1 + ), f"cu_seqlens_q.shape:{cu_seqlens_q.shape}, x.shape[0]:{x.shape[0]}" + attn = attention( + q, + k, + v, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_kv=cu_seqlens_kv, + max_seqlen_q=max_seqlen_q, + max_seqlen_kv=max_seqlen_kv, + batch_size=x.shape[0], + ) + + # Compute activation in mlp stream, cat again and run second linear layer. + output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2)) + return x + apply_gate(output, gate=mod_gate) + + +class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin): + """ + HunyuanVideo Transformer backbone + + Inherited from ModelMixin and ConfigMixin for compatibility with diffusers' sampler StableDiffusionPipeline. + + Reference: [1] Flux.1: https://github.com/black-forest-labs/flux [2] MMDiT: http://arxiv.org/abs/2403.03206 + + Parameters ---------- args: argparse.Namespace + The arguments parsed by argparse. + patch_size: list + The size of the patch. + in_channels: int + The number of input channels. + out_channels: int + The number of output channels. + hidden_size: int + The hidden size of the transformer backbone. + heads_num: int + The number of attention heads. + mlp_width_ratio: float + The ratio of the hidden size of the MLP in the transformer block. + mlp_act_type: str + The activation function of the MLP in the transformer block. + depth_double_blocks: int + The number of transformer blocks in the double blocks. + depth_single_blocks: int + The number of transformer blocks in the single blocks. + rope_dim_list: list + The dimension of the rotary embedding for t, h, w. + qkv_bias: bool + Whether to use bias in the qkv linear layer. + qk_norm: bool + Whether to use qk norm. + qk_norm_type: str + The type of qk norm. + guidance_embed: bool + Whether to use guidance embedding for distillation. + text_projection: str + The type of the text projection, default is single_refiner. + use_attention_mask: bool + Whether to use attention mask for text encoder. + dtype: torch.dtype + The dtype of the model. + device: torch.device + The device of the model. + """ + + @register_to_config + def __init__( + self, + args: Any, + patch_size: list = [1, 2, 2], + in_channels: int = 4, # Should be VAE.config.latent_channels. + out_channels: int = None, + hidden_size: int = 3072, + heads_num: int = 24, + mlp_width_ratio: float = 4.0, + mlp_act_type: str = "gelu_tanh", + mm_double_blocks_depth: int = 20, + mm_single_blocks_depth: int = 40, + rope_dim_list: List[int] = [16, 56, 56], + qkv_bias: bool = True, + qk_norm: bool = True, + qk_norm_type: str = "rms", + guidance_embed: bool = False, # For modulation. + text_projection: str = "single_refiner", + use_attention_mask: bool = True, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + ): + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + + self.patch_size = patch_size + self.in_channels = in_channels + self.out_channels = in_channels if out_channels is None else out_channels + self.unpatchify_channels = self.out_channels + self.guidance_embed = guidance_embed + self.rope_dim_list = rope_dim_list + + # Text projection. Default to linear projection. + # Alternative: TokenRefiner. See more details (LI-DiT): http://arxiv.org/abs/2406.11831 + self.use_attention_mask = use_attention_mask + self.text_projection = text_projection + + self.text_states_dim = args.text_states_dim + self.text_states_dim_2 = args.text_states_dim_2 + + if hidden_size % heads_num != 0: + raise ValueError(f"Hidden size {hidden_size} must be divisible by heads_num {heads_num}") + pe_dim = hidden_size // heads_num + if sum(rope_dim_list) != pe_dim: + raise ValueError(f"Got {rope_dim_list} but expected positional dim {pe_dim}") + self.hidden_size = hidden_size + self.heads_num = heads_num + + # image projection + self.img_in = PatchEmbed(self.patch_size, self.in_channels, self.hidden_size, **factory_kwargs) + + # text projection + if self.text_projection == "linear": + self.txt_in = TextProjection( + self.text_states_dim, + self.hidden_size, + get_activation_layer("silu"), + **factory_kwargs, + ) + elif self.text_projection == "single_refiner": + self.txt_in = SingleTokenRefiner(self.text_states_dim, hidden_size, heads_num, depth=2, **factory_kwargs) + else: + raise NotImplementedError(f"Unsupported text_projection: {self.text_projection}") + + # time modulation + self.time_in = TimestepEmbedder(self.hidden_size, get_activation_layer("silu"), **factory_kwargs) + + # text modulation + self.vector_in = MLPEmbedder(self.text_states_dim_2, self.hidden_size, **factory_kwargs) + + # guidance modulation + self.guidance_in = ( + TimestepEmbedder(self.hidden_size, get_activation_layer("silu"), **factory_kwargs) + if guidance_embed + else None + ) + + # double blocks + self.double_blocks = nn.ModuleList( + [ + HunyuanVideoDoubleStreamBlock( + self.hidden_size, + self.heads_num, + mlp_width_ratio=mlp_width_ratio, + mlp_act_type=mlp_act_type, + qk_norm=qk_norm, + qk_norm_type=qk_norm_type, + qkv_bias=qkv_bias, + **factory_kwargs, + ) + for _ in range(mm_double_blocks_depth) + ] + ) + + # single blocks + self.single_blocks = nn.ModuleList( + [ + HunyuanVideoSingleStreamBlock( + self.hidden_size, + self.heads_num, + mlp_width_ratio=mlp_width_ratio, + mlp_act_type=mlp_act_type, + qk_norm=qk_norm, + qk_norm_type=qk_norm_type, + **factory_kwargs, + ) + for _ in range(mm_single_blocks_depth) + ] + ) + + self.final_layer = FinalLayer( + self.hidden_size, + self.patch_size, + self.out_channels, + get_activation_layer("silu"), + **factory_kwargs, + ) + + def enable_deterministic(self): + for block in self.double_blocks: + block.enable_deterministic() + for block in self.single_blocks: + block.enable_deterministic() + + def disable_deterministic(self): + for block in self.double_blocks: + block.disable_deterministic() + for block in self.single_blocks: + block.disable_deterministic() + + def forward( + self, + x: torch.Tensor, + t: torch.Tensor, # Should be in range(0, 1000). + text_states: torch.Tensor = None, + text_mask: torch.Tensor = None, # Now we don't use it. + text_states_2: Optional[torch.Tensor] = None, # Text embedding for modulation. + freqs_cos: Optional[torch.Tensor] = None, + freqs_sin: Optional[torch.Tensor] = None, + guidance: torch.Tensor = None, # Guidance for modulation, should be cfg_scale x 1000. + return_dict: bool = True, + ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: + out = {} + img = x + txt = text_states + _, _, ot, oh, ow = x.shape + tt, th, tw = ( + ot // self.patch_size[0], + oh // self.patch_size[1], + ow // self.patch_size[2], + ) + + # Prepare modulation vectors. + vec = self.time_in(t) + + # text modulation + vec = vec + self.vector_in(text_states_2) + + # guidance modulation + if self.guidance_embed: + if guidance is None: + raise ValueError("Didn't get guidance strength for guidance distilled model.") + + # our timestep_embedding is merged into guidance_in(TimestepEmbedder) + vec = vec + self.guidance_in(guidance) + + # Embed image and text. + img = self.img_in(img) + if self.text_projection == "linear": + txt = self.txt_in(txt) + elif self.text_projection == "single_refiner": + txt = self.txt_in(txt, t, text_mask if self.use_attention_mask else None) + else: + raise NotImplementedError(f"Unsupported text_projection: {self.text_projection}") + + txt_seq_len = txt.shape[1] + img_seq_len = img.shape[1] + + # Compute cu_squlens and max_seqlen for flash attention + cu_seqlens_q = get_cu_seqlens(text_mask, img_seq_len) + cu_seqlens_kv = cu_seqlens_q + max_seqlen_q = img_seq_len + txt_seq_len + max_seqlen_kv = max_seqlen_q + + freqs_cis = (freqs_cos, freqs_sin) if freqs_cos is not None else None + # --------------------- Pass through DiT blocks ------------------------ + for _, block in enumerate(self.double_blocks): + double_block_args = [ + img, + txt, + vec, + cu_seqlens_q, + cu_seqlens_kv, + max_seqlen_q, + max_seqlen_kv, + freqs_cis, + ] + + img, txt = block(*double_block_args) + + # Merge txt and img to pass through single stream blocks. + x = torch.cat((img, txt), 1) + if len(self.single_blocks) > 0: + for _, block in enumerate(self.single_blocks): + single_block_args = [ + x, + vec, + txt_seq_len, + cu_seqlens_q, + cu_seqlens_kv, + max_seqlen_q, + max_seqlen_kv, + (freqs_cos, freqs_sin), + ] + + x = block(*single_block_args) + + img = x[:, :img_seq_len, ...] + + # ---------------------------- Final layer ------------------------------ + img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels) + + img = self.unpatchify(img, tt, th, tw) + if return_dict: + out["x"] = img + return out + return img + + def unpatchify(self, x, t, h, w): + """ + x: (N, T, patch_size**2 * C) imgs: (N, H, W, C) + """ + c = self.unpatchify_channels + pt, ph, pw = self.patch_size + assert t * h * w == x.shape[1] + + x = x.reshape(shape=(x.shape[0], t, h, w, c, pt, ph, pw)) + x = torch.einsum("nthwcopq->nctohpwq", x) + imgs = x.reshape(shape=(x.shape[0], c, t * pt, h * ph, w * pw)) + + return imgs + + def params_count(self): + counts = { + "double": sum( + [ + sum(p.numel() for p in block.img_attn_qkv.parameters()) + + sum(p.numel() for p in block.img_attn_proj.parameters()) + + sum(p.numel() for p in block.img_mlp.parameters()) + + sum(p.numel() for p in block.txt_attn_qkv.parameters()) + + sum(p.numel() for p in block.txt_attn_proj.parameters()) + + sum(p.numel() for p in block.txt_mlp.parameters()) + for block in self.double_blocks + ] + ), + "single": sum( + [ + sum(p.numel() for p in block.linear1.parameters()) + + sum(p.numel() for p in block.linear2.parameters()) + for block in self.single_blocks + ] + ), + "total": sum(p.numel() for p in self.parameters()), + } + counts["attn+mlp"] = counts["double"] + counts["single"] + return counts diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index 5091ff318f1b..2f8dfc010092 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -332,6 +332,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class HunyuanVideoTransformer3DModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class I2VGenXLUNet(metaclass=DummyObject): _backends = ["torch"] From 2e61a9dca4d7f9238856dda55f8f150b88f0d9ae Mon Sep 17 00:00:00 2001 From: Aryan Date: Thu, 5 Dec 2024 21:10:59 +0100 Subject: [PATCH 02/58] copy vae --- src/diffusers/__init__.py | 2 + src/diffusers/models/__init__.py | 2 + src/diffusers/models/autoencoders/__init__.py | 1 + .../autoencoder_kl_hunyuan_video.py | 1486 +++++++++++++++++ .../transformers/transformer_hunyuan_video.py | 14 + src/diffusers/utils/dummy_pt_objects.py | 15 + 6 files changed, 1520 insertions(+) create mode 100644 src/diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index caefc093ed64..29284ba9e83f 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -83,6 +83,7 @@ "AutoencoderKL", "AutoencoderKLAllegro", "AutoencoderKLCogVideoX", + "AutoencoderKLHunyuanVideo", "AutoencoderKLMochi", "AutoencoderKLTemporalDecoder", "AutoencoderOobleck", @@ -576,6 +577,7 @@ AutoencoderKL, AutoencoderKLAllegro, AutoencoderKLCogVideoX, + AutoencoderKLHunyuanVideo, AutoencoderKLMochi, AutoencoderKLTemporalDecoder, AutoencoderOobleck, diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index fdb09fd4ebb5..057123bf4462 100644 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -30,6 +30,7 @@ _import_structure["autoencoders.autoencoder_kl"] = ["AutoencoderKL"] _import_structure["autoencoders.autoencoder_kl_allegro"] = ["AutoencoderKLAllegro"] _import_structure["autoencoders.autoencoder_kl_cogvideox"] = ["AutoencoderKLCogVideoX"] + _import_structure["autoencoders.autoencoder_kl_hunyuan_video"] = ["AutoencoderKLHunyuanVideo"] _import_structure["autoencoders.autoencoder_kl_mochi"] = ["AutoencoderKLMochi"] _import_structure["autoencoders.autoencoder_kl_temporal_decoder"] = ["AutoencoderKLTemporalDecoder"] _import_structure["autoencoders.autoencoder_oobleck"] = ["AutoencoderOobleck"] @@ -92,6 +93,7 @@ AutoencoderKL, AutoencoderKLAllegro, AutoencoderKLCogVideoX, + AutoencoderKLHunyuanVideo, AutoencoderKLMochi, AutoencoderKLTemporalDecoder, AutoencoderOobleck, diff --git a/src/diffusers/models/autoencoders/__init__.py b/src/diffusers/models/autoencoders/__init__.py index ba45d6671252..46c0af0d6554 100644 --- a/src/diffusers/models/autoencoders/__init__.py +++ b/src/diffusers/models/autoencoders/__init__.py @@ -2,6 +2,7 @@ from .autoencoder_kl import AutoencoderKL from .autoencoder_kl_allegro import AutoencoderKLAllegro from .autoencoder_kl_cogvideox import AutoencoderKLCogVideoX +from .autoencoder_kl_hunyuan_video import AutoencoderKLHunyuanVideo from .autoencoder_kl_mochi import AutoencoderKLMochi from .autoencoder_kl_temporal_decoder import AutoencoderKLTemporalDecoder from .autoencoder_oobleck import AutoencoderOobleck diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py b/src/diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py new file mode 100644 index 000000000000..0cdd483de68f --- /dev/null +++ b/src/diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py @@ -0,0 +1,1486 @@ +# Copyright 2024 The Hunyuan Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange + +from ...configuration_utils import ConfigMixin, register_to_config +from ...utils import logging +from ...utils.accelerate_utils import apply_forward_hook +from ..activations import get_activation +from ..attention_processor import ( + Attention, + SpatialNorm, +) +from ..modeling_outputs import AutoencoderKLOutput +from ..modeling_utils import ModelMixin +from ..normalization import AdaGroupNorm, RMSNorm +from .vae import BaseOutput, DecoderOutput, DiagonalGaussianDistribution + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +def prepare_causal_attention_mask(n_frame: int, n_hw: int, dtype, device, batch_size: int = None): + seq_len = n_frame * n_hw + mask = torch.full((seq_len, seq_len), float("-inf"), dtype=dtype, device=device) + for i in range(seq_len): + i_frame = i // n_hw + mask[i, : (i_frame + 1) * n_hw] = 0 + if batch_size is not None: + mask = mask.unsqueeze(0).expand(batch_size, -1, -1) + return mask + + +class CausalConv3d(nn.Module): + """ + Implements a causal 3D convolution layer where each position only depends on previous timesteps and current spatial + locations. This maintains temporal causality in video generation tasks. + """ + + def __init__( + self, + chan_in, + chan_out, + kernel_size: Union[int, Tuple[int, int, int]], + stride: Union[int, Tuple[int, int, int]] = 1, + dilation: Union[int, Tuple[int, int, int]] = 1, + pad_mode="replicate", + **kwargs, + ): + super().__init__() + + self.pad_mode = pad_mode + padding = ( + kernel_size // 2, + kernel_size // 2, + kernel_size // 2, + kernel_size // 2, + kernel_size - 1, + 0, + ) # W, H, T + self.time_causal_padding = padding + + self.conv = nn.Conv3d(chan_in, chan_out, kernel_size, stride=stride, dilation=dilation, **kwargs) + + def forward(self, x): + x = F.pad(x, self.time_causal_padding, mode=self.pad_mode) + return self.conv(x) + + +class UpsampleCausal3D(nn.Module): + """ + A 3D upsampling layer with an optional convolution. + """ + + def __init__( + self, + channels: int, + use_conv: bool = False, + use_conv_transpose: bool = False, + out_channels: Optional[int] = None, + name: str = "conv", + kernel_size: Optional[int] = None, + padding=1, + norm_type=None, + eps=None, + elementwise_affine=None, + bias=True, + interpolate=True, + upsample_factor=(2, 2, 2), + ): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.use_conv_transpose = use_conv_transpose + self.name = name + self.interpolate = interpolate + self.upsample_factor = upsample_factor + + if norm_type == "ln_norm": + self.norm = nn.LayerNorm(channels, eps, elementwise_affine) + elif norm_type == "rms_norm": + self.norm = RMSNorm(channels, eps, elementwise_affine) + elif norm_type is None: + self.norm = None + else: + raise ValueError(f"unknown norm_type: {norm_type}") + + conv = None + if use_conv_transpose: + assert False, "Not Implement yet" + if kernel_size is None: + kernel_size = 4 + conv = nn.ConvTranspose2d( + channels, self.out_channels, kernel_size=kernel_size, stride=2, padding=padding, bias=bias + ) + elif use_conv: + if kernel_size is None: + kernel_size = 3 + conv = CausalConv3d(self.channels, self.out_channels, kernel_size=kernel_size, bias=bias) + + if name == "conv": + self.conv = conv + else: + self.Conv2d_0 = conv + + def forward( + self, + hidden_states: torch.FloatTensor, + output_size: Optional[int] = None, + scale: float = 1.0, + ) -> torch.FloatTensor: + assert hidden_states.shape[1] == self.channels + + if self.norm is not None: + assert False, "Not Implement yet" + hidden_states = self.norm(hidden_states.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) + + if self.use_conv_transpose: + return self.conv(hidden_states) + + # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16 + dtype = hidden_states.dtype + if dtype == torch.bfloat16: + hidden_states = hidden_states.to(torch.float32) + + # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984 + if hidden_states.shape[0] >= 64: + hidden_states = hidden_states.contiguous() + + # if `output_size` is passed we force the interpolation output + # size and do not make use of `scale_factor=2` + if self.interpolate: + B, C, T, H, W = hidden_states.shape + first_h, other_h = hidden_states.split((1, T - 1), dim=2) + if output_size is None: + if T > 1: + other_h = F.interpolate(other_h, scale_factor=self.upsample_factor, mode="nearest") + + first_h = first_h.squeeze(2) + first_h = F.interpolate(first_h, scale_factor=self.upsample_factor[1:], mode="nearest") + first_h = first_h.unsqueeze(2) + else: + assert False, "Not Implement yet" + other_h = F.interpolate(other_h, size=output_size, mode="nearest") + + if T > 1: + hidden_states = torch.cat((first_h, other_h), dim=2) + else: + hidden_states = first_h + + # If the input is bfloat16, we cast back to bfloat16 + if dtype == torch.bfloat16: + hidden_states = hidden_states.to(dtype) + + if self.use_conv: + if self.name == "conv": + hidden_states = self.conv(hidden_states) + else: + hidden_states = self.Conv2d_0(hidden_states) + + return hidden_states + + +class DownsampleCausal3D(nn.Module): + """ + A 3D downsampling layer with an optional convolution. + """ + + def __init__( + self, + channels: int, + use_conv: bool = False, + out_channels: Optional[int] = None, + padding: int = 1, + name: str = "conv", + kernel_size=3, + norm_type=None, + eps=None, + elementwise_affine=None, + bias=True, + stride=2, + ): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.padding = padding + stride = stride + self.name = name + + if norm_type == "ln_norm": + self.norm = nn.LayerNorm(channels, eps, elementwise_affine) + elif norm_type == "rms_norm": + self.norm = RMSNorm(channels, eps, elementwise_affine) + elif norm_type is None: + self.norm = None + else: + raise ValueError(f"unknown norm_type: {norm_type}") + + if use_conv: + conv = CausalConv3d(self.channels, self.out_channels, kernel_size=kernel_size, stride=stride, bias=bias) + else: + raise NotImplementedError + + if name == "conv": + self.Conv2d_0 = conv + self.conv = conv + elif name == "Conv2d_0": + self.conv = conv + else: + self.conv = conv + + def forward(self, hidden_states: torch.FloatTensor, scale: float = 1.0) -> torch.FloatTensor: + assert hidden_states.shape[1] == self.channels + + if self.norm is not None: + hidden_states = self.norm(hidden_states.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) + + assert hidden_states.shape[1] == self.channels + + hidden_states = self.conv(hidden_states) + + return hidden_states + + +def get_down_block3d( + down_block_type: str, + num_layers: int, + in_channels: int, + out_channels: int, + temb_channels: int, + add_downsample: bool, + downsample_stride: int, + resnet_eps: float, + resnet_act_fn: str, + transformer_layers_per_block: int = 1, + num_attention_heads: Optional[int] = None, + resnet_groups: Optional[int] = None, + cross_attention_dim: Optional[int] = None, + downsample_padding: Optional[int] = None, + dual_cross_attention: bool = False, + use_linear_projection: bool = False, + only_cross_attention: bool = False, + upcast_attention: bool = False, + resnet_time_scale_shift: str = "default", + attention_type: str = "default", + resnet_skip_time_act: bool = False, + resnet_out_scale_factor: float = 1.0, + cross_attention_norm: Optional[str] = None, + attention_head_dim: Optional[int] = None, + downsample_type: Optional[str] = None, + dropout: float = 0.0, +): + # If attn head dim is not defined, we default it to the number of heads + if attention_head_dim is None: + logger.warn( + f"It is recommended to provide `attention_head_dim` when calling `get_down_block`. Defaulting `attention_head_dim` to {num_attention_heads}." + ) + attention_head_dim = num_attention_heads + + down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type + if down_block_type == "DownEncoderBlockCausal3D": + return DownEncoderBlockCausal3D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + dropout=dropout, + add_downsample=add_downsample, + downsample_stride=downsample_stride, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + raise ValueError(f"{down_block_type} does not exist.") + + +def get_up_block3d( + up_block_type: str, + num_layers: int, + in_channels: int, + out_channels: int, + prev_output_channel: int, + temb_channels: int, + add_upsample: bool, + upsample_scale_factor: Tuple, + resnet_eps: float, + resnet_act_fn: str, + resolution_idx: Optional[int] = None, + transformer_layers_per_block: int = 1, + num_attention_heads: Optional[int] = None, + resnet_groups: Optional[int] = None, + cross_attention_dim: Optional[int] = None, + dual_cross_attention: bool = False, + use_linear_projection: bool = False, + only_cross_attention: bool = False, + upcast_attention: bool = False, + resnet_time_scale_shift: str = "default", + attention_type: str = "default", + resnet_skip_time_act: bool = False, + resnet_out_scale_factor: float = 1.0, + cross_attention_norm: Optional[str] = None, + attention_head_dim: Optional[int] = None, + upsample_type: Optional[str] = None, + dropout: float = 0.0, +) -> nn.Module: + # If attn head dim is not defined, we default it to the number of heads + if attention_head_dim is None: + logger.warn( + f"It is recommended to provide `attention_head_dim` when calling `get_up_block`. Defaulting `attention_head_dim` to {num_attention_heads}." + ) + attention_head_dim = num_attention_heads + + up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type + if up_block_type == "UpDecoderBlockCausal3D": + return UpDecoderBlockCausal3D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + resolution_idx=resolution_idx, + dropout=dropout, + add_upsample=add_upsample, + upsample_scale_factor=upsample_scale_factor, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + resnet_time_scale_shift=resnet_time_scale_shift, + temb_channels=temb_channels, + ) + raise ValueError(f"{up_block_type} does not exist.") + + +class ResnetBlockCausal3D(nn.Module): + r""" + A Resnet block. + """ + + def __init__( + self, + *, + in_channels: int, + out_channels: Optional[int] = None, + conv_shortcut: bool = False, + dropout: float = 0.0, + temb_channels: int = 512, + groups: int = 32, + groups_out: Optional[int] = None, + pre_norm: bool = True, + eps: float = 1e-6, + non_linearity: str = "swish", + skip_time_act: bool = False, + # default, scale_shift, ada_group, spatial + time_embedding_norm: str = "default", + kernel: Optional[torch.FloatTensor] = None, + output_scale_factor: float = 1.0, + use_in_shortcut: Optional[bool] = None, + up: bool = False, + down: bool = False, + conv_shortcut_bias: bool = True, + conv_3d_out_channels: Optional[int] = None, + ): + super().__init__() + self.pre_norm = pre_norm + self.pre_norm = True + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + self.up = up + self.down = down + self.output_scale_factor = output_scale_factor + self.time_embedding_norm = time_embedding_norm + self.skip_time_act = skip_time_act + + linear_cls = nn.Linear + + if groups_out is None: + groups_out = groups + + if self.time_embedding_norm == "ada_group": + self.norm1 = AdaGroupNorm(temb_channels, in_channels, groups, eps=eps) + elif self.time_embedding_norm == "spatial": + self.norm1 = SpatialNorm(in_channels, temb_channels) + else: + self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True) + + self.conv1 = CausalConv3d(in_channels, out_channels, kernel_size=3, stride=1) + + if temb_channels is not None: + if self.time_embedding_norm == "default": + self.time_emb_proj = linear_cls(temb_channels, out_channels) + elif self.time_embedding_norm == "scale_shift": + self.time_emb_proj = linear_cls(temb_channels, 2 * out_channels) + elif self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial": + self.time_emb_proj = None + else: + raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ") + else: + self.time_emb_proj = None + + if self.time_embedding_norm == "ada_group": + self.norm2 = AdaGroupNorm(temb_channels, out_channels, groups_out, eps=eps) + elif self.time_embedding_norm == "spatial": + self.norm2 = SpatialNorm(out_channels, temb_channels) + else: + self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True) + + self.dropout = torch.nn.Dropout(dropout) + conv_3d_out_channels = conv_3d_out_channels or out_channels + self.conv2 = CausalConv3d(out_channels, conv_3d_out_channels, kernel_size=3, stride=1) + + self.nonlinearity = get_activation(non_linearity) + + self.upsample = self.downsample = None + if self.up: + self.upsample = UpsampleCausal3D(in_channels, use_conv=False) + elif self.down: + self.downsample = DownsampleCausal3D(in_channels, use_conv=False, name="op") + + self.use_in_shortcut = self.in_channels != conv_3d_out_channels if use_in_shortcut is None else use_in_shortcut + + self.conv_shortcut = None + if self.use_in_shortcut: + self.conv_shortcut = CausalConv3d( + in_channels, + conv_3d_out_channels, + kernel_size=1, + stride=1, + bias=conv_shortcut_bias, + ) + + def forward( + self, + input_tensor: torch.FloatTensor, + temb: torch.FloatTensor, + scale: float = 1.0, + ) -> torch.FloatTensor: + hidden_states = input_tensor + + if self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial": + hidden_states = self.norm1(hidden_states, temb) + else: + hidden_states = self.norm1(hidden_states) + + hidden_states = self.nonlinearity(hidden_states) + + if self.upsample is not None: + # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984 + if hidden_states.shape[0] >= 64: + input_tensor = input_tensor.contiguous() + hidden_states = hidden_states.contiguous() + input_tensor = self.upsample(input_tensor, scale=scale) + hidden_states = self.upsample(hidden_states, scale=scale) + elif self.downsample is not None: + input_tensor = self.downsample(input_tensor, scale=scale) + hidden_states = self.downsample(hidden_states, scale=scale) + + hidden_states = self.conv1(hidden_states) + + if self.time_emb_proj is not None: + if not self.skip_time_act: + temb = self.nonlinearity(temb) + temb = self.time_emb_proj(temb, scale)[:, :, None, None] + + if temb is not None and self.time_embedding_norm == "default": + hidden_states = hidden_states + temb + + if self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial": + hidden_states = self.norm2(hidden_states, temb) + else: + hidden_states = self.norm2(hidden_states) + + if temb is not None and self.time_embedding_norm == "scale_shift": + scale, shift = torch.chunk(temb, 2, dim=1) + hidden_states = hidden_states * (1 + scale) + shift + + hidden_states = self.nonlinearity(hidden_states) + + hidden_states = self.dropout(hidden_states) + hidden_states = self.conv2(hidden_states) + + if self.conv_shortcut is not None: + input_tensor = self.conv_shortcut(input_tensor) + + output_tensor = (input_tensor + hidden_states) / self.output_scale_factor + + return output_tensor + + +class UNetMidBlockCausal3D(nn.Module): + """ + A 3D UNet mid-block [`UNetMidBlockCausal3D`] with multiple residual blocks and optional attention blocks. + """ + + def __init__( + self, + in_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", # default, spatial + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + attn_groups: Optional[int] = None, + resnet_pre_norm: bool = True, + add_attention: bool = True, + attention_head_dim: int = 1, + output_scale_factor: float = 1.0, + ): + super().__init__() + resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) + self.add_attention = add_attention + + if attn_groups is None: + attn_groups = resnet_groups if resnet_time_scale_shift == "default" else None + + # there is always at least one resnet + resnets = [ + ResnetBlockCausal3D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ] + attentions = [] + + if attention_head_dim is None: + logger.warn( + f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {in_channels}." + ) + attention_head_dim = in_channels + + for _ in range(num_layers): + if self.add_attention: + # assert False, "Not implemented yet" + attentions.append( + Attention( + in_channels, + heads=in_channels // attention_head_dim, + dim_head=attention_head_dim, + rescale_output_factor=output_scale_factor, + eps=resnet_eps, + norm_num_groups=attn_groups, + spatial_norm_dim=temb_channels if resnet_time_scale_shift == "spatial" else None, + residual_connection=True, + bias=True, + upcast_softmax=True, + _from_deprecated_attn_block=True, + ) + ) + else: + attentions.append(None) + + resnets.append( + ResnetBlockCausal3D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor: + hidden_states = self.resnets[0](hidden_states, temb) + for attn, resnet in zip(self.attentions, self.resnets[1:]): + if attn is not None: + B, C, T, H, W = hidden_states.shape + hidden_states = rearrange(hidden_states, "b c f h w -> b (f h w) c") + attention_mask = prepare_causal_attention_mask( + T, H * W, hidden_states.dtype, hidden_states.device, batch_size=B + ) + hidden_states = attn(hidden_states, temb=temb, attention_mask=attention_mask) + hidden_states = rearrange(hidden_states, "b (f h w) c -> b c f h w", f=T, h=H, w=W) + hidden_states = resnet(hidden_states, temb) + + return hidden_states + + +class DownEncoderBlockCausal3D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor: float = 1.0, + add_downsample: bool = True, + downsample_stride: int = 2, + downsample_padding: int = 1, + ): + super().__init__() + resnets = [] + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlockCausal3D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=None, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + self.resnets = nn.ModuleList(resnets) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + DownsampleCausal3D( + out_channels, + use_conv=True, + out_channels=out_channels, + padding=downsample_padding, + name="op", + stride=downsample_stride, + ) + ] + ) + else: + self.downsamplers = None + + def forward(self, hidden_states: torch.FloatTensor, scale: float = 1.0) -> torch.FloatTensor: + for resnet in self.resnets: + hidden_states = resnet(hidden_states, temb=None, scale=scale) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states, scale) + + return hidden_states + + +class UpDecoderBlockCausal3D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + resolution_idx: Optional[int] = None, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", # default, spatial + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor: float = 1.0, + add_upsample: bool = True, + upsample_scale_factor=(2, 2, 2), + temb_channels: Optional[int] = None, + ): + super().__init__() + resnets = [] + + for i in range(num_layers): + input_channels = in_channels if i == 0 else out_channels + + resnets.append( + ResnetBlockCausal3D( + in_channels=input_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + self.resnets = nn.ModuleList(resnets) + + if add_upsample: + self.upsamplers = nn.ModuleList( + [ + UpsampleCausal3D( + out_channels, + use_conv=True, + out_channels=out_channels, + upsample_factor=upsample_scale_factor, + ) + ] + ) + else: + self.upsamplers = None + + self.resolution_idx = resolution_idx + + def forward( + self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, scale: float = 1.0 + ) -> torch.FloatTensor: + for resnet in self.resnets: + hidden_states = resnet(hidden_states, temb=temb, scale=scale) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states) + + return hidden_states + + +class EncoderCausal3D(nn.Module): + r""" + The `EncoderCausal3D` layer of a variational autoencoder that encodes its input into a latent representation. + """ + + def __init__( + self, + in_channels: int = 3, + out_channels: int = 3, + down_block_types: Tuple[str, ...] = ("DownEncoderBlockCausal3D",), + block_out_channels: Tuple[int, ...] = (64,), + layers_per_block: int = 2, + norm_num_groups: int = 32, + act_fn: str = "silu", + double_z: bool = True, + mid_block_add_attention=True, + time_compression_ratio: int = 4, + spatial_compression_ratio: int = 8, + ): + super().__init__() + self.layers_per_block = layers_per_block + + self.conv_in = CausalConv3d(in_channels, block_out_channels[0], kernel_size=3, stride=1) + self.mid_block = None + self.down_blocks = nn.ModuleList([]) + + # down + output_channel = block_out_channels[0] + for i, down_block_type in enumerate(down_block_types): + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + num_spatial_downsample_layers = int(np.log2(spatial_compression_ratio)) + num_time_downsample_layers = int(np.log2(time_compression_ratio)) + + if time_compression_ratio == 4: + add_spatial_downsample = bool(i < num_spatial_downsample_layers) + add_time_downsample = bool( + i >= (len(block_out_channels) - 1 - num_time_downsample_layers) and not is_final_block + ) + elif time_compression_ratio == 8: + add_spatial_downsample = bool(i < num_spatial_downsample_layers) + add_time_downsample = bool(i < num_time_downsample_layers) + else: + raise ValueError(f"Unsupported time_compression_ratio: {time_compression_ratio}") + + downsample_stride_HW = (2, 2) if add_spatial_downsample else (1, 1) + downsample_stride_T = (2,) if add_time_downsample else (1,) + downsample_stride = tuple(downsample_stride_T + downsample_stride_HW) + down_block = get_down_block3d( + down_block_type, + num_layers=self.layers_per_block, + in_channels=input_channel, + out_channels=output_channel, + add_downsample=bool(add_spatial_downsample or add_time_downsample), + downsample_stride=downsample_stride, + resnet_eps=1e-6, + downsample_padding=0, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + attention_head_dim=output_channel, + temb_channels=None, + ) + self.down_blocks.append(down_block) + + # mid + self.mid_block = UNetMidBlockCausal3D( + in_channels=block_out_channels[-1], + resnet_eps=1e-6, + resnet_act_fn=act_fn, + output_scale_factor=1, + resnet_time_scale_shift="default", + attention_head_dim=block_out_channels[-1], + resnet_groups=norm_num_groups, + temb_channels=None, + add_attention=mid_block_add_attention, + ) + + # out + self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6) + self.conv_act = nn.SiLU() + + conv_out_channels = 2 * out_channels if double_z else out_channels + self.conv_out = CausalConv3d(block_out_channels[-1], conv_out_channels, kernel_size=3) + + def forward(self, sample: torch.FloatTensor) -> torch.FloatTensor: + r"""The forward method of the `EncoderCausal3D` class.""" + assert len(sample.shape) == 5, "The input tensor should have 5 dimensions" + + sample = self.conv_in(sample) + + # down + for down_block in self.down_blocks: + sample = down_block(sample) + + # middle + sample = self.mid_block(sample) + + # post-process + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + + return sample + + +class DecoderCausal3D(nn.Module): + r""" + The `DecoderCausal3D` layer of a variational autoencoder that decodes its latent representation into an output + sample. + """ + + def __init__( + self, + in_channels: int = 3, + out_channels: int = 3, + up_block_types: Tuple[str, ...] = ("UpDecoderBlockCausal3D",), + block_out_channels: Tuple[int, ...] = (64,), + layers_per_block: int = 2, + norm_num_groups: int = 32, + act_fn: str = "silu", + norm_type: str = "group", # group, spatial + mid_block_add_attention=True, + time_compression_ratio: int = 4, + spatial_compression_ratio: int = 8, + ): + super().__init__() + self.layers_per_block = layers_per_block + + self.conv_in = CausalConv3d(in_channels, block_out_channels[-1], kernel_size=3, stride=1) + self.mid_block = None + self.up_blocks = nn.ModuleList([]) + + temb_channels = in_channels if norm_type == "spatial" else None + + # mid + self.mid_block = UNetMidBlockCausal3D( + in_channels=block_out_channels[-1], + resnet_eps=1e-6, + resnet_act_fn=act_fn, + output_scale_factor=1, + resnet_time_scale_shift="default" if norm_type == "group" else norm_type, + attention_head_dim=block_out_channels[-1], + resnet_groups=norm_num_groups, + temb_channels=temb_channels, + add_attention=mid_block_add_attention, + ) + + # up + reversed_block_out_channels = list(reversed(block_out_channels)) + output_channel = reversed_block_out_channels[0] + for i, up_block_type in enumerate(up_block_types): + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + num_spatial_upsample_layers = int(np.log2(spatial_compression_ratio)) + num_time_upsample_layers = int(np.log2(time_compression_ratio)) + + if time_compression_ratio == 4: + add_spatial_upsample = bool(i < num_spatial_upsample_layers) + add_time_upsample = bool( + i >= len(block_out_channels) - 1 - num_time_upsample_layers and not is_final_block + ) + else: + raise ValueError(f"Unsupported time_compression_ratio: {time_compression_ratio}") + + upsample_scale_factor_HW = (2, 2) if add_spatial_upsample else (1, 1) + upsample_scale_factor_T = (2,) if add_time_upsample else (1,) + upsample_scale_factor = tuple(upsample_scale_factor_T + upsample_scale_factor_HW) + up_block = get_up_block3d( + up_block_type, + num_layers=self.layers_per_block + 1, + in_channels=prev_output_channel, + out_channels=output_channel, + prev_output_channel=None, + add_upsample=bool(add_spatial_upsample or add_time_upsample), + upsample_scale_factor=upsample_scale_factor, + resnet_eps=1e-6, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + attention_head_dim=output_channel, + temb_channels=temb_channels, + resnet_time_scale_shift=norm_type, + ) + self.up_blocks.append(up_block) + prev_output_channel = output_channel + + # out + if norm_type == "spatial": + self.conv_norm_out = SpatialNorm(block_out_channels[0], temb_channels) + else: + self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6) + self.conv_act = nn.SiLU() + self.conv_out = CausalConv3d(block_out_channels[0], out_channels, kernel_size=3) + + self.gradient_checkpointing = False + + def forward( + self, + sample: torch.FloatTensor, + latent_embeds: Optional[torch.FloatTensor] = None, + ) -> torch.FloatTensor: + r"""The forward method of the `DecoderCausal3D` class.""" + assert len(sample.shape) == 5, "The input tensor should have 5 dimensions" + + sample = self.conv_in(sample) + + upscale_dtype = next(iter(self.up_blocks.parameters())).dtype + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + # up + for up_block in self.up_blocks: + sample = torch.utils.checkpoint.checkpoint( + create_custom_forward(up_block), + sample, + latent_embeds, + use_reentrant=False, + ) + else: + # middle + sample = torch.utils.checkpoint.checkpoint( + create_custom_forward(self.mid_block), sample, latent_embeds + ) + sample = sample.to(upscale_dtype) + + # up + for up_block in self.up_blocks: + sample = torch.utils.checkpoint.checkpoint(create_custom_forward(up_block), sample, latent_embeds) + else: + # middle + sample = self.mid_block(sample, latent_embeds) + sample = sample.to(upscale_dtype) + + # up + for up_block in self.up_blocks: + sample = up_block(sample, latent_embeds) + + # post-process + if latent_embeds is None: + sample = self.conv_norm_out(sample) + else: + sample = self.conv_norm_out(sample, latent_embeds) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + + return sample + + +@dataclass +class DecoderOutput2(BaseOutput): + sample: torch.FloatTensor + posterior: Optional[DiagonalGaussianDistribution] = None + + +class AutoencoderKLHunyuanVideo(ModelMixin, ConfigMixin): + r""" + A VAE model with KL loss for encoding images/videos into latents and decoding latent representations into + images/videos. + + This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented + for all models (such as downloading or saving). + """ + + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + in_channels: int = 3, + out_channels: int = 3, + down_block_types: Tuple[str] = ("DownEncoderBlockCausal3D",), + up_block_types: Tuple[str] = ("UpDecoderBlockCausal3D",), + block_out_channels: Tuple[int] = (64,), + layers_per_block: int = 1, + act_fn: str = "silu", + latent_channels: int = 4, + norm_num_groups: int = 32, + sample_size: int = 32, + sample_tsize: int = 64, + scaling_factor: float = 0.18215, + force_upcast: float = True, + spatial_compression_ratio: int = 8, + time_compression_ratio: int = 4, + mid_block_add_attention: bool = True, + ): + super().__init__() + + self.time_compression_ratio = time_compression_ratio + + self.encoder = EncoderCausal3D( + in_channels=in_channels, + out_channels=latent_channels, + down_block_types=down_block_types, + block_out_channels=block_out_channels, + layers_per_block=layers_per_block, + act_fn=act_fn, + norm_num_groups=norm_num_groups, + double_z=True, + time_compression_ratio=time_compression_ratio, + spatial_compression_ratio=spatial_compression_ratio, + mid_block_add_attention=mid_block_add_attention, + ) + + self.decoder = DecoderCausal3D( + in_channels=latent_channels, + out_channels=out_channels, + up_block_types=up_block_types, + block_out_channels=block_out_channels, + layers_per_block=layers_per_block, + norm_num_groups=norm_num_groups, + act_fn=act_fn, + time_compression_ratio=time_compression_ratio, + spatial_compression_ratio=spatial_compression_ratio, + mid_block_add_attention=mid_block_add_attention, + ) + + self.quant_conv = nn.Conv3d(2 * latent_channels, 2 * latent_channels, kernel_size=1) + self.post_quant_conv = nn.Conv3d(latent_channels, latent_channels, kernel_size=1) + + self.use_slicing = False + self.use_spatial_tiling = False + self.use_temporal_tiling = False + + # only relevant if vae tiling is enabled + self.tile_sample_min_tsize = sample_tsize + self.tile_latent_min_tsize = sample_tsize // time_compression_ratio + + self.tile_sample_min_size = self.config.sample_size + sample_size = ( + self.config.sample_size[0] + if isinstance(self.config.sample_size, (list, tuple)) + else self.config.sample_size + ) + self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.block_out_channels) - 1))) + self.tile_overlap_factor = 0.25 + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, (EncoderCausal3D, DecoderCausal3D)): + module.gradient_checkpointing = value + + def enable_temporal_tiling(self, use_tiling: bool = True): + self.use_temporal_tiling = use_tiling + + def disable_temporal_tiling(self): + self.enable_temporal_tiling(False) + + def enable_spatial_tiling(self, use_tiling: bool = True): + self.use_spatial_tiling = use_tiling + + def disable_spatial_tiling(self): + self.enable_spatial_tiling(False) + + def enable_tiling(self, use_tiling: bool = True): + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger videos. + """ + self.enable_spatial_tiling(use_tiling) + self.enable_temporal_tiling(use_tiling) + + def disable_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing + decoding in one step. + """ + self.disable_spatial_tiling() + self.disable_temporal_tiling() + + def enable_slicing(self): + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + self.use_slicing = True + + def disable_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing + decoding in one step. + """ + self.use_slicing = False + + @apply_forward_hook + def encode( + self, x: torch.FloatTensor, return_dict: bool = True + ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]: + """ + Encode a batch of images/videos into latents. + + Args: + x (`torch.FloatTensor`): Input batch of images/videos. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple. + + Returns: + The latent representations of the encoded images/videos. If `return_dict` is True, a + [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned. + """ + assert len(x.shape) == 5, "The input tensor should have 5 dimensions" + + if self.use_temporal_tiling and x.shape[2] > self.tile_sample_min_tsize: + return self.temporal_tiled_encode(x, return_dict=return_dict) + + if self.use_spatial_tiling and ( + x.shape[-1] > self.tile_sample_min_size or x.shape[-2] > self.tile_sample_min_size + ): + return self.spatial_tiled_encode(x, return_dict=return_dict) + + if self.use_slicing and x.shape[0] > 1: + encoded_slices = [self.encoder(x_slice) for x_slice in x.split(1)] + h = torch.cat(encoded_slices) + else: + h = self.encoder(x) + + moments = self.quant_conv(h) + posterior = DiagonalGaussianDistribution(moments) + + if not return_dict: + return (posterior,) + + return AutoencoderKLOutput(latent_dist=posterior) + + def _decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]: + assert len(z.shape) == 5, "The input tensor should have 5 dimensions" + + if self.use_temporal_tiling and z.shape[2] > self.tile_latent_min_tsize: + return self.temporal_tiled_decode(z, return_dict=return_dict) + + if self.use_spatial_tiling and ( + z.shape[-1] > self.tile_latent_min_size or z.shape[-2] > self.tile_latent_min_size + ): + return self.spatial_tiled_decode(z, return_dict=return_dict) + + z = self.post_quant_conv(z) + dec = self.decoder(z) + + if not return_dict: + return (dec,) + + return DecoderOutput(sample=dec) + + @apply_forward_hook + def decode( + self, z: torch.FloatTensor, return_dict: bool = True, generator=None + ) -> Union[DecoderOutput, torch.FloatTensor]: + """ + Decode a batch of images/videos. + + Args: + z (`torch.FloatTensor`): Input batch of latent vectors. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. + + Returns: + [`~models.vae.DecoderOutput`] or `tuple`: + If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is + returned. + + """ + if self.use_slicing and z.shape[0] > 1: + decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)] + decoded = torch.cat(decoded_slices) + else: + decoded = self._decode(z).sample + + if not return_dict: + return (decoded,) + + return DecoderOutput(sample=decoded) + + def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + blend_extent = min(a.shape[-2], b.shape[-2], blend_extent) + for y in range(blend_extent): + b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * ( + y / blend_extent + ) + return b + + def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + blend_extent = min(a.shape[-1], b.shape[-1], blend_extent) + for x in range(blend_extent): + b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * ( + x / blend_extent + ) + return b + + def blend_t(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + blend_extent = min(a.shape[-3], b.shape[-3], blend_extent) + for x in range(blend_extent): + b[:, :, x, :, :] = a[:, :, -blend_extent + x, :, :] * (1 - x / blend_extent) + b[:, :, x, :, :] * ( + x / blend_extent + ) + return b + + def spatial_tiled_encode( + self, x: torch.FloatTensor, return_dict: bool = True, return_moments: bool = False + ) -> AutoencoderKLOutput: + r"""Encode a batch of images/videos using a tiled encoder. + + When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several + steps. This is useful to keep memory use constant regardless of image/videos size. The end result of tiled + encoding is different from non-tiled encoding because each tile uses a different encoder. To avoid tiling + artifacts, the tiles overlap and are blended together to form a smooth output. You may still see tile-sized + changes in the output, but they should be much less noticeable. + + Args: + x (`torch.FloatTensor`): Input batch of images/videos. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple. + + Returns: + [`~models.autoencoder_kl.AutoencoderKLOutput`] or `tuple`: + If return_dict is True, a [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain + `tuple` is returned. + """ + overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor)) + blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor) + row_limit = self.tile_latent_min_size - blend_extent + + # Split video into tiles and encode them separately. + rows = [] + for i in range(0, x.shape[-2], overlap_size): + row = [] + for j in range(0, x.shape[-1], overlap_size): + tile = x[:, :, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size] + tile = self.encoder(tile) + tile = self.quant_conv(tile) + row.append(tile) + rows.append(row) + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_extent) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_extent) + result_row.append(tile[:, :, :, :row_limit, :row_limit]) + result_rows.append(torch.cat(result_row, dim=-1)) + + moments = torch.cat(result_rows, dim=-2) + if return_moments: + return moments + + posterior = DiagonalGaussianDistribution(moments) + if not return_dict: + return (posterior,) + + return AutoencoderKLOutput(latent_dist=posterior) + + def spatial_tiled_decode( + self, z: torch.FloatTensor, return_dict: bool = True + ) -> Union[DecoderOutput, torch.FloatTensor]: + r""" + Decode a batch of images/videos using a tiled decoder. + + Args: + z (`torch.FloatTensor`): Input batch of latent vectors. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. + + Returns: + [`~models.vae.DecoderOutput`] or `tuple`: + If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is + returned. + """ + overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor)) + blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor) + row_limit = self.tile_sample_min_size - blend_extent + + # Split z into overlapping tiles and decode them separately. + # The tiles have an overlap to avoid seams between tiles. + rows = [] + for i in range(0, z.shape[-2], overlap_size): + row = [] + for j in range(0, z.shape[-1], overlap_size): + tile = z[:, :, :, i : i + self.tile_latent_min_size, j : j + self.tile_latent_min_size] + tile = self.post_quant_conv(tile) + decoded = self.decoder(tile) + row.append(decoded) + rows.append(row) + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_extent) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_extent) + result_row.append(tile[:, :, :, :row_limit, :row_limit]) + result_rows.append(torch.cat(result_row, dim=-1)) + + dec = torch.cat(result_rows, dim=-2) + if not return_dict: + return (dec,) + + return DecoderOutput(sample=dec) + + def temporal_tiled_encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput: + B, C, T, H, W = x.shape + overlap_size = int(self.tile_sample_min_tsize * (1 - self.tile_overlap_factor)) + blend_extent = int(self.tile_latent_min_tsize * self.tile_overlap_factor) + t_limit = self.tile_latent_min_tsize - blend_extent + + # Split the video into tiles and encode them separately. + row = [] + for i in range(0, T, overlap_size): + tile = x[:, :, i : i + self.tile_sample_min_tsize + 1, :, :] + if self.use_spatial_tiling and ( + tile.shape[-1] > self.tile_sample_min_size or tile.shape[-2] > self.tile_sample_min_size + ): + tile = self.spatial_tiled_encode(tile, return_moments=True) + else: + tile = self.encoder(tile) + tile = self.quant_conv(tile) + if i > 0: + tile = tile[:, :, 1:, :, :] + row.append(tile) + result_row = [] + for i, tile in enumerate(row): + if i > 0: + tile = self.blend_t(row[i - 1], tile, blend_extent) + result_row.append(tile[:, :, :t_limit, :, :]) + else: + result_row.append(tile[:, :, : t_limit + 1, :, :]) + + moments = torch.cat(result_row, dim=2) + posterior = DiagonalGaussianDistribution(moments) + + if not return_dict: + return (posterior,) + + return AutoencoderKLOutput(latent_dist=posterior) + + def temporal_tiled_decode( + self, z: torch.FloatTensor, return_dict: bool = True + ) -> Union[DecoderOutput, torch.FloatTensor]: + # Split z into overlapping tiles and decode them separately. + + B, C, T, H, W = z.shape + overlap_size = int(self.tile_latent_min_tsize * (1 - self.tile_overlap_factor)) + blend_extent = int(self.tile_sample_min_tsize * self.tile_overlap_factor) + t_limit = self.tile_sample_min_tsize - blend_extent + + row = [] + for i in range(0, T, overlap_size): + tile = z[:, :, i : i + self.tile_latent_min_tsize + 1, :, :] + if self.use_spatial_tiling and ( + tile.shape[-1] > self.tile_latent_min_size or tile.shape[-2] > self.tile_latent_min_size + ): + decoded = self.spatial_tiled_decode(tile, return_dict=True).sample + else: + tile = self.post_quant_conv(tile) + decoded = self.decoder(tile) + if i > 0: + decoded = decoded[:, :, 1:, :, :] + row.append(decoded) + result_row = [] + for i, tile in enumerate(row): + if i > 0: + tile = self.blend_t(row[i - 1], tile, blend_extent) + result_row.append(tile[:, :, :t_limit, :, :]) + else: + result_row.append(tile[:, :, : t_limit + 1, :, :]) + + dec = torch.cat(result_row, dim=2) + if not return_dict: + return (dec,) + + return DecoderOutput(sample=dec) + + def forward( + self, + sample: torch.FloatTensor, + sample_posterior: bool = False, + return_dict: bool = True, + return_posterior: bool = False, + generator: Optional[torch.Generator] = None, + ) -> Union[DecoderOutput2, torch.FloatTensor]: + r""" + Args: + sample (`torch.FloatTensor`): Input sample. + sample_posterior (`bool`, *optional*, defaults to `False`): + Whether to sample from the posterior. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`DecoderOutput`] instead of a plain tuple. + """ + x = sample + posterior = self.encode(x).latent_dist + if sample_posterior: + z = posterior.sample(generator=generator) + else: + z = posterior.mode() + dec = self.decode(z).sample + + if not return_dict: + if return_posterior: + return (dec, posterior) + else: + return (dec,) + if return_posterior: + return DecoderOutput2(sample=dec, posterior=posterior) + else: + return DecoderOutput2(sample=dec) diff --git a/src/diffusers/models/transformers/transformer_hunyuan_video.py b/src/diffusers/models/transformers/transformer_hunyuan_video.py index 8c9c26f9cc03..e18aeef3a357 100644 --- a/src/diffusers/models/transformers/transformer_hunyuan_video.py +++ b/src/diffusers/models/transformers/transformer_hunyuan_video.py @@ -1,3 +1,17 @@ +# Copyright 2024 The Hunyuan Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import collections.abc import itertools import math diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index 2f8dfc010092..305bc2371fca 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -92,6 +92,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class AutoencoderKLHunyuanVideo(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class AutoencoderKLMochi(metaclass=DummyObject): _backends = ["torch"] From d885a6b020c0f09245ab94018a31e8576cb1939e Mon Sep 17 00:00:00 2001 From: Aryan Date: Thu, 5 Dec 2024 21:25:13 +0100 Subject: [PATCH 03/58] copy pipeline --- src/diffusers/__init__.py | 2 + .../autoencoder_kl_hunyuan_video.py | 5 +- src/diffusers/pipelines/__init__.py | 2 + .../pipelines/hunyuan_video/__init__.py | 48 + .../hunyuan_video/pipeline_hunyuan_video.py | 970 ++++++++++++++++++ .../pipelines/hunyuan_video/text_encoder.py | 345 +++++++ .../dummy_torch_and_transformers_objects.py | 15 + 7 files changed, 1383 insertions(+), 4 deletions(-) create mode 100644 src/diffusers/pipelines/hunyuan_video/__init__.py create mode 100644 src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py create mode 100644 src/diffusers/pipelines/hunyuan_video/text_encoder.py diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 29284ba9e83f..40a99574ed8f 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -284,6 +284,7 @@ "HunyuanDiTControlNetPipeline", "HunyuanDiTPAGPipeline", "HunyuanDiTPipeline", + "HunyuanVideoPipeline", "I2VGenXLPipeline", "IFImg2ImgPipeline", "IFImg2ImgSuperResolutionPipeline", @@ -757,6 +758,7 @@ HunyuanDiTControlNetPipeline, HunyuanDiTPAGPipeline, HunyuanDiTPipeline, + HunyuanVideoPipeline, I2VGenXLPipeline, IFImg2ImgPipeline, IFImg2ImgSuperResolutionPipeline, diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py b/src/diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py index 0cdd483de68f..ea3867a1d017 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py @@ -25,10 +25,7 @@ from ...utils import logging from ...utils.accelerate_utils import apply_forward_hook from ..activations import get_activation -from ..attention_processor import ( - Attention, - SpatialNorm, -) +from ..attention_processor import Attention, SpatialNorm from ..modeling_outputs import AutoencoderKLOutput from ..modeling_utils import ModelMixin from ..normalization import AdaGroupNorm, RMSNorm diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 6d3a20511696..3e049bbebf25 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -209,6 +209,7 @@ "IFSuperResolutionPipeline", ] _import_structure["hunyuandit"] = ["HunyuanDiTPipeline"] + _import_structure["hunyuan_video"] = ["HunyuanVideoPipeline"] _import_structure["kandinsky"] = [ "KandinskyCombinedPipeline", "KandinskyImg2ImgCombinedPipeline", @@ -539,6 +540,7 @@ FluxPriorReduxPipeline, ReduxImageEncoder, ) + from .hunyuan_video import HunyuanVideoPipeline from .hunyuandit import HunyuanDiTPipeline from .i2vgen_xl import I2VGenXLPipeline from .kandinsky import ( diff --git a/src/diffusers/pipelines/hunyuan_video/__init__.py b/src/diffusers/pipelines/hunyuan_video/__init__.py new file mode 100644 index 000000000000..978ed7f96110 --- /dev/null +++ b/src/diffusers/pipelines/hunyuan_video/__init__.py @@ -0,0 +1,48 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_hunyuan_video"] = ["HunyuanVideoPipeline"] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * + else: + from .pipeline_hunyuan_video import HunyuanVideoPipeline + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py b/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py new file mode 100644 index 000000000000..69e954d73cac --- /dev/null +++ b/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py @@ -0,0 +1,970 @@ +# Copyright 2024 The HunyuanVideo Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from dataclasses import dataclass +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import numpy as np +import torch + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...configuration_utils import FrozenDict +from ...image_processor import VaeImageProcessor +from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin +from ...models import AutoencoderKLHunyuanVideo, HunyuanVideoTransformer3DModel +from ...models.lora import adjust_lora_scale_text_encoder +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import ( + USE_PEFT_BACKEND, + BaseOutput, + deprecate, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline +from .text_encoder import TextEncoder + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """""" + + +def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): + """ + Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 + """ + std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) + std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) + # rescale the results from guidance (fixes overexposure) + noise_pred_rescaled = noise_cfg * (std_text / std_cfg) + # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images + noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg + return noise_cfg + + +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + """ + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +PRECISION_TO_TYPE = { + "fp32": torch.float32, + "fp16": torch.float16, + "bf16": torch.bfloat16, +} + + +@dataclass +class HunyuanVideoPipelineOutput(BaseOutput): + videos: Union[torch.Tensor, np.ndarray] + + +class HunyuanVideoPipeline(DiffusionPipeline): + r""" + Pipeline for text-to-video generation using HunyuanVideo. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. + text_encoder ([`TextEncoder`]): + Frozen text-encoder. + text_encoder_2 ([`TextEncoder`]): + Frozen text-encoder_2. + transformer ([`HYVideoDiffusionTransformer`]): + A `HYVideoDiffusionTransformer` to denoise the encoded video latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. + """ + + model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae" + _optional_components = ["text_encoder_2"] + _exclude_from_cpu_offload = ["transformer"] + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + + def __init__( + self, + vae: AutoencoderKLHunyuanVideo, + text_encoder: TextEncoder, + transformer: HunyuanVideoTransformer3DModel, + scheduler: KarrasDiffusionSchedulers, + text_encoder_2: Optional[TextEncoder] = None, + progress_bar_config: Dict[str, Any] = None, + args=None, + ): + super().__init__() + + # ========================================================================================== + if progress_bar_config is None: + progress_bar_config = {} + if not hasattr(self, "_progress_bar_config"): + self._progress_bar_config = {} + self._progress_bar_config.update(progress_bar_config) + + self.args = args + # ========================================================================================== + + if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" + f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " + "to update the config accordingly as leaving `steps_offset` might led to incorrect results" + " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," + " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" + " file" + ) + deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["steps_offset"] = 1 + scheduler._internal_dict = FrozenDict(new_config) + + if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." + " `clip_sample` should be set to False in the configuration file. Please make sure to update the" + " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" + " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" + " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" + ) + deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["clip_sample"] = False + scheduler._internal_dict = FrozenDict(new_config) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + transformer=transformer, + scheduler=scheduler, + text_encoder_2=text_encoder_2, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + + def encode_prompt( + self, + prompt, + device, + num_videos_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + negative_attention_mask: Optional[torch.Tensor] = None, + lora_scale: Optional[float] = None, + clip_skip: Optional[int] = None, + text_encoder: Optional[TextEncoder] = None, + data_type: Optional[str] = "image", + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_videos_per_prompt (`int`): + number of videos that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the video generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + attention_mask (`torch.Tensor`, *optional*): + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + negative_attention_mask (`torch.Tensor`, *optional*): + lora_scale (`float`, *optional*): + A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + text_encoder (TextEncoder, *optional*): + data_type (`str`, *optional*): + """ + if text_encoder is None: + text_encoder = self.text_encoder + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, LoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(text_encoder.model, lora_scale) + else: + scale_lora_layers(text_encoder.model, lora_scale) + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, text_encoder.tokenizer) + + text_inputs = text_encoder.text2tokens(prompt, data_type=data_type) + + if clip_skip is None: + prompt_outputs = text_encoder.encode(text_inputs, data_type=data_type, device=device) + prompt_embeds = prompt_outputs.hidden_state + else: + prompt_outputs = text_encoder.encode( + text_inputs, + output_hidden_states=True, + data_type=data_type, + device=device, + ) + # Access the `hidden_states` first, that contains a tuple of + # all the hidden states from the encoder layers. Then index into + # the tuple to access the hidden states from the desired layer. + prompt_embeds = prompt_outputs.hidden_states_list[-(clip_skip + 1)] + # We also need to apply the final LayerNorm here to not mess with the + # representations. The `last_hidden_states` that we typically use for + # obtaining the final prompt representations passes through the LayerNorm + # layer. + prompt_embeds = text_encoder.model.text_model.final_layer_norm(prompt_embeds) + + attention_mask = prompt_outputs.attention_mask + if attention_mask is not None: + attention_mask = attention_mask.to(device) + bs_embed, seq_len = attention_mask.shape + attention_mask = attention_mask.repeat(1, num_videos_per_prompt) + attention_mask = attention_mask.view(bs_embed * num_videos_per_prompt, seq_len) + + if text_encoder is not None: + prompt_embeds_dtype = text_encoder.dtype + elif self.transformer is not None: + prompt_embeds_dtype = self.transformer.dtype + else: + prompt_embeds_dtype = prompt_embeds.dtype + + prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + if prompt_embeds.ndim == 2: + bs_embed, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt) + prompt_embeds = prompt_embeds.view(bs_embed * num_videos_per_prompt, -1) + else: + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_videos_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, text_encoder.tokenizer) + + # max_length = prompt_embeds.shape[1] + uncond_input = text_encoder.text2tokens(uncond_tokens, data_type=data_type) + + negative_prompt_outputs = text_encoder.encode(uncond_input, data_type=data_type, device=device) + negative_prompt_embeds = negative_prompt_outputs.hidden_state + + negative_attention_mask = negative_prompt_outputs.attention_mask + if negative_attention_mask is not None: + negative_attention_mask = negative_attention_mask.to(device) + _, seq_len = negative_attention_mask.shape + negative_attention_mask = negative_attention_mask.repeat(1, num_videos_per_prompt) + negative_attention_mask = negative_attention_mask.view(batch_size * num_videos_per_prompt, seq_len) + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + if negative_prompt_embeds.ndim == 2: + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_videos_per_prompt) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_videos_per_prompt, -1) + else: + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_videos_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + if text_encoder is not None: + if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(text_encoder.model, lora_scale) + + return ( + prompt_embeds, + negative_prompt_embeds, + attention_mask, + negative_attention_mask, + ) + + def decode_latents(self, latents, enable_tiling=True): + deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead" + deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) + + latents = 1 / self.vae.config.scaling_factor * latents + if enable_tiling: + self.vae.enable_tiling() + image = self.vae.decode(latents, return_dict=False)[0] + else: + image = self.vae.decode(latents, return_dict=False)[0] + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + if image.ndim == 4: + image = image.cpu().permute(0, 2, 3, 1).float() + else: + image = image.cpu().float() + return image + + def prepare_extra_func_kwargs(self, func, kwargs): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + extra_step_kwargs = {} + + for k, v in kwargs.items(): + accepts = k in set(inspect.signature(func).parameters.keys()) + if accepts: + extra_step_kwargs[k] = v + return extra_step_kwargs + + def check_inputs( + self, + prompt, + height, + width, + video_length, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + vae_ver="88-4c-sd", + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if video_length is not None: + if "884" in vae_ver: + if video_length != 1 and (video_length - 1) % 4 != 0: + raise ValueError(f"`video_length` has to be 1 or a multiple of 4 but is {video_length}.") + elif "888" in vae_ver: + if video_length != 1 and (video_length - 1) % 8 != 0: + raise ValueError(f"`video_length` has to be 1 or a multiple of 8 but is {video_length}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + def prepare_latents( + self, + batch_size, + num_channels_latents, + height, + width, + video_length, + dtype, + device, + generator, + latents=None, + ): + shape = ( + batch_size, + num_channels_latents, + video_length, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # Check existence to make it compatible with FlowMatchEulerDiscreteScheduler + if hasattr(self.scheduler, "init_noise_sigma"): + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding + def get_guidance_scale_embedding( + self, + w: torch.Tensor, + embedding_dim: int = 512, + dtype: torch.dtype = torch.float32, + ) -> torch.Tensor: + """ + See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 + + Args: + w (`torch.Tensor`): + Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings. + embedding_dim (`int`, *optional*, defaults to 512): + Dimension of the embeddings to generate. + dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): + Data type of the generated embeddings. + + Returns: + `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`. + """ + assert len(w.shape) == 1 + w = w * 1000.0 + + half_dim = embedding_dim // 2 + emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) + emb = w.to(dtype)[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0, 1)) + assert emb.shape == (w.shape[0], embedding_dim) + return emb + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def guidance_rescale(self): + return self._guidance_rescale + + @property + def clip_skip(self): + return self._clip_skip + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + # return self._guidance_scale > 1 and self.transformer.config.time_cond_proj_dim is None + return self._guidance_scale > 1 + + @property + def cross_attention_kwargs(self): + return self._cross_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]], + height: int, + width: int, + video_length: int, + data_type: str = "video", + num_inference_steps: int = 50, + timesteps: List[int] = None, + sigmas: List[float] = None, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_videos_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + negative_attention_mask: Optional[torch.Tensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + guidance_rescale: float = 0.0, + clip_skip: Optional[int] = None, + callback_on_step_end: Optional[ + Union[ + Callable[[int, int, Dict], None], + PipelineCallback, + MultiPipelineCallbacks, + ] + ] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + freqs_cis: Tuple[torch.Tensor, torch.Tensor] = None, + vae_ver: str = "88-4c-sd", + enable_tiling: bool = False, + n_tokens: Optional[int] = None, + embedded_guidance_scale: Optional[float] = None, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `List[str]`): + The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. + height (`int`): + The height in pixels of the generated image. + width (`int`): + The width in pixels of the generated image. + video_length (`int`): + The number of frames in the generated video. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + guidance_scale (`float`, *optional*, defaults to 7.5): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies + to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`HunyuanVideoPipelineOutput`] instead of a plain tuple. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in + [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + guidance_rescale (`float`, *optional*, defaults to 0.0): + Guidance rescale factor from [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://arxiv.org/pdf/2305.08891.pdf). Guidance rescale factor should fix overexposure when + using zero terminal SNR. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + + Examples: + + Returns: + [`~HunyuanVideoPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`HunyuanVideoPipelineOutput`] is returned, otherwise a `tuple` is returned + where the first element is a list with the generated images and the second element is a list of `bool`s + indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content. + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + # 0. Default height and width to unet + # height = height or self.transformer.config.sample_size * self.vae_scale_factor + # width = width or self.transformer.config.sample_size * self.vae_scale_factor + # to deal with lora scaling and other possible forward hooks + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + height, + width, + video_length, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + callback_on_step_end_tensor_inputs, + vae_ver=vae_ver, + ) + + self._guidance_scale = guidance_scale + self._guidance_rescale = guidance_rescale + self._clip_skip = clip_skip + self._cross_attention_kwargs = cross_attention_kwargs + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # 3. Encode input prompt + lora_scale = ( + self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None + ) + + ( + prompt_embeds, + negative_prompt_embeds, + prompt_mask, + negative_prompt_mask, + ) = self.encode_prompt( + prompt, + device, + num_videos_per_prompt, + self.do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + attention_mask=attention_mask, + negative_prompt_embeds=negative_prompt_embeds, + negative_attention_mask=negative_attention_mask, + lora_scale=lora_scale, + clip_skip=self.clip_skip, + data_type=data_type, + ) + if self.text_encoder_2 is not None: + ( + prompt_embeds_2, + negative_prompt_embeds_2, + prompt_mask_2, + negative_prompt_mask_2, + ) = self.encode_prompt( + prompt, + device, + num_videos_per_prompt, + self.do_classifier_free_guidance, + negative_prompt, + prompt_embeds=None, + attention_mask=None, + negative_prompt_embeds=None, + negative_attention_mask=None, + lora_scale=lora_scale, + clip_skip=self.clip_skip, + text_encoder=self.text_encoder_2, + data_type=data_type, + ) + else: + prompt_embeds_2 = None + negative_prompt_embeds_2 = None + prompt_mask_2 = None + negative_prompt_mask_2 = None + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + if prompt_mask is not None: + prompt_mask = torch.cat([negative_prompt_mask, prompt_mask]) + if prompt_embeds_2 is not None: + prompt_embeds_2 = torch.cat([negative_prompt_embeds_2, prompt_embeds_2]) + if prompt_mask_2 is not None: + prompt_mask_2 = torch.cat([negative_prompt_mask_2, prompt_mask_2]) + + # 4. Prepare timesteps + extra_set_timesteps_kwargs = self.prepare_extra_func_kwargs( + self.scheduler.set_timesteps, {"n_tokens": n_tokens} + ) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + timesteps, + sigmas, + **extra_set_timesteps_kwargs, + ) + + if "884" in vae_ver: + video_length = (video_length - 1) // 4 + 1 + elif "888" in vae_ver: + video_length = (video_length - 1) // 8 + 1 + else: + video_length = video_length + + # 5. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels + latents = self.prepare_latents( + batch_size * num_videos_per_prompt, + num_channels_latents, + height, + width, + video_length, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_func_kwargs( + self.scheduler.step, + {"generator": generator, "eta": eta}, + ) + + target_dtype = PRECISION_TO_TYPE[self.args.precision] + autocast_enabled = (target_dtype != torch.float32) and not self.args.disable_autocast + vae_dtype = PRECISION_TO_TYPE[self.args.vae_precision] + vae_autocast_enabled = (vae_dtype != torch.float32) and not self.args.disable_autocast + + # 7. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + self._num_timesteps = len(timesteps) + + # if is_progress_bar: + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + t_expand = t.repeat(latent_model_input.shape[0]) + guidance_expand = ( + torch.tensor( + [embedded_guidance_scale] * latent_model_input.shape[0], + dtype=torch.float32, + device=device, + ).to(target_dtype) + * 1000.0 + if embedded_guidance_scale is not None + else None + ) + + # predict the noise residual + with torch.autocast(device_type="cuda", dtype=target_dtype, enabled=autocast_enabled): + noise_pred = self.transformer( # For an input image (129, 192, 336) (1, 256, 256) + latent_model_input, # [2, 16, 33, 24, 42] + t_expand, # [2] + text_states=prompt_embeds, # [2, 256, 4096] + text_mask=prompt_mask, # [2, 256] + text_states_2=prompt_embeds_2, # [2, 768] + freqs_cos=freqs_cis[0], # [seqlen, head_dim] + freqs_sin=freqs_cis[1], # [seqlen, head_dim] + guidance=guidance_expand, + return_dict=True, + )["x"] + + # perform guidance + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + if self.do_classifier_free_guidance and self.guidance_rescale > 0.0: + # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf + noise_pred = rescale_noise_cfg( + noise_pred, + noise_pred_text, + guidance_rescale=self.guidance_rescale, + ) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if not output_type == "latent": + expand_temporal_dim = False + if len(latents.shape) == 4: + latents = latents.unsqueeze(2) + expand_temporal_dim = True + elif len(latents.shape) == 5: + pass + else: + raise ValueError( + f"Only support latents with shape (b, c, h, w) or (b, c, f, h, w), but got {latents.shape}." + ) + + if hasattr(self.vae.config, "shift_factor") and self.vae.config.shift_factor: + latents = latents / self.vae.config.scaling_factor + self.vae.config.shift_factor + else: + latents = latents / self.vae.config.scaling_factor + + with torch.autocast(device_type="cuda", dtype=vae_dtype, enabled=vae_autocast_enabled): + if enable_tiling: + self.vae.enable_tiling() + image = self.vae.decode(latents, return_dict=False, generator=generator)[0] + else: + image = self.vae.decode(latents, return_dict=False, generator=generator)[0] + + if expand_temporal_dim or image.shape[2] == 1: + image = image.squeeze(2) + + else: + image = latents + + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 + image = image.cpu().float() + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return image + + return HunyuanVideoPipelineOutput(videos=image) diff --git a/src/diffusers/pipelines/hunyuan_video/text_encoder.py b/src/diffusers/pipelines/hunyuan_video/text_encoder.py new file mode 100644 index 000000000000..d7281799d8c6 --- /dev/null +++ b/src/diffusers/pipelines/hunyuan_video/text_encoder.py @@ -0,0 +1,345 @@ +import os +from dataclasses import dataclass +from typing import Optional, Tuple + +import torch +import torch.nn as nn +from transformers import AutoModel, AutoTokenizer, CLIPTextModel, CLIPTokenizer +from transformers.utils import ModelOutput + + +PRECISION_TO_TYPE = { + "fp32": torch.float32, + "fp16": torch.float16, + "bf16": torch.bfloat16, +} + +MODEL_BASE = os.getenv("MODEL_BASE", "./ckpts") + +# Text Encoder +TEXT_ENCODER_PATH = { + "clipL": f"{MODEL_BASE}/text_encoder_2", + "llm": f"{MODEL_BASE}/text_encoder", +} + +# Tokenizer +TOKENIZER_PATH = { + "clipL": f"{MODEL_BASE}/text_encoder_2", + "llm": f"{MODEL_BASE}/text_encoder", +} + + +def use_default(value, default): + return value if value is not None else default + + +def load_text_encoder( + text_encoder_type, + text_encoder_precision=None, + text_encoder_path=None, + logger=None, + device=None, +): + if text_encoder_path is None: + text_encoder_path = TEXT_ENCODER_PATH[text_encoder_type] + if logger is not None: + logger.info(f"Loading text encoder model ({text_encoder_type}) from: {text_encoder_path}") + + if text_encoder_type == "clipL": + text_encoder = CLIPTextModel.from_pretrained(text_encoder_path) + text_encoder.final_layer_norm = text_encoder.text_model.final_layer_norm + elif text_encoder_type == "llm": + text_encoder = AutoModel.from_pretrained(text_encoder_path, low_cpu_mem_usage=True) + text_encoder.final_layer_norm = text_encoder.norm + else: + raise ValueError(f"Unsupported text encoder type: {text_encoder_type}") + # from_pretrained will ensure that the model is in eval mode. + + if text_encoder_precision is not None: + text_encoder = text_encoder.to(dtype=PRECISION_TO_TYPE[text_encoder_precision]) + + text_encoder.requires_grad_(False) + + if logger is not None: + logger.info(f"Text encoder to dtype: {text_encoder.dtype}") + + if device is not None: + text_encoder = text_encoder.to(device) + + return text_encoder, text_encoder_path + + +def load_tokenizer(tokenizer_type, tokenizer_path=None, padding_side="right", logger=None): + if tokenizer_path is None: + tokenizer_path = TOKENIZER_PATH[tokenizer_type] + if logger is not None: + logger.info(f"Loading tokenizer ({tokenizer_type}) from: {tokenizer_path}") + + if tokenizer_type == "clipL": + tokenizer = CLIPTokenizer.from_pretrained(tokenizer_path, max_length=77) + elif tokenizer_type == "llm": + tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, padding_side=padding_side) + else: + raise ValueError(f"Unsupported tokenizer type: {tokenizer_type}") + + return tokenizer, tokenizer_path + + +@dataclass +class TextEncoderModelOutput(ModelOutput): + """ + Base class for model's outputs that also contains a pooling of the last hidden states. + + Args: + hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``: + hidden_states_list (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of + the model at the output of each layer plus the optional initial embedding outputs. + text_outputs (`list`, *optional*, returned when `return_texts=True` is passed): + List of decoded texts. + """ + + hidden_state: torch.FloatTensor = None + attention_mask: Optional[torch.LongTensor] = None + hidden_states_list: Optional[Tuple[torch.FloatTensor, ...]] = None + text_outputs: Optional[list] = None + + +class TextEncoder(nn.Module): + def __init__( + self, + text_encoder_type: str, + max_length: int, + text_encoder_precision: Optional[str] = None, + text_encoder_path: Optional[str] = None, + tokenizer_type: Optional[str] = None, + tokenizer_path: Optional[str] = None, + output_key: Optional[str] = None, + use_attention_mask: bool = True, + input_max_length: Optional[int] = None, + prompt_template: Optional[dict] = None, + prompt_template_video: Optional[dict] = None, + hidden_state_skip_layer: Optional[int] = None, + apply_final_norm: bool = False, + reproduce: bool = False, + logger=None, + device=None, + ): + super().__init__() + self.text_encoder_type = text_encoder_type + self.max_length = max_length + self.precision = text_encoder_precision + self.model_path = text_encoder_path + self.tokenizer_type = tokenizer_type if tokenizer_type is not None else text_encoder_type + self.tokenizer_path = tokenizer_path if tokenizer_path is not None else text_encoder_path + self.use_attention_mask = use_attention_mask + if prompt_template_video is not None: + assert use_attention_mask is True, "Attention mask is True required when training videos." + self.input_max_length = input_max_length if input_max_length is not None else max_length + self.prompt_template = prompt_template + self.prompt_template_video = prompt_template_video + self.hidden_state_skip_layer = hidden_state_skip_layer + self.apply_final_norm = apply_final_norm + self.reproduce = reproduce + self.logger = logger + + self.use_template = self.prompt_template is not None + if self.use_template: + assert ( + isinstance(self.prompt_template, dict) and "template" in self.prompt_template + ), f"`prompt_template` must be a dictionary with a key 'template', got {self.prompt_template}" + assert "{}" in str(self.prompt_template["template"]), ( + "`prompt_template['template']` must contain a placeholder `{}` for the input text, " + f"got {self.prompt_template['template']}" + ) + + self.use_video_template = self.prompt_template_video is not None + if self.use_video_template: + if self.prompt_template_video is not None: + assert ( + isinstance(self.prompt_template_video, dict) and "template" in self.prompt_template_video + ), f"`prompt_template_video` must be a dictionary with a key 'template', got {self.prompt_template_video}" + assert "{}" in str(self.prompt_template_video["template"]), ( + "`prompt_template_video['template']` must contain a placeholder `{}` for the input text, " + f"got {self.prompt_template_video['template']}" + ) + + if "t5" in text_encoder_type: + self.output_key = output_key or "last_hidden_state" + elif "clip" in text_encoder_type: + self.output_key = output_key or "pooler_output" + elif "llm" in text_encoder_type or "glm" in text_encoder_type: + self.output_key = output_key or "last_hidden_state" + else: + raise ValueError(f"Unsupported text encoder type: {text_encoder_type}") + + self.model, self.model_path = load_text_encoder( + text_encoder_type=self.text_encoder_type, + text_encoder_precision=self.precision, + text_encoder_path=self.model_path, + logger=self.logger, + device=device, + ) + self.dtype = self.model.dtype + self.device = self.model.device + + self.tokenizer, self.tokenizer_path = load_tokenizer( + tokenizer_type=self.tokenizer_type, + tokenizer_path=self.tokenizer_path, + padding_side="right", + logger=self.logger, + ) + + def __repr__(self): + return f"{self.text_encoder_type} ({self.precision} - {self.model_path})" + + @staticmethod + def apply_text_to_template(text, template, prevent_empty_text=True): + """ + Apply text to template. + + Args: + text (str): Input text. + template (str or list): Template string or list of chat conversation. + prevent_empty_text (bool): If Ture, we will prevent the user text from being empty + by adding a space. Defaults to True. + """ + if isinstance(template, str): + # Will send string to tokenizer. Used for llm + return template.format(text) + else: + raise TypeError(f"Unsupported template type: {type(template)}") + + def text2tokens(self, text, data_type="image"): + """ + Tokenize the input text. + + Args: + text (str or list): Input text. + """ + tokenize_input_type = "str" + if self.use_template: + if data_type == "image": + prompt_template = self.prompt_template["template"] + elif data_type == "video": + prompt_template = self.prompt_template_video["template"] + else: + raise ValueError(f"Unsupported data type: {data_type}") + if isinstance(text, (list, tuple)): + text = [self.apply_text_to_template(one_text, prompt_template) for one_text in text] + if isinstance(text[0], list): + tokenize_input_type = "list" + elif isinstance(text, str): + text = self.apply_text_to_template(text, prompt_template) + if isinstance(text, list): + tokenize_input_type = "list" + else: + raise TypeError(f"Unsupported text type: {type(text)}") + + kwargs = { + "truncation": True, + "max_length": self.max_length, + "padding": "max_length", + "return_tensors": "pt", + } + if tokenize_input_type == "str": + return self.tokenizer( + text, + return_length=False, + return_overflowing_tokens=False, + return_attention_mask=True, + **kwargs, + ) + elif tokenize_input_type == "list": + return self.tokenizer.apply_chat_template( + text, + add_generation_prompt=True, + tokenize=True, + return_dict=True, + **kwargs, + ) + else: + raise ValueError(f"Unsupported tokenize_input_type: {tokenize_input_type}") + + def encode( + self, + batch_encoding, + use_attention_mask=None, + output_hidden_states=False, + do_sample=None, + hidden_state_skip_layer=None, + return_texts=False, + data_type="image", + device=None, + ): + """ + Args: + batch_encoding (dict): Batch encoding from tokenizer. + use_attention_mask (bool): Whether to use attention mask. If None, use self.use_attention_mask. + Defaults to None. + output_hidden_states (bool): Whether to output hidden states. If False, return the value of + self.output_key. If True, return the entire output. If set self.hidden_state_skip_layer, + output_hidden_states will be set True. Defaults to False. + do_sample (bool): Whether to sample from the model. Used for Decoder-Only LLMs. Defaults to None. + When self.produce is False, do_sample is set to True by default. + hidden_state_skip_layer (int): Number of hidden states to hidden_state_skip_layer. 0 means the last layer. + If None, self.output_key will be used. Defaults to None. + return_texts (bool): Whether to return the decoded texts. Defaults to False. + """ + device = self.model.device if device is None else device + use_attention_mask = use_default(use_attention_mask, self.use_attention_mask) + hidden_state_skip_layer = use_default(hidden_state_skip_layer, self.hidden_state_skip_layer) + do_sample = use_default(do_sample, not self.reproduce) + attention_mask = batch_encoding["attention_mask"].to(device) if use_attention_mask else None + outputs = self.model( + input_ids=batch_encoding["input_ids"].to(device), + attention_mask=attention_mask, + output_hidden_states=output_hidden_states or hidden_state_skip_layer is not None, + ) + if hidden_state_skip_layer is not None: + last_hidden_state = outputs.hidden_states[-(hidden_state_skip_layer + 1)] + # Real last hidden state already has layer norm applied. So here we only apply it + # for intermediate layers. + if hidden_state_skip_layer > 0 and self.apply_final_norm: + last_hidden_state = self.model.final_layer_norm(last_hidden_state) + else: + last_hidden_state = outputs[self.output_key] + + # Remove hidden states of instruction tokens, only keep prompt tokens. + if self.use_template: + if data_type == "image": + crop_start = self.prompt_template.get("crop_start", -1) + elif data_type == "video": + crop_start = self.prompt_template_video.get("crop_start", -1) + else: + raise ValueError(f"Unsupported data type: {data_type}") + if crop_start > 0: + last_hidden_state = last_hidden_state[:, crop_start:] + attention_mask = attention_mask[:, crop_start:] if use_attention_mask else None + + if output_hidden_states: + return TextEncoderModelOutput(last_hidden_state, attention_mask, outputs.hidden_states) + return TextEncoderModelOutput(last_hidden_state, attention_mask) + + def forward( + self, + text, + use_attention_mask=None, + output_hidden_states=False, + do_sample=False, + hidden_state_skip_layer=None, + return_texts=False, + ): + batch_encoding = self.text2tokens(text) + return self.encode( + batch_encoding, + use_attention_mask=use_attention_mask, + output_hidden_states=output_hidden_states, + do_sample=do_sample, + hidden_state_skip_layer=hidden_state_skip_layer, + return_texts=return_texts, + ) diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index 4fc7cd6aefff..aa505e5c9bab 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -572,6 +572,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class HunyuanVideoPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class I2VGenXLPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] From 332c771fb1123793f5a6f6fdec2c6bb574ecdb34 Mon Sep 17 00:00:00 2001 From: Aryan Date: Thu, 5 Dec 2024 21:25:34 +0100 Subject: [PATCH 04/58] make fix-copies --- .../pipelines/hunyuan_video/pipeline_hunyuan_video.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py b/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py index 69e954d73cac..4564acd65da5 100644 --- a/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +++ b/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py @@ -526,10 +526,7 @@ def prepare_latents( # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding def get_guidance_scale_embedding( - self, - w: torch.Tensor, - embedding_dim: int = 512, - dtype: torch.dtype = torch.float32, + self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32 ) -> torch.Tensor: """ See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 From 77097473300a04e5aea00d62cb86f876e18406e5 Mon Sep 17 00:00:00 2001 From: Aryan Date: Tue, 10 Dec 2024 17:01:19 +0100 Subject: [PATCH 05/58] refactor; make original code work with diffusers; test latents for comparison generated with this commit --- .../transformers/transformer_hunyuan_video.py | 243 +++++------------- .../hunyuan_video/pipeline_hunyuan_video.py | 197 +++----------- .../pipelines/hunyuan_video/text_encoder.py | 44 +--- 3 files changed, 108 insertions(+), 376 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_hunyuan_video.py b/src/diffusers/models/transformers/transformer_hunyuan_video.py index e18aeef3a357..4dfd4816f968 100644 --- a/src/diffusers/models/transformers/transformer_hunyuan_video.py +++ b/src/diffusers/models/transformers/transformer_hunyuan_video.py @@ -320,8 +320,6 @@ def __init__( dim: int, elementwise_affine=True, eps: float = 1e-6, - device=None, - dtype=None, ): """ Initialize the RMSNorm normalization layer. @@ -335,11 +333,10 @@ def __init__( weight (nn.Parameter): Learnable scaling parameter. """ - factory_kwargs = {"device": device, "dtype": dtype} super().__init__() self.eps = eps if elementwise_affine: - self.weight = nn.Parameter(torch.ones(dim, **factory_kwargs)) + self.weight = nn.Parameter(torch.ones(dim)) def _norm(self, x): """ @@ -437,16 +434,10 @@ def __init__( hidden_size: int, factor: int, act_layer: Callable, - dtype=None, - device=None, ): - factory_kwargs = {"dtype": dtype, "device": device} super().__init__() self.act = act_layer() - self.linear = nn.Linear(hidden_size, factor * hidden_size, bias=True, **factory_kwargs) - # Zero-initialize the modulation - nn.init.zeros_(self.linear.weight) - nn.init.zeros_(self.linear.bias) + self.linear = nn.Linear(hidden_size, factor * hidden_size, bias=True) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.linear(self.act(x)) @@ -465,10 +456,7 @@ def __init__( bias=True, drop=0.0, use_conv=False, - device=None, - dtype=None, ): - factory_kwargs = {"device": device, "dtype": dtype} super().__init__() out_features = out_features or in_channels hidden_channels = hidden_channels or in_channels @@ -476,11 +464,11 @@ def __init__( drop_probs = to_2tuple(drop) linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear - self.fc1 = linear_layer(in_channels, hidden_channels, bias=bias[0], **factory_kwargs) + self.fc1 = linear_layer(in_channels, hidden_channels, bias=bias[0]) self.act = act_layer() self.drop1 = nn.Dropout(drop_probs[0]) - self.norm = norm_layer(hidden_channels, **factory_kwargs) if norm_layer is not None else nn.Identity() - self.fc2 = linear_layer(hidden_channels, out_features, bias=bias[1], **factory_kwargs) + self.norm = norm_layer(hidden_channels) if norm_layer is not None else nn.Identity() + self.fc2 = linear_layer(hidden_channels, out_features, bias=bias[1]) self.drop2 = nn.Dropout(drop_probs[1]) def forward(self, x): @@ -497,12 +485,11 @@ def forward(self, x): class MLPEmbedder(nn.Module): """copied from https://github.com/black-forest-labs/flux/blob/main/src/flux/modules/layers.py""" - def __init__(self, in_dim: int, hidden_dim: int, device=None, dtype=None): - factory_kwargs = {"device": device, "dtype": dtype} + def __init__(self, in_dim: int, hidden_dim: int): super().__init__() - self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True, **factory_kwargs) + self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True) self.silu = nn.SiLU() - self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True, **factory_kwargs) + self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.out_layer(self.silu(self.in_layer(x))) @@ -511,14 +498,13 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class FinalLayer(nn.Module): """The final layer of DiT.""" - def __init__(self, hidden_size, patch_size, out_channels, act_layer, device=None, dtype=None): - factory_kwargs = {"device": device, "dtype": dtype} + def __init__(self, hidden_size, patch_size, out_channels, act_layer): super().__init__() # Just use LayerNorm for the final layer - self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs) + self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) if isinstance(patch_size, int): - self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True, **factory_kwargs) + self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True) else: self.linear = nn.Linear( hidden_size, @@ -531,11 +517,8 @@ def __init__(self, hidden_size, patch_size, out_channels, act_layer, device=None # Here we don't distinguish between the modulate types. Just use the simple one. self.adaLN_modulation = nn.Sequential( act_layer(), - nn.Linear(hidden_size, 2 * hidden_size, bias=True, **factory_kwargs), + nn.Linear(hidden_size, 2 * hidden_size, bias=True), ) - # Zero-initialize the modulation - nn.init.zeros_(self.adaLN_modulation[1].weight) - nn.init.zeros_(self.adaLN_modulation[1].bias) def forward(self, x, c): shift, scale = self.adaLN_modulation(c).chunk(2, dim=1) @@ -566,17 +549,14 @@ def __init__( norm_layer=None, flatten=True, bias=True, - dtype=None, - device=None, ): - factory_kwargs = {"dtype": dtype, "device": device} super().__init__() patch_size = to_2tuple(patch_size) self.patch_size = patch_size self.flatten = flatten self.proj = nn.Conv3d( - in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias, **factory_kwargs + in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias ) nn.init.xavier_uniform_(self.proj.weight.view(self.proj.weight.size(0), -1)) if bias: @@ -599,12 +579,11 @@ class TextProjection(nn.Module): Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py """ - def __init__(self, in_channels, hidden_size, act_layer, dtype=None, device=None): - factory_kwargs = {"dtype": dtype, "device": device} + def __init__(self, in_channels, hidden_size, act_layer): super().__init__() - self.linear_1 = nn.Linear(in_features=in_channels, out_features=hidden_size, bias=True, **factory_kwargs) + self.linear_1 = nn.Linear(in_features=in_channels, out_features=hidden_size, bias=True) self.act_1 = act_layer() - self.linear_2 = nn.Linear(in_features=hidden_size, out_features=hidden_size, bias=True, **factory_kwargs) + self.linear_2 = nn.Linear(in_features=hidden_size, out_features=hidden_size, bias=True) def forward(self, caption): hidden_states = self.linear_1(caption) @@ -650,10 +629,7 @@ def __init__( frequency_embedding_size=256, max_period=10000, out_size=None, - dtype=None, - device=None, ): - factory_kwargs = {"dtype": dtype, "device": device} super().__init__() self.frequency_embedding_size = frequency_embedding_size self.max_period = max_period @@ -661,9 +637,9 @@ def __init__( out_size = hidden_size self.mlp = nn.Sequential( - nn.Linear(frequency_embedding_size, hidden_size, bias=True, **factory_kwargs), + nn.Linear(frequency_embedding_size, hidden_size, bias=True), act_layer(), - nn.Linear(hidden_size, out_size, bias=True, **factory_kwargs), + nn.Linear(hidden_size, out_size, bias=True), ) nn.init.normal_(self.mlp[0].weight, std=0.02) nn.init.normal_(self.mlp[2].weight, std=0.02) @@ -685,43 +661,36 @@ def __init__( qk_norm: bool = False, qk_norm_type: str = "layer", qkv_bias: bool = True, - dtype: Optional[torch.dtype] = None, - device: Optional[torch.device] = None, ): - factory_kwargs = {"device": device, "dtype": dtype} super().__init__() self.heads_num = heads_num head_dim = hidden_size // heads_num mlp_hidden_dim = int(hidden_size * mlp_width_ratio) - self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6, **factory_kwargs) - self.self_attn_qkv = nn.Linear(hidden_size, hidden_size * 3, bias=qkv_bias, **factory_kwargs) + self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6) + self.self_attn_qkv = nn.Linear(hidden_size, hidden_size * 3, bias=qkv_bias) qk_norm_layer = get_norm_layer(qk_norm_type) self.self_attn_q_norm = ( - qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity() + qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity() ) self.self_attn_k_norm = ( - qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity() + qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity() ) - self.self_attn_proj = nn.Linear(hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs) + self.self_attn_proj = nn.Linear(hidden_size, hidden_size, bias=qkv_bias) - self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6, **factory_kwargs) + self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6) act_layer = get_activation_layer(act_type) self.mlp = MLP( in_channels=hidden_size, hidden_channels=mlp_hidden_dim, act_layer=act_layer, drop=mlp_drop_rate, - **factory_kwargs, ) self.adaLN_modulation = nn.Sequential( act_layer(), - nn.Linear(hidden_size, 2 * hidden_size, bias=True, **factory_kwargs), + nn.Linear(hidden_size, 2 * hidden_size, bias=True), ) - # Zero-initialize the modulation - nn.init.zeros_(self.adaLN_modulation[1].weight) - nn.init.zeros_(self.adaLN_modulation[1].bias) def forward( self, @@ -761,10 +730,7 @@ def __init__( qk_norm: bool = False, qk_norm_type: str = "layer", qkv_bias: bool = True, - dtype: Optional[torch.dtype] = None, - device: Optional[torch.device] = None, ): - factory_kwargs = {"device": device, "dtype": dtype} super().__init__() self.blocks = nn.ModuleList( [ @@ -777,7 +743,6 @@ def __init__( qk_norm=qk_norm, qk_norm_type=qk_norm_type, qkv_bias=qkv_bias, - **factory_kwargs, ) for _ in range(depth) ] @@ -793,7 +758,7 @@ def forward( if mask is not None: batch_size = mask.shape[0] seq_len = mask.shape[1] - mask = mask.to(x.device) + mask = mask.to(x.device).bool() # batch_size x 1 x seq_len x seq_len self_attn_mask_1 = mask.view(batch_size, 1, 1, seq_len).repeat(1, 1, seq_len, 1) # batch_size x 1 x seq_len x seq_len @@ -826,21 +791,18 @@ def __init__( qk_norm_type: str = "layer", qkv_bias: bool = True, attn_mode: str = "torch", - dtype: Optional[torch.dtype] = None, - device: Optional[torch.device] = None, ): - factory_kwargs = {"device": device, "dtype": dtype} super().__init__() self.attn_mode = attn_mode assert self.attn_mode == "torch", "Only support 'torch' mode for token refiner." - self.input_embedder = nn.Linear(in_channels, hidden_size, bias=True, **factory_kwargs) + self.input_embedder = nn.Linear(in_channels, hidden_size, bias=True) act_layer = get_activation_layer(act_type) # Build timestep embedding layer - self.t_embedder = TimestepEmbedder(hidden_size, act_layer, **factory_kwargs) + self.t_embedder = TimestepEmbedder(hidden_size, act_layer) # Build context embedding layer - self.c_embedder = TextProjection(in_channels, hidden_size, act_layer, **factory_kwargs) + self.c_embedder = TextProjection(in_channels, hidden_size, act_layer) self.individual_token_refiner = IndividualTokenRefiner( hidden_size=hidden_size, @@ -852,7 +814,6 @@ def __init__( qk_norm=qk_norm, qk_norm_type=qk_norm_type, qkv_bias=qkv_bias, - **factory_kwargs, ) def forward( @@ -861,6 +822,7 @@ def forward( t: torch.LongTensor, mask: Optional[torch.LongTensor] = None, ): + original_dtype = x.dtype timestep_aware_representations = self.t_embedder(t) if mask is None: @@ -868,11 +830,12 @@ def forward( else: mask_float = mask.float().unsqueeze(-1) # [b, s1, 1] context_aware_representations = (x * mask_float).sum(dim=1) / mask_float.sum(dim=1) + context_aware_representations = context_aware_representations.to(original_dtype) + context_aware_representations = self.c_embedder(context_aware_representations) c = timestep_aware_representations + context_aware_representations x = self.input_embedder(x) - x = self.individual_token_refiner(x, c, mask) return x @@ -894,10 +857,7 @@ def __init__( qk_norm: bool = True, qk_norm_type: str = "rms", qkv_bias: bool = False, - dtype: Optional[torch.dtype] = None, - device: Optional[torch.device] = None, ): - factory_kwargs = {"device": device, "dtype": dtype} super().__init__() self.deterministic = False @@ -908,62 +868,52 @@ def __init__( self.img_mod = ModulateDiT( hidden_size, factor=6, - act_layer=get_activation_layer("silu"), - **factory_kwargs, + act_layer=get_activation_layer("silu") ) - self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs) + self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) - self.img_attn_qkv = nn.Linear(hidden_size, hidden_size * 3, bias=qkv_bias, **factory_kwargs) + self.img_attn_qkv = nn.Linear(hidden_size, hidden_size * 3, bias=qkv_bias) qk_norm_layer = get_norm_layer(qk_norm_type) self.img_attn_q_norm = ( - qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity() + qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity() ) self.img_attn_k_norm = ( - qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity() + qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity() ) - self.img_attn_proj = nn.Linear(hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs) + self.img_attn_proj = nn.Linear(hidden_size, hidden_size, bias=qkv_bias) - self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs) + self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) self.img_mlp = MLP( hidden_size, mlp_hidden_dim, act_layer=get_activation_layer(mlp_act_type), bias=True, - **factory_kwargs, ) self.txt_mod = ModulateDiT( hidden_size, factor=6, act_layer=get_activation_layer("silu"), - **factory_kwargs, ) - self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs) + self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) - self.txt_attn_qkv = nn.Linear(hidden_size, hidden_size * 3, bias=qkv_bias, **factory_kwargs) + self.txt_attn_qkv = nn.Linear(hidden_size, hidden_size * 3, bias=qkv_bias) self.txt_attn_q_norm = ( - qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity() + qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity() ) self.txt_attn_k_norm = ( - qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity() + qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity() ) - self.txt_attn_proj = nn.Linear(hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs) + self.txt_attn_proj = nn.Linear(hidden_size, hidden_size, bias=qkv_bias) - self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs) + self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) self.txt_mlp = MLP( hidden_size, mlp_hidden_dim, act_layer=get_activation_layer(mlp_act_type), bias=True, - **factory_kwargs, ) - def enable_deterministic(self): - self.deterministic = True - - def disable_deterministic(self): - self.deterministic = False - def forward( self, img: torch.Tensor, @@ -1071,10 +1021,7 @@ def __init__( qk_norm: bool = True, qk_norm_type: str = "rms", qk_scale: float = None, - dtype: Optional[torch.dtype] = None, - device: Optional[torch.device] = None, ): - factory_kwargs = {"device": device, "dtype": dtype} super().__init__() self.deterministic = False @@ -1086,34 +1033,27 @@ def __init__( self.scale = qk_scale or head_dim**-0.5 # qkv and mlp_in - self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + mlp_hidden_dim, **factory_kwargs) + self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + mlp_hidden_dim) # proj and mlp_out - self.linear2 = nn.Linear(hidden_size + mlp_hidden_dim, hidden_size, **factory_kwargs) + self.linear2 = nn.Linear(hidden_size + mlp_hidden_dim, hidden_size) qk_norm_layer = get_norm_layer(qk_norm_type) self.q_norm = ( - qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity() + qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity() ) self.k_norm = ( - qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity() + qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity() ) - self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs) + self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) self.mlp_act = get_activation_layer(mlp_act_type)() self.modulation = ModulateDiT( hidden_size, factor=3, act_layer=get_activation_layer("silu"), - **factory_kwargs, ) - def enable_deterministic(self): - self.deterministic = True - - def disable_deterministic(self): - self.deterministic = False - def forward( self, x: torch.Tensor, @@ -1218,7 +1158,6 @@ class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin): @register_to_config def __init__( self, - args: Any, patch_size: list = [1, 2, 2], in_channels: int = 4, # Should be VAE.config.latent_channels. out_channels: int = None, @@ -1233,12 +1172,9 @@ def __init__( qk_norm: bool = True, qk_norm_type: str = "rms", guidance_embed: bool = False, # For modulation. - text_projection: str = "single_refiner", - use_attention_mask: bool = True, - dtype: Optional[torch.dtype] = None, - device: Optional[torch.device] = None, + text_states_dim: int = 4096, + text_states_dim_2: int = 768, ): - factory_kwargs = {"device": device, "dtype": dtype} super().__init__() self.patch_size = patch_size @@ -1248,14 +1184,6 @@ def __init__( self.guidance_embed = guidance_embed self.rope_dim_list = rope_dim_list - # Text projection. Default to linear projection. - # Alternative: TokenRefiner. See more details (LI-DiT): http://arxiv.org/abs/2406.11831 - self.use_attention_mask = use_attention_mask - self.text_projection = text_projection - - self.text_states_dim = args.text_states_dim - self.text_states_dim_2 = args.text_states_dim_2 - if hidden_size % heads_num != 0: raise ValueError(f"Hidden size {hidden_size} must be divisible by heads_num {heads_num}") pe_dim = hidden_size // heads_num @@ -1265,30 +1193,20 @@ def __init__( self.heads_num = heads_num # image projection - self.img_in = PatchEmbed(self.patch_size, self.in_channels, self.hidden_size, **factory_kwargs) + self.img_in = PatchEmbed(self.patch_size, self.in_channels, self.hidden_size) # text projection - if self.text_projection == "linear": - self.txt_in = TextProjection( - self.text_states_dim, - self.hidden_size, - get_activation_layer("silu"), - **factory_kwargs, - ) - elif self.text_projection == "single_refiner": - self.txt_in = SingleTokenRefiner(self.text_states_dim, hidden_size, heads_num, depth=2, **factory_kwargs) - else: - raise NotImplementedError(f"Unsupported text_projection: {self.text_projection}") + self.txt_in = SingleTokenRefiner(text_states_dim, hidden_size, heads_num, depth=2) # time modulation - self.time_in = TimestepEmbedder(self.hidden_size, get_activation_layer("silu"), **factory_kwargs) + self.time_in = TimestepEmbedder(self.hidden_size, get_activation_layer("silu")) # text modulation - self.vector_in = MLPEmbedder(self.text_states_dim_2, self.hidden_size, **factory_kwargs) + self.vector_in = MLPEmbedder(text_states_dim_2, self.hidden_size) # guidance modulation self.guidance_in = ( - TimestepEmbedder(self.hidden_size, get_activation_layer("silu"), **factory_kwargs) + TimestepEmbedder(self.hidden_size, get_activation_layer("silu")) if guidance_embed else None ) @@ -1304,7 +1222,6 @@ def __init__( qk_norm=qk_norm, qk_norm_type=qk_norm_type, qkv_bias=qkv_bias, - **factory_kwargs, ) for _ in range(mm_double_blocks_depth) ] @@ -1320,7 +1237,6 @@ def __init__( mlp_act_type=mlp_act_type, qk_norm=qk_norm, qk_norm_type=qk_norm_type, - **factory_kwargs, ) for _ in range(mm_single_blocks_depth) ] @@ -1331,21 +1247,8 @@ def __init__( self.patch_size, self.out_channels, get_activation_layer("silu"), - **factory_kwargs, ) - def enable_deterministic(self): - for block in self.double_blocks: - block.enable_deterministic() - for block in self.single_blocks: - block.enable_deterministic() - - def disable_deterministic(self): - for block in self.double_blocks: - block.disable_deterministic() - for block in self.single_blocks: - block.disable_deterministic() - def forward( self, x: torch.Tensor, @@ -1384,12 +1287,7 @@ def forward( # Embed image and text. img = self.img_in(img) - if self.text_projection == "linear": - txt = self.txt_in(txt) - elif self.text_projection == "single_refiner": - txt = self.txt_in(txt, t, text_mask if self.use_attention_mask else None) - else: - raise NotImplementedError(f"Unsupported text_projection: {self.text_projection}") + txt = self.txt_in(txt, t, text_mask) txt_seq_len = txt.shape[1] img_seq_len = img.shape[1] @@ -1457,28 +1355,3 @@ def unpatchify(self, x, t, h, w): imgs = x.reshape(shape=(x.shape[0], c, t * pt, h * ph, w * pw)) return imgs - - def params_count(self): - counts = { - "double": sum( - [ - sum(p.numel() for p in block.img_attn_qkv.parameters()) - + sum(p.numel() for p in block.img_attn_proj.parameters()) - + sum(p.numel() for p in block.img_mlp.parameters()) - + sum(p.numel() for p in block.txt_attn_qkv.parameters()) - + sum(p.numel() for p in block.txt_attn_proj.parameters()) - + sum(p.numel() for p in block.txt_mlp.parameters()) - for block in self.double_blocks - ] - ), - "single": sum( - [ - sum(p.numel() for p in block.linear1.parameters()) - + sum(p.numel() for p in block.linear2.parameters()) - for block in self.single_blocks - ] - ), - "total": sum(p.numel() for p in self.parameters()), - } - counts["attn+mlp"] = counts["double"] + counts["single"] - return counts diff --git a/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py b/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py index 4564acd65da5..3b9fc8ae3002 100644 --- a/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +++ b/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py @@ -20,20 +20,14 @@ import torch from ...callbacks import MultiPipelineCallbacks, PipelineCallback -from ...configuration_utils import FrozenDict from ...image_processor import VaeImageProcessor -from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKLHunyuanVideo, HunyuanVideoTransformer3DModel -from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( - USE_PEFT_BACKEND, BaseOutput, deprecate, logging, replace_example_docstring, - scale_lora_layers, - unscale_lora_layers, ) from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline @@ -162,48 +156,9 @@ def __init__( transformer: HunyuanVideoTransformer3DModel, scheduler: KarrasDiffusionSchedulers, text_encoder_2: Optional[TextEncoder] = None, - progress_bar_config: Dict[str, Any] = None, - args=None, ): super().__init__() - # ========================================================================================== - if progress_bar_config is None: - progress_bar_config = {} - if not hasattr(self, "_progress_bar_config"): - self._progress_bar_config = {} - self._progress_bar_config.update(progress_bar_config) - - self.args = args - # ========================================================================================== - - if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: - deprecation_message = ( - f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" - f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " - "to update the config accordingly as leaving `steps_offset` might led to incorrect results" - " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," - " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" - " file" - ) - deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) - new_config = dict(scheduler.config) - new_config["steps_offset"] = 1 - scheduler._internal_dict = FrozenDict(new_config) - - if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: - deprecation_message = ( - f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." - " `clip_sample` should be set to False in the configuration file. Please make sure to update the" - " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" - " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" - " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" - ) - deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) - new_config = dict(scheduler.config) - new_config["clip_sample"] = False - scheduler._internal_dict = FrozenDict(new_config) - self.register_modules( vae=vae, text_encoder=text_encoder, @@ -225,7 +180,6 @@ def encode_prompt( attention_mask: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, negative_attention_mask: Optional[torch.Tensor] = None, - lora_scale: Optional[float] = None, clip_skip: Optional[int] = None, text_encoder: Optional[TextEncoder] = None, data_type: Optional[str] = "image", @@ -266,17 +220,6 @@ def encode_prompt( if text_encoder is None: text_encoder = self.text_encoder - # set lora scale so that monkey patched LoRA - # function of text encoder can correctly access it - if lora_scale is not None and isinstance(self, LoraLoaderMixin): - self._lora_scale = lora_scale - - # dynamically adjust the LoRA scale - if not USE_PEFT_BACKEND: - adjust_lora_scale_text_encoder(text_encoder.model, lora_scale) - else: - scale_lora_layers(text_encoder.model, lora_scale) - if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): @@ -285,14 +228,13 @@ def encode_prompt( batch_size = prompt_embeds.shape[0] if prompt_embeds is None: - # textual inversion: process multi-vector tokens if necessary - if isinstance(self, TextualInversionLoaderMixin): - prompt = self.maybe_convert_prompt(prompt, text_encoder.tokenizer) - text_inputs = text_encoder.text2tokens(prompt, data_type=data_type) if clip_skip is None: prompt_outputs = text_encoder.encode(text_inputs, data_type=data_type, device=device) + # TODO(aryan): Don't know why it doesn't work without this + torch.cuda.synchronize() + prompt_embeds = prompt_outputs.hidden_state else: prompt_outputs = text_encoder.encode( @@ -359,10 +301,6 @@ def encode_prompt( else: uncond_tokens = negative_prompt - # textual inversion: process multi-vector tokens if necessary - if isinstance(self, TextualInversionLoaderMixin): - uncond_tokens = self.maybe_convert_prompt(uncond_tokens, text_encoder.tokenizer) - # max_length = prompt_embeds.shape[1] uncond_input = text_encoder.text2tokens(uncond_tokens, data_type=data_type) @@ -389,11 +327,6 @@ def encode_prompt( negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_videos_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) - if text_encoder is not None: - if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND: - # Retrieve the original scale by scaling back the LoRA layers - unscale_lora_layers(text_encoder.model, lora_scale) - return ( prompt_embeds, negative_prompt_embeds, @@ -419,19 +352,6 @@ def decode_latents(self, latents, enable_tiling=True): image = image.cpu().float() return image - def prepare_extra_func_kwargs(self, func, kwargs): - # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature - # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. - # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 - # and should be between [0, 1] - extra_step_kwargs = {} - - for k, v in kwargs.items(): - accepts = k in set(inspect.signature(func).parameters.keys()) - if accepts: - extra_step_kwargs[k] = v - return extra_step_kwargs - def check_inputs( self, prompt, @@ -442,19 +362,10 @@ def check_inputs( prompt_embeds=None, negative_prompt_embeds=None, callback_on_step_end_tensor_inputs=None, - vae_ver="88-4c-sd", ): if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") - if video_length is not None: - if "884" in vae_ver: - if video_length != 1 and (video_length - 1) % 4 != 0: - raise ValueError(f"`video_length` has to be 1 or a multiple of 4 but is {video_length}.") - elif "888" in vae_ver: - if video_length != 1 and (video_length - 1) % 8 != 0: - raise ValueError(f"`video_length` has to be 1 or a multiple of 8 but is {video_length}.") - if callback_on_step_end_tensor_inputs is not None and not all( k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs ): @@ -592,9 +503,9 @@ def interrupt(self): def __call__( self, prompt: Union[str, List[str]], - height: int, - width: int, - video_length: int, + height: int = 720, + width: int = 1280, + video_length: int = 129, data_type: str = "video", num_inference_steps: int = 50, timesteps: List[int] = None, @@ -611,7 +522,6 @@ def __call__( negative_attention_mask: Optional[torch.Tensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, - cross_attention_kwargs: Optional[Dict[str, Any]] = None, guidance_rescale: float = 0.0, clip_skip: Optional[int] = None, callback_on_step_end: Optional[ @@ -623,7 +533,6 @@ def __call__( ] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], freqs_cis: Tuple[torch.Tensor, torch.Tensor] = None, - vae_ver: str = "88-4c-sd", enable_tiling: bool = False, n_tokens: Optional[int] = None, embedded_guidance_scale: Optional[float] = None, @@ -675,14 +584,10 @@ def __call__( negative_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. - output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generated image. Choose between `PIL.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`HunyuanVideoPipelineOutput`] instead of a plain tuple. - cross_attention_kwargs (`dict`, *optional*): - A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in - [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). guidance_rescale (`float`, *optional*, defaults to 0.0): Guidance rescale factor from [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). Guidance rescale factor should fix overexposure when @@ -712,11 +617,6 @@ def __call__( if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs - # 0. Default height and width to unet - # height = height or self.transformer.config.sample_size * self.vae_scale_factor - # width = width or self.transformer.config.sample_size * self.vae_scale_factor - # to deal with lora scaling and other possible forward hooks - # 1. Check inputs. Raise error if not correct self.check_inputs( prompt, @@ -727,13 +627,11 @@ def __call__( prompt_embeds, negative_prompt_embeds, callback_on_step_end_tensor_inputs, - vae_ver=vae_ver, ) self._guidance_scale = guidance_scale self._guidance_rescale = guidance_rescale self._clip_skip = clip_skip - self._cross_attention_kwargs = cross_attention_kwargs self._interrupt = False # 2. Define call parameters @@ -744,13 +642,10 @@ def __call__( else: batch_size = prompt_embeds.shape[0] - device = self._execution_device + # TODO(aryan): No idea why it won't run without this + device = torch.device(self._execution_device) # 3. Encode input prompt - lora_scale = ( - self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None - ) - ( prompt_embeds, negative_prompt_embeds, @@ -766,7 +661,6 @@ def __call__( attention_mask=attention_mask, negative_prompt_embeds=negative_prompt_embeds, negative_attention_mask=negative_attention_mask, - lora_scale=lora_scale, clip_skip=self.clip_skip, data_type=data_type, ) @@ -786,7 +680,6 @@ def __call__( attention_mask=None, negative_prompt_embeds=None, negative_attention_mask=None, - lora_scale=lora_scale, clip_skip=self.clip_skip, text_encoder=self.text_encoder_2, data_type=data_type, @@ -810,26 +703,18 @@ def __call__( prompt_mask_2 = torch.cat([negative_prompt_mask_2, prompt_mask_2]) # 4. Prepare timesteps - extra_set_timesteps_kwargs = self.prepare_extra_func_kwargs( - self.scheduler.set_timesteps, {"n_tokens": n_tokens} - ) timesteps, num_inference_steps = retrieve_timesteps( self.scheduler, num_inference_steps, device, timesteps, sigmas, - **extra_set_timesteps_kwargs, ) - if "884" in vae_ver: - video_length = (video_length - 1) // 4 + 1 - elif "888" in vae_ver: - video_length = (video_length - 1) // 8 + 1 - else: - video_length = video_length + video_length = (video_length - 1) // 4 + 1 # 5. Prepare latent variables + target_dtype = torch.bfloat16 # Note(aryan): This has been hardcoded for now from the original repo num_channels_latents = self.transformer.config.in_channels latents = self.prepare_latents( batch_size * num_videos_per_prompt, @@ -837,22 +722,17 @@ def __call__( height, width, video_length, - prompt_embeds.dtype, + target_dtype, device, generator, latents, ) - # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline - extra_step_kwargs = self.prepare_extra_func_kwargs( - self.scheduler.step, - {"generator": generator, "eta": eta}, - ) - - target_dtype = PRECISION_TO_TYPE[self.args.precision] - autocast_enabled = (target_dtype != torch.float32) and not self.args.disable_autocast - vae_dtype = PRECISION_TO_TYPE[self.args.vae_precision] - vae_autocast_enabled = (vae_dtype != torch.float32) and not self.args.disable_autocast + prompt_embeds = prompt_embeds.to(target_dtype) + prompt_mask = prompt_mask.to(target_dtype) + if prompt_embeds_2 is not None: + prompt_embeds_2 = prompt_embeds_2.to(target_dtype) + vae_dtype = torch.float16 # Note(aryan): This has been hardcoded for now from the original repo # 7. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order @@ -866,9 +746,10 @@ def __call__( # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents - latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) - t_expand = t.repeat(latent_model_input.shape[0]) + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latent_model_input.shape[0]).to(latents.dtype) + guidance_expand = ( torch.tensor( [embedded_guidance_scale] * latent_model_input.shape[0], @@ -880,19 +761,17 @@ def __call__( else None ) - # predict the noise residual - with torch.autocast(device_type="cuda", dtype=target_dtype, enabled=autocast_enabled): - noise_pred = self.transformer( # For an input image (129, 192, 336) (1, 256, 256) - latent_model_input, # [2, 16, 33, 24, 42] - t_expand, # [2] - text_states=prompt_embeds, # [2, 256, 4096] - text_mask=prompt_mask, # [2, 256] - text_states_2=prompt_embeds_2, # [2, 768] - freqs_cos=freqs_cis[0], # [seqlen, head_dim] - freqs_sin=freqs_cis[1], # [seqlen, head_dim] - guidance=guidance_expand, - return_dict=True, - )["x"] + noise_pred = self.transformer( # For an input image (129, 192, 336) (1, 256, 256) + latent_model_input, # [2, 16, 33, 24, 42] + timestep, # [2] + text_states=prompt_embeds, # [2, 256, 4096] + text_mask=prompt_mask, # [2, 256] + text_states_2=prompt_embeds_2, # [2, 768] + freqs_cos=freqs_cis[0], # [seqlen, head_dim] + freqs_sin=freqs_cis[1], # [seqlen, head_dim] + guidance=guidance_expand, + return_dict=True, + )["x"] # perform guidance if self.do_classifier_free_guidance: @@ -908,7 +787,9 @@ def __call__( ) # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + torch.save(latents, f"latents_{i}.pt") if callback_on_step_end is not None: callback_kwargs = {} @@ -924,6 +805,7 @@ def __call__( if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() + latents = latents.to(vae_dtype) if not output_type == "latent": expand_temporal_dim = False if len(latents.shape) == 4: @@ -941,12 +823,11 @@ def __call__( else: latents = latents / self.vae.config.scaling_factor - with torch.autocast(device_type="cuda", dtype=vae_dtype, enabled=vae_autocast_enabled): - if enable_tiling: - self.vae.enable_tiling() - image = self.vae.decode(latents, return_dict=False, generator=generator)[0] - else: - image = self.vae.decode(latents, return_dict=False, generator=generator)[0] + if enable_tiling: + self.vae.enable_tiling() + image = self.vae.decode(latents, return_dict=False, generator=generator)[0] + else: + image = self.vae.decode(latents, return_dict=False, generator=generator)[0] if expand_temporal_dim or image.shape[2] == 1: image = image.squeeze(2) diff --git a/src/diffusers/pipelines/hunyuan_video/text_encoder.py b/src/diffusers/pipelines/hunyuan_video/text_encoder.py index d7281799d8c6..4a259cafa5f7 100644 --- a/src/diffusers/pipelines/hunyuan_video/text_encoder.py +++ b/src/diffusers/pipelines/hunyuan_video/text_encoder.py @@ -14,20 +14,6 @@ "bf16": torch.bfloat16, } -MODEL_BASE = os.getenv("MODEL_BASE", "./ckpts") - -# Text Encoder -TEXT_ENCODER_PATH = { - "clipL": f"{MODEL_BASE}/text_encoder_2", - "llm": f"{MODEL_BASE}/text_encoder", -} - -# Tokenizer -TOKENIZER_PATH = { - "clipL": f"{MODEL_BASE}/text_encoder_2", - "llm": f"{MODEL_BASE}/text_encoder", -} - def use_default(value, default): return value if value is not None else default @@ -37,13 +23,10 @@ def load_text_encoder( text_encoder_type, text_encoder_precision=None, text_encoder_path=None, - logger=None, device=None, ): if text_encoder_path is None: - text_encoder_path = TEXT_ENCODER_PATH[text_encoder_type] - if logger is not None: - logger.info(f"Loading text encoder model ({text_encoder_type}) from: {text_encoder_path}") + raise ValueError("text_encoder_path must be provided.") if text_encoder_type == "clipL": text_encoder = CLIPTextModel.from_pretrained(text_encoder_path) @@ -60,20 +43,15 @@ def load_text_encoder( text_encoder.requires_grad_(False) - if logger is not None: - logger.info(f"Text encoder to dtype: {text_encoder.dtype}") - if device is not None: text_encoder = text_encoder.to(device) return text_encoder, text_encoder_path -def load_tokenizer(tokenizer_type, tokenizer_path=None, padding_side="right", logger=None): +def load_tokenizer(tokenizer_type, tokenizer_path=None, padding_side="right"): if tokenizer_path is None: - tokenizer_path = TOKENIZER_PATH[tokenizer_type] - if logger is not None: - logger.info(f"Loading tokenizer ({tokenizer_type}) from: {tokenizer_path}") + raise ValueError("tokenizer_path must be provided.") if tokenizer_type == "clipL": tokenizer = CLIPTokenizer.from_pretrained(tokenizer_path, max_length=77) @@ -126,8 +104,6 @@ def __init__( hidden_state_skip_layer: Optional[int] = None, apply_final_norm: bool = False, reproduce: bool = False, - logger=None, - device=None, ): super().__init__() self.text_encoder_type = text_encoder_type @@ -145,7 +121,6 @@ def __init__( self.hidden_state_skip_layer = hidden_state_skip_layer self.apply_final_norm = apply_final_norm self.reproduce = reproduce - self.logger = logger self.use_template = self.prompt_template is not None if self.use_template: @@ -181,17 +156,15 @@ def __init__( text_encoder_type=self.text_encoder_type, text_encoder_precision=self.precision, text_encoder_path=self.model_path, - logger=self.logger, - device=device, + device="cuda", ) self.dtype = self.model.dtype - self.device = self.model.device + self.device = "cuda" self.tokenizer, self.tokenizer_path = load_tokenizer( tokenizer_type=self.tokenizer_type, tokenizer_path=self.tokenizer_path, padding_side="right", - logger=self.logger, ) def __repr__(self): @@ -295,8 +268,13 @@ def encode( hidden_state_skip_layer = use_default(hidden_state_skip_layer, self.hidden_state_skip_layer) do_sample = use_default(do_sample, not self.reproduce) attention_mask = batch_encoding["attention_mask"].to(device) if use_attention_mask else None + input_ids = batch_encoding["input_ids"].to(device) + + # No idea why it doesn't work without this + torch.cuda.synchronize() + outputs = self.model( - input_ids=batch_encoding["input_ids"].to(device), + input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=output_hidden_states or hidden_state_skip_layer is not None, ) From fbe5031e8bb1aaafef7dc6b13c2a4733676582dc Mon Sep 17 00:00:00 2001 From: Aryan Date: Tue, 10 Dec 2024 21:23:18 +0100 Subject: [PATCH 06/58] move rope into pipeline; remove flash attention; refactor --- .../transformers/transformer_hunyuan_video.py | 172 +------ .../hunyuan_video/pipeline_hunyuan_video.py | 479 ++++++++++-------- 2 files changed, 265 insertions(+), 386 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_hunyuan_video.py b/src/diffusers/models/transformers/transformer_hunyuan_video.py index 4dfd4816f968..d4c4025d1985 100644 --- a/src/diffusers/models/transformers/transformer_hunyuan_video.py +++ b/src/diffusers/models/transformers/transformer_hunyuan_video.py @@ -49,32 +49,6 @@ } -def get_cu_seqlens(text_mask, img_len): - """Calculate cu_seqlens_q, cu_seqlens_kv using text_mask and img_len - - Args: - text_mask (torch.Tensor): the mask of text - img_len (int): the length of image - - Returns: - torch.Tensor: the calculated cu_seqlens for flash attention - """ - batch_size = text_mask.shape[0] - text_len = text_mask.sum(dim=1) - max_len = text_mask.shape[1] + img_len - - cu_seqlens = torch.zeros([2 * batch_size + 1], dtype=torch.int32, device="cuda") - - for i in range(batch_size): - s = text_len[i] + img_len - s1 = i * max_len + s - s2 = (i + 1) * max_len - cu_seqlens[2 * i + 1] = s1 - cu_seqlens[2 * i + 2] = s2 - - return cu_seqlens - - def attention( q, k, @@ -83,34 +57,8 @@ def attention( drop_rate=0, attn_mask=None, causal=False, - cu_seqlens_q=None, - cu_seqlens_kv=None, - max_seqlen_q=None, - max_seqlen_kv=None, batch_size=1, ): - """ - Perform QKV self attention. - - Args: - q (torch.Tensor): Query tensor with shape [b, s, a, d], where a is the number of heads. - k (torch.Tensor): Key tensor with shape [b, s1, a, d] - v (torch.Tensor): Value tensor with shape [b, s1, a, d] - mode (str): Attention mode. Choose from 'self_flash', 'cross_flash', 'torch', and 'vanilla'. - drop_rate (float): Dropout rate in attention map. (default: 0) - attn_mask (torch.Tensor): Attention mask with shape [b, s1] (cross_attn), or [b, a, s, s1] (torch or vanilla). - (default: None) - causal (bool): Whether to use causal attention. (default: False) - cu_seqlens_q (torch.Tensor): dtype torch.int32. The cumulative sequence lengths of the sequences in the batch, - used to index into q. - cu_seqlens_kv (torch.Tensor): - dtype torch.int32. The cumulative sequence lengths of the sequences in the batch, used to index into kv. - max_seqlen_q (int): The maximum sequence length in the batch of q. - max_seqlen_kv (int): The maximum sequence length in the batch of k and v. - - Returns: - torch.Tensor: Output tensor after self attention with shape [b, s, ad] - """ pre_attn_layout, post_attn_layout = MEMORY_LAYOUT[mode] q = pre_attn_layout(q) k = pre_attn_layout(k) @@ -120,45 +68,6 @@ def attention( if attn_mask is not None and attn_mask.dtype != torch.bool: attn_mask = attn_mask.to(q.dtype) x = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, dropout_p=drop_rate, is_causal=causal) - elif mode == "flash": - x = flash_attn_varlen_func( - q, - k, - v, - cu_seqlens_q, - cu_seqlens_kv, - max_seqlen_q, - max_seqlen_kv, - ) - # x with shape [(bxs), a, d] - x = x.view(batch_size, max_seqlen_q, x.shape[-2], x.shape[-1]) # reshape x to [b, s, a, d] - elif mode == "vanilla": - scale_factor = 1 / math.sqrt(q.size(-1)) - - b, a, s, _ = q.shape - s1 = k.size(2) - attn_bias = torch.zeros(b, a, s, s1, dtype=q.dtype, device=q.device) - if causal: - # Only applied to self attention - assert attn_mask is None, "Causal mask and attn_mask cannot be used together" - temp_mask = torch.ones(b, a, s, s, dtype=torch.bool, device=q.device).tril(diagonal=0) - attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) - attn_bias.to(q.dtype) - - if attn_mask is not None: - if attn_mask.dtype == torch.bool: - attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf")) - else: - attn_bias += attn_mask - - # TODO: Maybe force q and k to be float32 to avoid numerical overflow - attn = (q @ k.transpose(-2, -1)) * scale_factor - attn += attn_bias - attn = attn.softmax(dim=-1) - attn = torch.dropout(attn, p=drop_rate, train=True) - x = attn @ v - else: - raise NotImplementedError(f"Unsupported attention mode: {mode}") x = post_attn_layout(x) b, s, a, d = x.shape @@ -407,25 +316,6 @@ def modulate(x, shift=None, scale=None): return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) -def apply_gate(x, gate=None, tanh=False): - """AI is creating summary for apply_gate - - Args: - x (torch.Tensor): input tensor. - gate (torch.Tensor, optional): gate tensor. Defaults to None. - tanh (bool, optional): whether to use tanh function. Defaults to False. - - Returns: - torch.Tensor: the output tensor after apply gate. - """ - if gate is None: - return x - if tanh: - return x * gate.unsqueeze(1).tanh() - else: - return x * gate.unsqueeze(1) - - class ModulateDiT(nn.Module): """Modulation layer for DiT.""" @@ -641,8 +531,6 @@ def __init__( act_layer(), nn.Linear(hidden_size, out_size, bias=True), ) - nn.init.normal_(self.mlp[0].weight, std=0.02) - nn.init.normal_(self.mlp[2].weight, std=0.02) def forward(self, t): t_freq = timestep_embedding(t, self.frequency_embedding_size, self.max_period).type(self.mlp[0].weight.dtype) @@ -710,10 +598,10 @@ def forward( # Self-Attention attn = attention(q, k, v, mode="torch", attn_mask=attn_mask) - x = x + apply_gate(self.self_attn_proj(attn), gate_msa) + x = x + self.self_attn_proj(attn) * gate_msa.unsqueeze(1) # FFN Layer - x = x + apply_gate(self.mlp(self.norm2(x)), gate_mlp) + x = x + self.mlp(self.norm2(x)) * gate_mlp.unsqueeze(1) return x @@ -919,10 +807,6 @@ def forward( img: torch.Tensor, txt: torch.Tensor, vec: torch.Tensor, - cu_seqlens_q: Optional[torch.Tensor] = None, - cu_seqlens_kv: Optional[torch.Tensor] = None, - max_seqlen_q: Optional[int] = None, - max_seqlen_kv: Optional[int] = None, freqs_cis: tuple = None, ) -> Tuple[torch.Tensor, torch.Tensor]: ( @@ -972,36 +856,23 @@ def forward( q = torch.cat((img_q, txt_q), dim=1) k = torch.cat((img_k, txt_k), dim=1) v = torch.cat((img_v, txt_v), dim=1) - assert ( - cu_seqlens_q.shape[0] == 2 * img.shape[0] + 1 - ), f"cu_seqlens_q.shape:{cu_seqlens_q.shape}, img.shape[0]:{img.shape[0]}" attn = attention( q, k, v, - cu_seqlens_q=cu_seqlens_q, - cu_seqlens_kv=cu_seqlens_kv, - max_seqlen_q=max_seqlen_q, - max_seqlen_kv=max_seqlen_kv, batch_size=img_k.shape[0], ) img_attn, txt_attn = attn[:, : img.shape[1]], attn[:, img.shape[1] :] # Calculate the img bloks. - img = img + apply_gate(self.img_attn_proj(img_attn), gate=img_mod1_gate) - img = img + apply_gate( - self.img_mlp(modulate(self.img_norm2(img), shift=img_mod2_shift, scale=img_mod2_scale)), - gate=img_mod2_gate, - ) + img = img + self.img_attn_proj(img_attn) * img_mod1_gate.unsqueeze(1) + img = img + self.img_mlp(modulate(self.img_norm2(img), shift=img_mod2_shift, scale=img_mod2_scale)) * img_mod2_gate.unsqueeze(1) # Calculate the txt bloks. - txt = txt + apply_gate(self.txt_attn_proj(txt_attn), gate=txt_mod1_gate) - txt = txt + apply_gate( - self.txt_mlp(modulate(self.txt_norm2(txt), shift=txt_mod2_shift, scale=txt_mod2_scale)), - gate=txt_mod2_gate, - ) - + txt = txt + self.txt_attn_proj(txt_attn) * txt_mod1_gate.unsqueeze(1) + txt = txt + self.txt_mlp(modulate(self.txt_norm2(txt), shift=txt_mod2_shift, scale=txt_mod2_scale)) * txt_mod2_gate.unsqueeze(1) + return img, txt @@ -1059,10 +930,6 @@ def forward( x: torch.Tensor, vec: torch.Tensor, txt_len: int, - cu_seqlens_q: Optional[torch.Tensor] = None, - cu_seqlens_kv: Optional[torch.Tensor] = None, - max_seqlen_q: Optional[int] = None, - max_seqlen_kv: Optional[int] = None, freqs_cis: Tuple[torch.Tensor, torch.Tensor] = None, ) -> torch.Tensor: mod_shift, mod_scale, mod_gate = self.modulation(vec).chunk(3, dim=-1) @@ -1087,24 +954,17 @@ def forward( q = torch.cat((img_q, txt_q), dim=1) k = torch.cat((img_k, txt_k), dim=1) - # Compute attention. - assert ( - cu_seqlens_q.shape[0] == 2 * x.shape[0] + 1 - ), f"cu_seqlens_q.shape:{cu_seqlens_q.shape}, x.shape[0]:{x.shape[0]}" attn = attention( q, k, v, - cu_seqlens_q=cu_seqlens_q, - cu_seqlens_kv=cu_seqlens_kv, - max_seqlen_q=max_seqlen_q, - max_seqlen_kv=max_seqlen_kv, batch_size=x.shape[0], ) # Compute activation in mlp stream, cat again and run second linear layer. output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2)) - return x + apply_gate(output, gate=mod_gate) + output = x + output * mod_gate.unsqueeze(1) + return output class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin): @@ -1292,12 +1152,6 @@ def forward( txt_seq_len = txt.shape[1] img_seq_len = img.shape[1] - # Compute cu_squlens and max_seqlen for flash attention - cu_seqlens_q = get_cu_seqlens(text_mask, img_seq_len) - cu_seqlens_kv = cu_seqlens_q - max_seqlen_q = img_seq_len + txt_seq_len - max_seqlen_kv = max_seqlen_q - freqs_cis = (freqs_cos, freqs_sin) if freqs_cos is not None else None # --------------------- Pass through DiT blocks ------------------------ for _, block in enumerate(self.double_blocks): @@ -1305,10 +1159,6 @@ def forward( img, txt, vec, - cu_seqlens_q, - cu_seqlens_kv, - max_seqlen_q, - max_seqlen_kv, freqs_cis, ] @@ -1322,10 +1172,6 @@ def forward( x, vec, txt_seq_len, - cu_seqlens_q, - cu_seqlens_kv, - max_seqlen_q, - max_seqlen_kv, (freqs_cos, freqs_sin), ] diff --git a/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py b/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py index 3b9fc8ae3002..dbd2393a55ac 100644 --- a/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +++ b/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py @@ -39,20 +39,6 @@ EXAMPLE_DOC_STRING = """""" -def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): - """ - Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and - Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 - """ - std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) - std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) - # rescale the results from guidance (fixes overexposure) - noise_pred_rescaled = noise_cfg * (std_text / std_cfg) - # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images - noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg - return noise_cfg - - def retrieve_timesteps( scheduler, num_inference_steps: Optional[int] = None, @@ -174,12 +160,8 @@ def encode_prompt( prompt, device, num_videos_per_prompt, - do_classifier_free_guidance, - negative_prompt=None, prompt_embeds: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, - negative_prompt_embeds: Optional[torch.Tensor] = None, - negative_attention_mask: Optional[torch.Tensor] = None, clip_skip: Optional[int] = None, text_encoder: Optional[TextEncoder] = None, data_type: Optional[str] = "image", @@ -194,23 +176,10 @@ def encode_prompt( torch device num_videos_per_prompt (`int`): number of videos that should be generated per prompt - do_classifier_free_guidance (`bool`): - whether to use classifier free guidance or not - negative_prompt (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the video generation. If not defined, one has to pass - `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is - less than `1`). prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. attention_mask (`torch.Tensor`, *optional*): - negative_prompt_embeds (`torch.Tensor`, *optional*): - Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt - weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input - argument. - negative_attention_mask (`torch.Tensor`, *optional*): - lora_scale (`float`, *optional*): - A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. @@ -280,90 +249,22 @@ def encode_prompt( prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) prompt_embeds = prompt_embeds.view(bs_embed * num_videos_per_prompt, seq_len, -1) - # get unconditional embeddings for classifier free guidance - if do_classifier_free_guidance and negative_prompt_embeds is None: - uncond_tokens: List[str] - if negative_prompt is None: - uncond_tokens = [""] * batch_size - elif prompt is not None and type(prompt) is not type(negative_prompt): - raise TypeError( - f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" - f" {type(prompt)}." - ) - elif isinstance(negative_prompt, str): - uncond_tokens = [negative_prompt] - elif batch_size != len(negative_prompt): - raise ValueError( - f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" - f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" - " the batch size of `prompt`." - ) - else: - uncond_tokens = negative_prompt - - # max_length = prompt_embeds.shape[1] - uncond_input = text_encoder.text2tokens(uncond_tokens, data_type=data_type) - - negative_prompt_outputs = text_encoder.encode(uncond_input, data_type=data_type, device=device) - negative_prompt_embeds = negative_prompt_outputs.hidden_state - - negative_attention_mask = negative_prompt_outputs.attention_mask - if negative_attention_mask is not None: - negative_attention_mask = negative_attention_mask.to(device) - _, seq_len = negative_attention_mask.shape - negative_attention_mask = negative_attention_mask.repeat(1, num_videos_per_prompt) - negative_attention_mask = negative_attention_mask.view(batch_size * num_videos_per_prompt, seq_len) - - if do_classifier_free_guidance: - # duplicate unconditional embeddings for each generation per prompt, using mps friendly method - seq_len = negative_prompt_embeds.shape[1] - - negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) - - if negative_prompt_embeds.ndim == 2: - negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_videos_per_prompt) - negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_videos_per_prompt, -1) - else: - negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_videos_per_prompt, 1) - negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) - return ( prompt_embeds, - negative_prompt_embeds, attention_mask, - negative_attention_mask, ) - def decode_latents(self, latents, enable_tiling=True): - deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead" - deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) - - latents = 1 / self.vae.config.scaling_factor * latents - if enable_tiling: - self.vae.enable_tiling() - image = self.vae.decode(latents, return_dict=False)[0] - else: - image = self.vae.decode(latents, return_dict=False)[0] - image = (image / 2 + 0.5).clamp(0, 1) - # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 - if image.ndim == 4: - image = image.cpu().permute(0, 2, 3, 1).float() - else: - image = image.cpu().float() - return image - def check_inputs( self, prompt, + prompt_2, height, width, video_length, - negative_prompt=None, prompt_embeds=None, - negative_prompt_embeds=None, callback_on_step_end_tensor_inputs=None, ): - if height % 8 != 0 or width % 8 != 0: + if height % 16 != 0 or width % 16 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") if callback_on_step_end_tensor_inputs is not None and not all( @@ -378,26 +279,19 @@ def check_inputs( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) + elif prompt_2 is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") - - if negative_prompt is not None and negative_prompt_embeds is not None: - raise ValueError( - f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" - f" {negative_prompt_embeds}. Please make sure to only forward one of the two." - ) - - if prompt_embeds is not None and negative_prompt_embeds is not None: - if prompt_embeds.shape != negative_prompt_embeds.shape: - raise ValueError( - "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" - f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" - f" {negative_prompt_embeds.shape}." - ) + elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): + raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") def prepare_latents( self, @@ -465,30 +359,230 @@ def get_guidance_scale_embedding( emb = torch.nn.functional.pad(emb, (0, 1)) assert emb.shape == (w.shape[0], embedding_dim) return emb + + def get_rotary_pos_embed(self, video_length, height, width): + def _to_tuple(x, dim=2): + if isinstance(x, int): + return (x,) * dim + elif len(x) == dim: + return x + else: + raise ValueError(f"Expected length {dim} or int, but got {x}") + + + def get_meshgrid_nd(start, *args, dim=2): + """ + Get n-D meshgrid with start, stop and num. + + Args: + start (int or tuple): If len(args) == 0, start is num; If len(args) == 1, start is start, args[0] is stop, + step is 1; If len(args) == 2, start is start, args[0] is stop, args[1] is num. For n-dim, start/stop/num + should be int or n-tuple. If n-tuple is provided, the meshgrid will be stacked following the dim order in + n-tuples. + *args: See above. + dim (int): Dimension of the meshgrid. Defaults to 2. + + Returns: + grid (np.ndarray): [dim, ...] + """ + if len(args) == 0: + # start is grid_size + num = _to_tuple(start, dim=dim) + start = (0,) * dim + stop = num + elif len(args) == 1: + # start is start, args[0] is stop, step is 1 + start = _to_tuple(start, dim=dim) + stop = _to_tuple(args[0], dim=dim) + num = [stop[i] - start[i] for i in range(dim)] + elif len(args) == 2: + # start is start, args[0] is stop, args[1] is num + start = _to_tuple(start, dim=dim) # Left-Top eg: 12,0 + stop = _to_tuple(args[0], dim=dim) # Right-Bottom eg: 20,32 + num = _to_tuple(args[1], dim=dim) # Target Size eg: 32,124 + else: + raise ValueError(f"len(args) should be 0, 1 or 2, but got {len(args)}") + + # PyTorch implement of np.linspace(start[i], stop[i], num[i], endpoint=False) + axis_grid = [] + for i in range(dim): + a, b, n = start[i], stop[i], num[i] + g = torch.linspace(a, b, n + 1, dtype=torch.float32)[:n] + axis_grid.append(g) + grid = torch.meshgrid(*axis_grid, indexing="ij") # dim x [W, H, D] + grid = torch.stack(grid, dim=0) # [dim, W, H, D] + + return grid + + def get_1d_rotary_pos_embed( + dim: int, + pos: Union[torch.FloatTensor, int], + theta: float = 10000.0, + use_real: bool = False, + theta_rescale_factor: float = 1.0, + interpolation_factor: float = 1.0, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + """ + Precompute the frequency tensor for complex exponential (cis) with given dimensions. + (Note: `cis` means `cos + i * sin`, where i is the imaginary unit.) + + This function calculates a frequency tensor with complex exponential using the given dimension 'dim' + and the end index 'end'. The 'theta' parameter scales the frequencies. + The returned tensor contains complex values in complex64 data type. + + Args: + dim (int): Dimension of the frequency tensor. + pos (int or torch.FloatTensor): Position indices for the frequency tensor. [S] or scalar + theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0. + use_real (bool, optional): If True, return real part and imaginary part separately. + Otherwise, return complex numbers. + theta_rescale_factor (float, optional): Rescale factor for theta. Defaults to 1.0. + + Returns: + freqs_cis: Precomputed frequency tensor with complex exponential. [S, D/2] + freqs_cos, freqs_sin: Precomputed frequency tensor with real and imaginary parts separately. [S, D] + """ + if isinstance(pos, int): + pos = torch.arange(pos).float() + + # proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning + # has some connection to NTK literature + if theta_rescale_factor != 1.0: + theta *= theta_rescale_factor ** (dim / (dim - 2)) + + freqs = 1.0 / ( + theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim) + ) # [D/2] + # assert interpolation_factor == 1.0, f"interpolation_factor: {interpolation_factor}" + freqs = torch.outer(pos * interpolation_factor, freqs) # [S, D/2] + if use_real: + freqs_cos = freqs.cos().repeat_interleave(2, dim=1) # [S, D] + freqs_sin = freqs.sin().repeat_interleave(2, dim=1) # [S, D] + return freqs_cos, freqs_sin + else: + freqs_cis = torch.polar( + torch.ones_like(freqs), freqs + ) # complex64 # [S, D/2] + return freqs_cis + + + def get_nd_rotary_pos_embed( + rope_dim_list, + start, + *args, + theta=10000.0, + use_real=False, + theta_rescale_factor: Union[float, List[float]] = 1.0, + interpolation_factor: Union[float, List[float]] = 1.0, + ): + """ + This is a n-d version of precompute_freqs_cis, which is a RoPE for tokens with n-d structure. + + Args: + rope_dim_list (list of int): Dimension of each rope. len(rope_dim_list) should equal to n. + sum(rope_dim_list) should equal to head_dim of attention layer. + start (int | tuple of int | list of int): If len(args) == 0, start is num; If len(args) == 1, start is start, + args[0] is stop, step is 1; If len(args) == 2, start is start, args[0] is stop, args[1] is num. + *args: See above. + theta (float): Scaling factor for frequency computation. Defaults to 10000.0. + use_real (bool): If True, return real part and imaginary part separately. Otherwise, return complex numbers. + Some libraries such as TensorRT does not support complex64 data type. So it is useful to provide a real + part and an imaginary part separately. + theta_rescale_factor (float): Rescale factor for theta. Defaults to 1.0. + + Returns: + pos_embed (torch.Tensor): [HW, D/2] + """ + + grid = get_meshgrid_nd( + start, *args, dim=len(rope_dim_list) + ) # [3, W, H, D] / [2, W, H] + + if isinstance(theta_rescale_factor, int) or isinstance(theta_rescale_factor, float): + theta_rescale_factor = [theta_rescale_factor] * len(rope_dim_list) + elif isinstance(theta_rescale_factor, list) and len(theta_rescale_factor) == 1: + theta_rescale_factor = [theta_rescale_factor[0]] * len(rope_dim_list) + assert len(theta_rescale_factor) == len( + rope_dim_list + ), "len(theta_rescale_factor) should equal to len(rope_dim_list)" + + if isinstance(interpolation_factor, int) or isinstance(interpolation_factor, float): + interpolation_factor = [interpolation_factor] * len(rope_dim_list) + elif isinstance(interpolation_factor, list) and len(interpolation_factor) == 1: + interpolation_factor = [interpolation_factor[0]] * len(rope_dim_list) + assert len(interpolation_factor) == len( + rope_dim_list + ), "len(interpolation_factor) should equal to len(rope_dim_list)" + + # use 1/ndim of dimensions to encode grid_axis + embs = [] + for i in range(len(rope_dim_list)): + emb = get_1d_rotary_pos_embed( + rope_dim_list[i], + grid[i].reshape(-1), + theta, + use_real=use_real, + theta_rescale_factor=theta_rescale_factor[i], + interpolation_factor=interpolation_factor[i], + ) # 2 x [WHD, rope_dim_list[i]] + embs.append(emb) + + if use_real: + cos = torch.cat([emb[0] for emb in embs], dim=1) # (WHD, D/2) + sin = torch.cat([emb[1] for emb in embs], dim=1) # (WHD, D/2) + return cos, sin + else: + emb = torch.cat(embs, dim=1) # (WHD, D/2) + return emb + + + target_ndim = 3 + ndim = 5 - 2 + # 884 + latents_size = [(video_length - 1) // 4 + 1, height // 8, width // 8] + + assert all( + s % self.transformer.config.patch_size[idx] == 0 + for idx, s in enumerate(latents_size) + ), ( + f"Latent size(last {ndim} dimensions) should be divisible by patch size ({self.transformer.config.patch_size}), " + f"but got {latents_size}." + ) + rope_sizes = [ + s // self.transformer.config.patch_size[idx] for idx, s in enumerate(latents_size) + ] + + if len(rope_sizes) != target_ndim: + rope_sizes = [1] * (target_ndim - len(rope_sizes)) + rope_sizes # time axis + head_dim = self.transformer.config.hidden_size // self.transformer.config.heads_num + rope_dim_list = self.transformer.config.rope_dim_list + if rope_dim_list is None: + rope_dim_list = [head_dim // target_ndim for _ in range(target_ndim)] + assert ( + sum(rope_dim_list) == head_dim + ), "sum(rope_dim_list) should equal to head_dim of attention layer" + + freqs_cos, freqs_sin = get_nd_rotary_pos_embed( + rope_dim_list, + rope_sizes, + theta=256, + use_real=True, + theta_rescale_factor=1, + ) + + return freqs_cos, freqs_sin @property def guidance_scale(self): return self._guidance_scale - @property - def guidance_rescale(self): - return self._guidance_rescale - @property def clip_skip(self): return self._clip_skip - # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) - # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` - # corresponds to doing no classifier free guidance. - @property - def do_classifier_free_guidance(self): - # return self._guidance_scale > 1 and self.transformer.config.time_cond_proj_dim is None - return self._guidance_scale > 1 - @property - def cross_attention_kwargs(self): - return self._cross_attention_kwargs + def guidance_scale(self): + return self._guidance_scale @property def num_timesteps(self): @@ -502,47 +596,45 @@ def interrupt(self): @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, - prompt: Union[str, List[str]], + prompt: Union[str, List[str]] = None, + prompt_2: Union[str, List[str]] = None, height: int = 720, width: int = 1280, video_length: int = 129, data_type: str = "video", num_inference_steps: int = 50, - timesteps: List[int] = None, sigmas: List[float] = None, - guidance_scale: float = 7.5, - negative_prompt: Optional[Union[str, List[str]]] = None, + guidance_scale: float = 6.0, num_videos_per_prompt: Optional[int] = 1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, prompt_embeds: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - negative_prompt_embeds: Optional[torch.Tensor] = None, - negative_attention_mask: Optional[torch.Tensor] = None, + prompt_embeds_2: Optional[torch.Tensor] = None, + prompt_attention_mask: Optional[torch.Tensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, - guidance_rescale: float = 0.0, clip_skip: Optional[int] = None, callback_on_step_end: Optional[ Union[ Callable[[int, int, Dict], None], PipelineCallback, - MultiPipelineCallbacks, + MultiPipelineCallbacks ] ] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], - freqs_cis: Tuple[torch.Tensor, torch.Tensor] = None, enable_tiling: bool = False, - n_tokens: Optional[int] = None, - embedded_guidance_scale: Optional[float] = None, ): r""" The call function to the pipeline for generation. Args: - prompt (`str` or `List[str]`): - The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + will be used instead. height (`int`): The height in pixels of the generated image. width (`int`): @@ -560,12 +652,13 @@ def __call__( Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed will be used. - guidance_scale (`float`, *optional*, defaults to 7.5): - A higher guidance scale value encourages the model to generate images closely linked to the text - `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. - negative_prompt (`str` or `List[str]`, *optional*): - The prompt or prompts to guide what to not include in image generation. If not defined, you need to - pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + guidance_scale (`float`, defaults to `6.0`): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). + Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate + images that are closely linked to the text `prompt`, usually at the expense of lower image quality. + Note that the only available HunyuanVideo model is CFG-distilled, which means that traditional guidance + between unconditional and conditional latent is not applied. num_videos_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): @@ -581,17 +674,10 @@ def __call__( prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, text embeddings are generated from the `prompt` input argument. - negative_prompt_embeds (`torch.Tensor`, *optional*): - Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If - not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generated image. Choose between `PIL.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`HunyuanVideoPipelineOutput`] instead of a plain tuple. - guidance_rescale (`float`, *optional*, defaults to 0.0): - Guidance rescale factor from [Common Diffusion Noise Schedules and Sample Steps are - Flawed](https://arxiv.org/pdf/2305.08891.pdf). Guidance rescale factor should fix overexposure when - using zero terminal SNR. clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. @@ -620,17 +706,15 @@ def __call__( # 1. Check inputs. Raise error if not correct self.check_inputs( prompt, + prompt_2, height, width, video_length, - negative_prompt, prompt_embeds, - negative_prompt_embeds, callback_on_step_end_tensor_inputs, ) self._guidance_scale = guidance_scale - self._guidance_rescale = guidance_rescale self._clip_skip = clip_skip self._interrupt = False @@ -648,71 +732,43 @@ def __call__( # 3. Encode input prompt ( prompt_embeds, - negative_prompt_embeds, prompt_mask, - negative_prompt_mask, ) = self.encode_prompt( prompt, device, num_videos_per_prompt, - self.do_classifier_free_guidance, - negative_prompt, prompt_embeds=prompt_embeds, - attention_mask=attention_mask, - negative_prompt_embeds=negative_prompt_embeds, - negative_attention_mask=negative_attention_mask, + attention_mask=prompt_attention_mask, clip_skip=self.clip_skip, data_type=data_type, ) + if self.text_encoder_2 is not None: ( prompt_embeds_2, - negative_prompt_embeds_2, prompt_mask_2, - negative_prompt_mask_2, ) = self.encode_prompt( prompt, device, num_videos_per_prompt, - self.do_classifier_free_guidance, - negative_prompt, - prompt_embeds=None, + prompt_embeds=prompt_embeds_2, attention_mask=None, - negative_prompt_embeds=None, - negative_attention_mask=None, clip_skip=self.clip_skip, text_encoder=self.text_encoder_2, data_type=data_type, ) else: prompt_embeds_2 = None - negative_prompt_embeds_2 = None prompt_mask_2 = None - negative_prompt_mask_2 = None - - # For classifier free guidance, we need to do two forward passes. - # Here we concatenate the unconditional and text embeddings into a single batch - # to avoid doing two forward passes - if self.do_classifier_free_guidance: - prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) - if prompt_mask is not None: - prompt_mask = torch.cat([negative_prompt_mask, prompt_mask]) - if prompt_embeds_2 is not None: - prompt_embeds_2 = torch.cat([negative_prompt_embeds_2, prompt_embeds_2]) - if prompt_mask_2 is not None: - prompt_mask_2 = torch.cat([negative_prompt_mask_2, prompt_mask_2]) # 4. Prepare timesteps timesteps, num_inference_steps = retrieve_timesteps( self.scheduler, num_inference_steps, device, - timesteps, - sigmas, + sigmas=sigmas, ) - video_length = (video_length - 1) // 4 + 1 - # 5. Prepare latent variables target_dtype = torch.bfloat16 # Note(aryan): This has been hardcoded for now from the original repo num_channels_latents = self.transformer.config.in_channels @@ -721,7 +777,7 @@ def __call__( num_channels_latents, height, width, - video_length, + (video_length - 1) // 4 + 1, target_dtype, device, generator, @@ -738,58 +794,35 @@ def __call__( num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order self._num_timesteps = len(timesteps) + image_rotary_emb = self.get_rotary_pos_embed(video_length, height, width) + # if is_progress_bar: with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): if self.interrupt: continue - # expand the latents if we are doing classifier free guidance - latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents - # broadcast to batch dimension in a way that's compatible with ONNX/Core ML - timestep = t.expand(latent_model_input.shape[0]).to(latents.dtype) + timestep = t.expand(latents.shape[0]).to(latents.dtype) - guidance_expand = ( - torch.tensor( - [embedded_guidance_scale] * latent_model_input.shape[0], - dtype=torch.float32, - device=device, - ).to(target_dtype) - * 1000.0 - if embedded_guidance_scale is not None - else None - ) + guidance_expand = torch.tensor([guidance_scale] * latents.shape[0], dtype=torch.float32, device=device).to(target_dtype) * 1000.0 noise_pred = self.transformer( # For an input image (129, 192, 336) (1, 256, 256) - latent_model_input, # [2, 16, 33, 24, 42] + latents, # [2, 16, 33, 24, 42] timestep, # [2] text_states=prompt_embeds, # [2, 256, 4096] text_mask=prompt_mask, # [2, 256] text_states_2=prompt_embeds_2, # [2, 768] - freqs_cos=freqs_cis[0], # [seqlen, head_dim] - freqs_sin=freqs_cis[1], # [seqlen, head_dim] + freqs_cos=image_rotary_emb[0], # [seqlen, head_dim] + freqs_sin=image_rotary_emb[1], # [seqlen, head_dim] guidance=guidance_expand, return_dict=True, )["x"] - # perform guidance - if self.do_classifier_free_guidance: - noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) - - if self.do_classifier_free_guidance and self.guidance_rescale > 0.0: - # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf - noise_pred = rescale_noise_cfg( - noise_pred, - noise_pred_text, - guidance_rescale=self.guidance_rescale, - ) - # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] - torch.save(latents, f"latents_{i}.pt") + torch.save(latents, f"diffusers_refactor_latents_{i}.pt") if callback_on_step_end is not None: callback_kwargs = {} From 0894c3842e6a8787b5fda8b6ffd9f0820ce642a2 Mon Sep 17 00:00:00 2001 From: Aryan Date: Tue, 10 Dec 2024 22:43:42 +0100 Subject: [PATCH 07/58] begin conversion script --- scripts/convert_hunyuan_video_to_diffusers.py | 96 +++++++++++++++++++ .../transformers/transformer_hunyuan_video.py | 50 +++++----- .../hunyuan_video/pipeline_hunyuan_video.py | 11 +-- 3 files changed, 122 insertions(+), 35 deletions(-) create mode 100644 scripts/convert_hunyuan_video_to_diffusers.py diff --git a/scripts/convert_hunyuan_video_to_diffusers.py b/scripts/convert_hunyuan_video_to_diffusers.py new file mode 100644 index 000000000000..c59a2c608095 --- /dev/null +++ b/scripts/convert_hunyuan_video_to_diffusers.py @@ -0,0 +1,96 @@ +import argparse +from typing import Any, Dict + +import torch +from accelerate import init_empty_weights + +from diffusers import HunyuanVideoTransformer3DModel + + +TRANSFORMER_KEYS_RENAME_DICT = { + # "time_in.mlp.0": "time_text_embed.timestep_embedder.linear_1", + # "time_in.mlp.2": "time_text_embed.timestep_embedder.linear_2", + # "guidance_in.mlp.0": "time_text_embed.guidance_embedder.linear_1", + # "guidance_in.mlp.2": "time_text_embed.guidance_embedder.linear_2", + # "vector_in.in_layer": "time_text_embed.text_embedder.linear_1", + # "vector_in.out_layer": "time_text_embed.text_embedder.linear_2", + "double_blocks": "transformer_blocks", + "single_blocks": "single_transformer_blocks", +} + +TRANSFORMER_SPECIAL_KEYS_REMAP = {} + +VAE_KEYS_RENAME_DICT = {} + +VAE_SPECIAL_KEYS_REMAP = {} + + +def update_state_dict_(state_dict: Dict[str, Any], old_key: str, new_key: str) -> Dict[str, Any]: + state_dict[new_key] = state_dict.pop(old_key) + + +def get_state_dict(saved_dict: Dict[str, Any]) -> Dict[str, Any]: + state_dict = saved_dict + if "model" in saved_dict.keys(): + state_dict = state_dict["model"] + if "module" in saved_dict.keys(): + state_dict = state_dict["module"] + if "state_dict" in saved_dict.keys(): + state_dict = state_dict["state_dict"] + return state_dict + + +def convert_transformer(ckpt_path: str): + original_state_dict = get_state_dict(torch.load(ckpt_path, map_location="cpu", weights_only=True)) + + with init_empty_weights(): + transformer = HunyuanVideoTransformer3DModel() + + for key in list(original_state_dict.keys()): + new_key = key[:] + for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items(): + new_key = new_key.replace(replace_key, rename_key) + update_state_dict_(original_state_dict, key, new_key) + + for key in list(original_state_dict.keys()): + for special_key, handler_fn_inplace in TRANSFORMER_SPECIAL_KEYS_REMAP.items(): + if special_key not in key: + continue + handler_fn_inplace(key, original_state_dict) + + transformer.load_state_dict(original_state_dict, strict=True, assign=True) + return transformer + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--transformer_ckpt_path", type=str, default=None, help="Path to original transformer checkpoint" + ) + parser.add_argument("--save_pipeline", action="store_true") + parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved") + parser.add_argument("--dtype", default="bf16", help="Torch dtype to save the model in.") + return parser.parse_args() + +DTYPE_MAPPING = { + "fp32": torch.float32, + "fp16": torch.float16, + "bf16": torch.bfloat16, +} + + +if __name__ == "__main__": + args = get_args() + + transformer = None + dtype = DTYPE_MAPPING[args.dtype] + + if args.save_pipeline: + assert args.transformer_ckpt_path is not None and args.vae_ckpt_path is not None + + if args.transformer_ckpt_path is not None: + transformer = convert_transformer(args.transformer_ckpt_path) + if not args.save_pipeline: + transformer.save_pretrained( + args.output_path, safe_serialization=True, max_shard_size="5GB" + ) diff --git a/src/diffusers/models/transformers/transformer_hunyuan_video.py b/src/diffusers/models/transformers/transformer_hunyuan_video.py index d4c4025d1985..68a342a60862 100644 --- a/src/diffusers/models/transformers/transformer_hunyuan_video.py +++ b/src/diffusers/models/transformers/transformer_hunyuan_video.py @@ -1019,10 +1019,10 @@ class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin): def __init__( self, patch_size: list = [1, 2, 2], - in_channels: int = 4, # Should be VAE.config.latent_channels. - out_channels: int = None, - hidden_size: int = 3072, - heads_num: int = 24, + in_channels: int = 16, # Should be VAE.config.latent_channels. + out_channels: int = 16, + num_attention_heads: int = 24, + attention_head_dim: int = 128, mlp_width_ratio: float = 4.0, mlp_act_type: str = "gelu_tanh", mm_double_blocks_depth: int = 20, @@ -1031,12 +1031,13 @@ def __init__( qkv_bias: bool = True, qk_norm: bool = True, qk_norm_type: str = "rms", - guidance_embed: bool = False, # For modulation. + guidance_embed: bool = True, text_states_dim: int = 4096, text_states_dim_2: int = 768, ): super().__init__() + inner_dim = num_attention_heads * attention_head_dim self.patch_size = patch_size self.in_channels = in_channels self.out_channels = in_channels if out_channels is None else out_channels @@ -1044,39 +1045,34 @@ def __init__( self.guidance_embed = guidance_embed self.rope_dim_list = rope_dim_list - if hidden_size % heads_num != 0: - raise ValueError(f"Hidden size {hidden_size} must be divisible by heads_num {heads_num}") - pe_dim = hidden_size // heads_num - if sum(rope_dim_list) != pe_dim: - raise ValueError(f"Got {rope_dim_list} but expected positional dim {pe_dim}") - self.hidden_size = hidden_size - self.heads_num = heads_num + if sum(rope_dim_list) != attention_head_dim: + raise ValueError(f"Got {rope_dim_list} but expected positional dim {attention_head_dim}") # image projection - self.img_in = PatchEmbed(self.patch_size, self.in_channels, self.hidden_size) + self.img_in = PatchEmbed(self.patch_size, self.in_channels, inner_dim) # text projection - self.txt_in = SingleTokenRefiner(text_states_dim, hidden_size, heads_num, depth=2) + self.txt_in = SingleTokenRefiner(text_states_dim, inner_dim, num_attention_heads, depth=2) # time modulation - self.time_in = TimestepEmbedder(self.hidden_size, get_activation_layer("silu")) + self.time_in = TimestepEmbedder(inner_dim, get_activation_layer("silu")) # text modulation - self.vector_in = MLPEmbedder(text_states_dim_2, self.hidden_size) + self.vector_in = MLPEmbedder(text_states_dim_2, inner_dim) # guidance modulation self.guidance_in = ( - TimestepEmbedder(self.hidden_size, get_activation_layer("silu")) + TimestepEmbedder(inner_dim, get_activation_layer("silu")) if guidance_embed else None ) # double blocks - self.double_blocks = nn.ModuleList( + self.transformer_blocks = nn.ModuleList( [ HunyuanVideoDoubleStreamBlock( - self.hidden_size, - self.heads_num, + inner_dim, + num_attention_heads, mlp_width_ratio=mlp_width_ratio, mlp_act_type=mlp_act_type, qk_norm=qk_norm, @@ -1088,11 +1084,11 @@ def __init__( ) # single blocks - self.single_blocks = nn.ModuleList( + self.single_transformer_blocks = nn.ModuleList( [ HunyuanVideoSingleStreamBlock( - self.hidden_size, - self.heads_num, + inner_dim, + num_attention_heads, mlp_width_ratio=mlp_width_ratio, mlp_act_type=mlp_act_type, qk_norm=qk_norm, @@ -1103,7 +1099,7 @@ def __init__( ) self.final_layer = FinalLayer( - self.hidden_size, + inner_dim, self.patch_size, self.out_channels, get_activation_layer("silu"), @@ -1154,7 +1150,7 @@ def forward( freqs_cis = (freqs_cos, freqs_sin) if freqs_cos is not None else None # --------------------- Pass through DiT blocks ------------------------ - for _, block in enumerate(self.double_blocks): + for _, block in enumerate(self.transformer_blocks): double_block_args = [ img, txt, @@ -1166,8 +1162,8 @@ def forward( # Merge txt and img to pass through single stream blocks. x = torch.cat((img, txt), 1) - if len(self.single_blocks) > 0: - for _, block in enumerate(self.single_blocks): + if len(self.single_transformer_blocks) > 0: + for _, block in enumerate(self.single_transformer_blocks): single_block_args = [ x, vec, diff --git a/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py b/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py index dbd2393a55ac..a1b0e10652af 100644 --- a/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +++ b/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py @@ -554,16 +554,9 @@ def get_nd_rotary_pos_embed( if len(rope_sizes) != target_ndim: rope_sizes = [1] * (target_ndim - len(rope_sizes)) + rope_sizes # time axis - head_dim = self.transformer.config.hidden_size // self.transformer.config.heads_num - rope_dim_list = self.transformer.config.rope_dim_list - if rope_dim_list is None: - rope_dim_list = [head_dim // target_ndim for _ in range(target_ndim)] - assert ( - sum(rope_dim_list) == head_dim - ), "sum(rope_dim_list) should equal to head_dim of attention layer" freqs_cos, freqs_sin = get_nd_rotary_pos_embed( - rope_dim_list, + self.transformer.config.rope_dim_list, rope_sizes, theta=256, use_real=True, @@ -862,6 +855,8 @@ def __call__( else: image = self.vae.decode(latents, return_dict=False, generator=generator)[0] + torch.save(image, "diffusers_latents_decoded.pt") + if expand_temporal_dim or image.shape[2] == 1: image = image.squeeze(2) From a159e58bcc8b15393514b29ed9409b6e21b0fb44 Mon Sep 17 00:00:00 2001 From: Aryan Date: Tue, 10 Dec 2024 23:17:39 +0100 Subject: [PATCH 08/58] make style --- scripts/convert_hunyuan_video_to_diffusers.py | 5 +- .../transformers/transformer_hunyuan_video.py | 69 ++++-------- .../hunyuan_video/pipeline_hunyuan_video.py | 106 +++++++----------- .../pipelines/hunyuan_video/text_encoder.py | 3 +- 4 files changed, 65 insertions(+), 118 deletions(-) diff --git a/scripts/convert_hunyuan_video_to_diffusers.py b/scripts/convert_hunyuan_video_to_diffusers.py index c59a2c608095..4689a815153b 100644 --- a/scripts/convert_hunyuan_video_to_diffusers.py +++ b/scripts/convert_hunyuan_video_to_diffusers.py @@ -72,6 +72,7 @@ def get_args(): parser.add_argument("--dtype", default="bf16", help="Torch dtype to save the model in.") return parser.parse_args() + DTYPE_MAPPING = { "fp32": torch.float32, "fp16": torch.float16, @@ -91,6 +92,4 @@ def get_args(): if args.transformer_ckpt_path is not None: transformer = convert_transformer(args.transformer_ckpt_path) if not args.save_pipeline: - transformer.save_pretrained( - args.output_path, safe_serialization=True, max_shard_size="5GB" - ) + transformer.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB") diff --git a/src/diffusers/models/transformers/transformer_hunyuan_video.py b/src/diffusers/models/transformers/transformer_hunyuan_video.py index 68a342a60862..cc1240df0b78 100644 --- a/src/diffusers/models/transformers/transformer_hunyuan_video.py +++ b/src/diffusers/models/transformers/transformer_hunyuan_video.py @@ -16,7 +16,7 @@ import itertools import math from functools import partial -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Callable, Dict, List, Optional, Tuple, Union import torch import torch.nn as nn @@ -57,7 +57,6 @@ def attention( drop_rate=0, attn_mask=None, causal=False, - batch_size=1, ): pre_attn_layout, post_attn_layout = MEMORY_LAYOUT[mode] q = pre_attn_layout(q) @@ -445,9 +444,7 @@ def __init__( self.patch_size = patch_size self.flatten = flatten - self.proj = nn.Conv3d( - in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias - ) + self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias) nn.init.xavier_uniform_(self.proj.weight.view(self.proj.weight.size(0), -1)) if bias: nn.init.zeros_(self.proj.bias) @@ -719,7 +716,7 @@ def forward( mask_float = mask.float().unsqueeze(-1) # [b, s1, 1] context_aware_representations = (x * mask_float).sum(dim=1) / mask_float.sum(dim=1) context_aware_representations = context_aware_representations.to(original_dtype) - + context_aware_representations = self.c_embedder(context_aware_representations) c = timestep_aware_representations + context_aware_representations @@ -753,21 +750,13 @@ def __init__( head_dim = hidden_size // heads_num mlp_hidden_dim = int(hidden_size * mlp_width_ratio) - self.img_mod = ModulateDiT( - hidden_size, - factor=6, - act_layer=get_activation_layer("silu") - ) + self.img_mod = ModulateDiT(hidden_size, factor=6, act_layer=get_activation_layer("silu")) self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) self.img_attn_qkv = nn.Linear(hidden_size, hidden_size * 3, bias=qkv_bias) qk_norm_layer = get_norm_layer(qk_norm_type) - self.img_attn_q_norm = ( - qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity() - ) - self.img_attn_k_norm = ( - qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity() - ) + self.img_attn_q_norm = qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity() + self.img_attn_k_norm = qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity() self.img_attn_proj = nn.Linear(hidden_size, hidden_size, bias=qkv_bias) self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) @@ -786,12 +775,8 @@ def __init__( self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) self.txt_attn_qkv = nn.Linear(hidden_size, hidden_size * 3, bias=qkv_bias) - self.txt_attn_q_norm = ( - qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity() - ) - self.txt_attn_k_norm = ( - qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity() - ) + self.txt_attn_q_norm = qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity() + self.txt_attn_k_norm = qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity() self.txt_attn_proj = nn.Linear(hidden_size, hidden_size, bias=qkv_bias) self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) @@ -856,23 +841,22 @@ def forward( q = torch.cat((img_q, txt_q), dim=1) k = torch.cat((img_k, txt_k), dim=1) v = torch.cat((img_v, txt_v), dim=1) - attn = attention( - q, - k, - v, - batch_size=img_k.shape[0], - ) + attn = attention(q, k, v) img_attn, txt_attn = attn[:, : img.shape[1]], attn[:, img.shape[1] :] # Calculate the img bloks. img = img + self.img_attn_proj(img_attn) * img_mod1_gate.unsqueeze(1) - img = img + self.img_mlp(modulate(self.img_norm2(img), shift=img_mod2_shift, scale=img_mod2_scale)) * img_mod2_gate.unsqueeze(1) + img = img + self.img_mlp( + modulate(self.img_norm2(img), shift=img_mod2_shift, scale=img_mod2_scale) + ) * img_mod2_gate.unsqueeze(1) # Calculate the txt bloks. txt = txt + self.txt_attn_proj(txt_attn) * txt_mod1_gate.unsqueeze(1) - txt = txt + self.txt_mlp(modulate(self.txt_norm2(txt), shift=txt_mod2_shift, scale=txt_mod2_scale)) * txt_mod2_gate.unsqueeze(1) - + txt = txt + self.txt_mlp( + modulate(self.txt_norm2(txt), shift=txt_mod2_shift, scale=txt_mod2_scale) + ) * txt_mod2_gate.unsqueeze(1) + return img, txt @@ -909,12 +893,8 @@ def __init__( self.linear2 = nn.Linear(hidden_size + mlp_hidden_dim, hidden_size) qk_norm_layer = get_norm_layer(qk_norm_type) - self.q_norm = ( - qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity() - ) - self.k_norm = ( - qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity() - ) + self.q_norm = qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity() + self.k_norm = qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity() self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) @@ -954,12 +934,7 @@ def forward( q = torch.cat((img_q, txt_q), dim=1) k = torch.cat((img_k, txt_k), dim=1) - attn = attention( - q, - k, - v, - batch_size=x.shape[0], - ) + attn = attention(q, k, v) # Compute activation in mlp stream, cat again and run second linear layer. output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2)) @@ -1061,11 +1036,7 @@ def __init__( self.vector_in = MLPEmbedder(text_states_dim_2, inner_dim) # guidance modulation - self.guidance_in = ( - TimestepEmbedder(inner_dim, get_activation_layer("silu")) - if guidance_embed - else None - ) + self.guidance_in = TimestepEmbedder(inner_dim, get_activation_layer("silu")) if guidance_embed else None # double blocks self.transformer_blocks = nn.ModuleList( diff --git a/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py b/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py index a1b0e10652af..69719eb0ac0c 100644 --- a/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +++ b/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py @@ -14,7 +14,7 @@ import inspect from dataclasses import dataclass -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Callable, Dict, List, Optional, Tuple, Union import numpy as np import torch @@ -25,7 +25,6 @@ from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( BaseOutput, - deprecate, logging, replace_example_docstring, ) @@ -189,13 +188,6 @@ def encode_prompt( if text_encoder is None: text_encoder = self.text_encoder - if prompt is not None and isinstance(prompt, str): - batch_size = 1 - elif prompt is not None and isinstance(prompt, list): - batch_size = len(prompt) - else: - batch_size = prompt_embeds.shape[0] - if prompt_embeds is None: text_inputs = text_encoder.text2tokens(prompt, data_type=data_type) @@ -203,7 +195,7 @@ def encode_prompt( prompt_outputs = text_encoder.encode(text_inputs, data_type=data_type, device=device) # TODO(aryan): Don't know why it doesn't work without this torch.cuda.synchronize() - + prompt_embeds = prompt_outputs.hidden_state else: prompt_outputs = text_encoder.encode( @@ -359,7 +351,7 @@ def get_guidance_scale_embedding( emb = torch.nn.functional.pad(emb, (0, 1)) assert emb.shape == (w.shape[0], embedding_dim) return emb - + def get_rotary_pos_embed(self, video_length, height, width): def _to_tuple(x, dim=2): if isinstance(x, int): @@ -369,15 +361,15 @@ def _to_tuple(x, dim=2): else: raise ValueError(f"Expected length {dim} or int, but got {x}") - def get_meshgrid_nd(start, *args, dim=2): """ Get n-D meshgrid with start, stop and num. Args: - start (int or tuple): If len(args) == 0, start is num; If len(args) == 1, start is start, args[0] is stop, - step is 1; If len(args) == 2, start is start, args[0] is stop, args[1] is num. For n-dim, start/stop/num - should be int or n-tuple. If n-tuple is provided, the meshgrid will be stacked following the dim order in + start (int or tuple): + If len(args) == 0, start is num; If len(args) == 1, start is start, args[0] is stop, step is 1; If + len(args) == 2, start is start, args[0] is stop, args[1] is num. For n-dim, start/stop/num should + be int or n-tuple. If n-tuple is provided, the meshgrid will be stacked following the dim order in n-tuples. *args: See above. dim (int): Dimension of the meshgrid. Defaults to 2. @@ -423,12 +415,12 @@ def get_1d_rotary_pos_embed( interpolation_factor: float = 1.0, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: """ - Precompute the frequency tensor for complex exponential (cis) with given dimensions. - (Note: `cis` means `cos + i * sin`, where i is the imaginary unit.) + Precompute the frequency tensor for complex exponential (cis) with given dimensions. (Note: `cis` means + `cos + i * sin`, where i is the imaginary unit.) - This function calculates a frequency tensor with complex exponential using the given dimension 'dim' - and the end index 'end'. The 'theta' parameter scales the frequencies. - The returned tensor contains complex values in complex64 data type. + This function calculates a frequency tensor with complex exponential using the given dimension 'dim' and + the end index 'end'. The 'theta' parameter scales the frequencies. The returned tensor contains complex + values in complex64 data type. Args: dim (int): Dimension of the frequency tensor. @@ -439,8 +431,8 @@ def get_1d_rotary_pos_embed( theta_rescale_factor (float, optional): Rescale factor for theta. Defaults to 1.0. Returns: - freqs_cis: Precomputed frequency tensor with complex exponential. [S, D/2] - freqs_cos, freqs_sin: Precomputed frequency tensor with real and imaginary parts separately. [S, D] + freqs_cis: Precomputed frequency tensor with complex exponential. [S, D/2] freqs_cos, freqs_sin: + Precomputed frequency tensor with real and imaginary parts separately. [S, D] """ if isinstance(pos, int): pos = torch.arange(pos).float() @@ -450,9 +442,7 @@ def get_1d_rotary_pos_embed( if theta_rescale_factor != 1.0: theta *= theta_rescale_factor ** (dim / (dim - 2)) - freqs = 1.0 / ( - theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim) - ) # [D/2] + freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) # [D/2] # assert interpolation_factor == 1.0, f"interpolation_factor: {interpolation_factor}" freqs = torch.outer(pos * interpolation_factor, freqs) # [S, D/2] if use_real: @@ -460,12 +450,9 @@ def get_1d_rotary_pos_embed( freqs_sin = freqs.sin().repeat_interleave(2, dim=1) # [S, D] return freqs_cos, freqs_sin else: - freqs_cis = torch.polar( - torch.ones_like(freqs), freqs - ) # complex64 # [S, D/2] + freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2] return freqs_cis - def get_nd_rotary_pos_embed( rope_dim_list, start, @@ -481,12 +468,14 @@ def get_nd_rotary_pos_embed( Args: rope_dim_list (list of int): Dimension of each rope. len(rope_dim_list) should equal to n. sum(rope_dim_list) should equal to head_dim of attention layer. - start (int | tuple of int | list of int): If len(args) == 0, start is num; If len(args) == 1, start is start, - args[0] is stop, step is 1; If len(args) == 2, start is start, args[0] is stop, args[1] is num. + start (int | tuple of int | list of int): + If len(args) == 0, start is num; If len(args) == 1, start is start, args[0] is stop, step is 1; If + len(args) == 2, start is start, args[0] is stop, args[1] is num. *args: See above. theta (float): Scaling factor for frequency computation. Defaults to 10000.0. - use_real (bool): If True, return real part and imaginary part separately. Otherwise, return complex numbers. - Some libraries such as TensorRT does not support complex64 data type. So it is useful to provide a real + use_real (bool): + If True, return real part and imaginary part separately. Otherwise, return complex numbers. Some + libraries such as TensorRT does not support complex64 data type. So it is useful to provide a real part and an imaginary part separately. theta_rescale_factor (float): Rescale factor for theta. Defaults to 1.0. @@ -494,9 +483,7 @@ def get_nd_rotary_pos_embed( pos_embed (torch.Tensor): [HW, D/2] """ - grid = get_meshgrid_nd( - start, *args, dim=len(rope_dim_list) - ) # [3, W, H, D] / [2, W, H] + grid = get_meshgrid_nd(start, *args, dim=len(rope_dim_list)) # [3, W, H, D] / [2, W, H] if isinstance(theta_rescale_factor, int) or isinstance(theta_rescale_factor, float): theta_rescale_factor = [theta_rescale_factor] * len(rope_dim_list) @@ -534,27 +521,21 @@ def get_nd_rotary_pos_embed( else: emb = torch.cat(embs, dim=1) # (WHD, D/2) return emb - - + target_ndim = 3 ndim = 5 - 2 # 884 latents_size = [(video_length - 1) // 4 + 1, height // 8, width // 8] - assert all( - s % self.transformer.config.patch_size[idx] == 0 - for idx, s in enumerate(latents_size) - ), ( + assert all(s % self.transformer.config.patch_size[idx] == 0 for idx, s in enumerate(latents_size)), ( f"Latent size(last {ndim} dimensions) should be divisible by patch size ({self.transformer.config.patch_size}), " f"but got {latents_size}." ) - rope_sizes = [ - s // self.transformer.config.patch_size[idx] for idx, s in enumerate(latents_size) - ] + rope_sizes = [s // self.transformer.config.patch_size[idx] for idx, s in enumerate(latents_size)] if len(rope_sizes) != target_ndim: rope_sizes = [1] * (target_ndim - len(rope_sizes)) + rope_sizes # time axis - + freqs_cos, freqs_sin = get_nd_rotary_pos_embed( self.transformer.config.rope_dim_list, rope_sizes, @@ -562,7 +543,7 @@ def get_nd_rotary_pos_embed( use_real=True, theta_rescale_factor=1, ) - + return freqs_cos, freqs_sin @property @@ -573,10 +554,6 @@ def guidance_scale(self): def clip_skip(self): return self._clip_skip - @property - def guidance_scale(self): - return self._guidance_scale - @property def num_timesteps(self): return self._num_timesteps @@ -609,11 +586,7 @@ def __call__( return_dict: bool = True, clip_skip: Optional[int] = None, callback_on_step_end: Optional[ - Union[ - Callable[[int, int, Dict], None], - PipelineCallback, - MultiPipelineCallbacks - ] + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] ] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], enable_tiling: bool = False, @@ -647,11 +620,12 @@ def __call__( will be used. guidance_scale (`float`, defaults to `6.0`): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). - `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). - Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate - images that are closely linked to the text `prompt`, usually at the expense of lower image quality. - Note that the only available HunyuanVideo model is CFG-distilled, which means that traditional guidance - between unconditional and conditional latent is not applied. + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. Note that the only available HunyuanVideo model is + CFG-distilled, which means that traditional guidance between unconditional and conditional latent is + not applied. num_videos_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): @@ -797,8 +771,13 @@ def __call__( # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = t.expand(latents.shape[0]).to(latents.dtype) - - guidance_expand = torch.tensor([guidance_scale] * latents.shape[0], dtype=torch.float32, device=device).to(target_dtype) * 1000.0 + + guidance_expand = ( + torch.tensor([guidance_scale] * latents.shape[0], dtype=torch.float32, device=device).to( + target_dtype + ) + * 1000.0 + ) noise_pred = self.transformer( # For an input image (129, 192, 336) (1, 256, 256) latents, # [2, 16, 33, 24, 42] @@ -825,7 +804,6 @@ def __call__( latents = callback_outputs.pop("latents", latents) prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) - negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): diff --git a/src/diffusers/pipelines/hunyuan_video/text_encoder.py b/src/diffusers/pipelines/hunyuan_video/text_encoder.py index 4a259cafa5f7..9acfcab85767 100644 --- a/src/diffusers/pipelines/hunyuan_video/text_encoder.py +++ b/src/diffusers/pipelines/hunyuan_video/text_encoder.py @@ -1,4 +1,3 @@ -import os from dataclasses import dataclass from typing import Optional, Tuple @@ -272,7 +271,7 @@ def encode( # No idea why it doesn't work without this torch.cuda.synchronize() - + outputs = self.model( input_ids=input_ids, attention_mask=attention_mask, From 5bce9384fe79d6d74854948e1387ca0d66ecd181 Mon Sep 17 00:00:00 2001 From: Aryan Date: Tue, 10 Dec 2024 23:22:37 +0100 Subject: [PATCH 09/58] refactor attention --- .../transformers/transformer_hunyuan_video.py | 44 ++++--------------- 1 file changed, 8 insertions(+), 36 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_hunyuan_video.py b/src/diffusers/models/transformers/transformer_hunyuan_video.py index cc1240df0b78..c6c3d408efed 100644 --- a/src/diffusers/models/transformers/transformer_hunyuan_video.py +++ b/src/diffusers/models/transformers/transformer_hunyuan_video.py @@ -27,48 +27,23 @@ from diffusers.models import ModelMixin -try: - from flash_attn.flash_attn_interface import flash_attn_varlen_func -except ImportError: - flash_attn_varlen_func = None - - -MEMORY_LAYOUT = { - "flash": ( - lambda x: x.view(x.shape[0] * x.shape[1], *x.shape[2:]), - lambda x: x, - ), - "torch": ( - lambda x: x.transpose(1, 2), - lambda x: x.transpose(1, 2), - ), - "vanilla": ( - lambda x: x.transpose(1, 2), - lambda x: x.transpose(1, 2), - ), -} - - def attention( q, k, v, - mode="torch", drop_rate=0, attn_mask=None, causal=False, ): - pre_attn_layout, post_attn_layout = MEMORY_LAYOUT[mode] - q = pre_attn_layout(q) - k = pre_attn_layout(k) - v = pre_attn_layout(v) + q = q.transpose(1, 2) + k = k.transpose(1, 2) + v = v.transpose(1, 2) - if mode == "torch": - if attn_mask is not None and attn_mask.dtype != torch.bool: - attn_mask = attn_mask.to(q.dtype) - x = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, dropout_p=drop_rate, is_causal=causal) + if attn_mask is not None and attn_mask.dtype != torch.bool: + attn_mask = attn_mask.to(q.dtype) + x = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, dropout_p=drop_rate, is_causal=causal) - x = post_attn_layout(x) + x = x.transpose(1, 2) b, s, a, d = x.shape out = x.reshape(b, s, -1) return out @@ -593,7 +568,7 @@ def forward( k = self.self_attn_k_norm(k).to(v) # Self-Attention - attn = attention(q, k, v, mode="torch", attn_mask=attn_mask) + attn = attention(q, k, v, attn_mask=attn_mask) x = x + self.self_attn_proj(attn) * gate_msa.unsqueeze(1) @@ -675,11 +650,8 @@ def __init__( qk_norm: bool = False, qk_norm_type: str = "layer", qkv_bias: bool = True, - attn_mode: str = "torch", ): super().__init__() - self.attn_mode = attn_mode - assert self.attn_mode == "torch", "Only support 'torch' mode for token refiner." self.input_embedder = nn.Linear(in_channels, hidden_size, bias=True) From 491a5b4c1abc533780f20375ca8d4c5ff5798404 Mon Sep 17 00:00:00 2001 From: Aryan Date: Tue, 10 Dec 2024 23:50:14 +0100 Subject: [PATCH 10/58] refactor --- .../transformers/transformer_hunyuan_video.py | 147 +++++++----------- .../hunyuan_video/pipeline_hunyuan_video.py | 33 ++-- 2 files changed, 63 insertions(+), 117 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_hunyuan_video.py b/src/diffusers/models/transformers/transformer_hunyuan_video.py index c6c3d408efed..09173773c0aa 100644 --- a/src/diffusers/models/transformers/transformer_hunyuan_video.py +++ b/src/diffusers/models/transformers/transformer_hunyuan_video.py @@ -23,8 +23,9 @@ import torch.nn.functional as F from einops import rearrange -from diffusers.configuration_utils import ConfigMixin, register_to_config -from diffusers.models import ModelMixin +from ...configuration_utils import ConfigMixin, register_to_config +from ..modeling_utils import ModelMixin +from ..modeling_outputs import Transformer2DModelOutput def attention( @@ -49,22 +50,6 @@ def attention( return out -def _ntuple(n): - def parse(x): - if isinstance(x, collections.abc.Iterable) and not isinstance(x, str): - x = tuple(x) - if len(x) == 1: - x = tuple(itertools.repeat(x[0], n)) - return x - return tuple(itertools.repeat(x, n)) - - return parse - - -to_1tuple = _ntuple(1) -to_2tuple = _ntuple(2) - - def get_activation_layer(act_type): """get activation layer @@ -324,16 +309,14 @@ def __init__( super().__init__() out_features = out_features or in_channels hidden_channels = hidden_channels or in_channels - bias = to_2tuple(bias) - drop_probs = to_2tuple(drop) linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear - self.fc1 = linear_layer(in_channels, hidden_channels, bias=bias[0]) + self.fc1 = linear_layer(in_channels, hidden_channels, bias=bias) self.act = act_layer() - self.drop1 = nn.Dropout(drop_probs[0]) + self.drop1 = nn.Dropout(drop) self.norm = norm_layer(hidden_channels) if norm_layer is not None else nn.Identity() - self.fc2 = linear_layer(hidden_channels, out_features, bias=bias[1]) - self.drop2 = nn.Dropout(drop_probs[1]) + self.fc2 = linear_layer(hidden_channels, out_features, bias=bias) + self.drop2 = nn.Dropout(drop) def forward(self, x): x = self.fc1(x) @@ -415,15 +398,10 @@ def __init__( bias=True, ): super().__init__() - patch_size = to_2tuple(patch_size) - self.patch_size = patch_size + + patch_size = tuple(patch_size) self.flatten = flatten - self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias) - nn.init.xavier_uniform_(self.proj.weight.view(self.proj.weight.size(0), -1)) - if bias: - nn.init.zeros_(self.proj.bias) - self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() def forward(self, x): @@ -965,8 +943,9 @@ class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin): @register_to_config def __init__( self, - patch_size: list = [1, 2, 2], - in_channels: int = 16, # Should be VAE.config.latent_channels. + patch_size: int = 2, + patch_size_t: int = 1, + in_channels: int = 16, out_channels: int = 16, num_attention_heads: int = 24, attention_head_dim: int = 128, @@ -981,14 +960,12 @@ def __init__( guidance_embed: bool = True, text_states_dim: int = 4096, text_states_dim_2: int = 768, - ): + ) -> None: super().__init__() inner_dim = num_attention_heads * attention_head_dim - self.patch_size = patch_size self.in_channels = in_channels self.out_channels = in_channels if out_channels is None else out_channels - self.unpatchify_channels = self.out_channels self.guidance_embed = guidance_embed self.rope_dim_list = rope_dim_list @@ -996,7 +973,7 @@ def __init__( raise ValueError(f"Got {rope_dim_list} but expected positional dim {attention_head_dim}") # image projection - self.img_in = PatchEmbed(self.patch_size, self.in_channels, inner_dim) + self.img_in = PatchEmbed((patch_size_t, patch_size, patch_size), self.in_channels, inner_dim) # text projection self.txt_in = SingleTokenRefiner(text_states_dim, inner_dim, num_attention_heads, depth=2) @@ -1010,7 +987,6 @@ def __init__( # guidance modulation self.guidance_in = TimestepEmbedder(inner_dim, get_activation_layer("silu")) if guidance_embed else None - # double blocks self.transformer_blocks = nn.ModuleList( [ HunyuanVideoDoubleStreamBlock( @@ -1026,7 +1002,6 @@ def __init__( ] ) - # single blocks self.single_transformer_blocks = nn.ModuleList( [ HunyuanVideoSingleStreamBlock( @@ -1043,100 +1018,82 @@ def __init__( self.final_layer = FinalLayer( inner_dim, - self.patch_size, + (patch_size_t, patch_size, patch_size), self.out_channels, get_activation_layer("silu"), ) def forward( self, - x: torch.Tensor, - t: torch.Tensor, # Should be in range(0, 1000). - text_states: torch.Tensor = None, - text_mask: torch.Tensor = None, # Now we don't use it. - text_states_2: Optional[torch.Tensor] = None, # Text embedding for modulation. + hidden_states: torch.Tensor, + timestep: torch.Tensor, # Should be in range(0, 1000). + encoder_hidden_states: torch.Tensor = None, + encoder_attention_mask: torch.Tensor = None, # Now we don't use it. + encoder_hidden_states_2: Optional[torch.Tensor] = None, # Text embedding for modulation. freqs_cos: Optional[torch.Tensor] = None, freqs_sin: Optional[torch.Tensor] = None, - guidance: torch.Tensor = None, # Guidance for modulation, should be cfg_scale x 1000. + guidance: torch.Tensor = None, return_dict: bool = True, ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: - out = {} - img = x - txt = text_states - _, _, ot, oh, ow = x.shape - tt, th, tw = ( - ot // self.patch_size[0], - oh // self.patch_size[1], - ow // self.patch_size[2], - ) + batch_size, num_channels, num_frames, height, width = hidden_states.shape + p = self.config.patch_size + p_t = self.config.patch_size_t + post_patch_num_frames = num_frames // p_t + post_patch_height = height // p + post_patch_width = width // p # Prepare modulation vectors. - vec = self.time_in(t) + temb = self.time_in(timestep) # text modulation - vec = vec + self.vector_in(text_states_2) + temb = temb + self.vector_in(encoder_hidden_states_2) # guidance modulation if self.guidance_embed: if guidance is None: raise ValueError("Didn't get guidance strength for guidance distilled model.") - # our timestep_embedding is merged into guidance_in(TimestepEmbedder) - vec = vec + self.guidance_in(guidance) + temb = temb + self.guidance_in(guidance) # Embed image and text. - img = self.img_in(img) - txt = self.txt_in(txt, t, text_mask) + hidden_states = self.img_in(hidden_states) + encoder_hidden_states = self.txt_in(encoder_hidden_states, timestep, encoder_attention_mask) - txt_seq_len = txt.shape[1] - img_seq_len = img.shape[1] + txt_seq_len = encoder_hidden_states.shape[1] + img_seq_len = hidden_states.shape[1] freqs_cis = (freqs_cos, freqs_sin) if freqs_cos is not None else None - # --------------------- Pass through DiT blocks ------------------------ for _, block in enumerate(self.transformer_blocks): double_block_args = [ - img, - txt, - vec, + hidden_states, + encoder_hidden_states, + temb, freqs_cis, ] - img, txt = block(*double_block_args) + hidden_states, encoder_hidden_states = block(*double_block_args) - # Merge txt and img to pass through single stream blocks. - x = torch.cat((img, txt), 1) + hidden_states = torch.cat((hidden_states, encoder_hidden_states), 1) if len(self.single_transformer_blocks) > 0: for _, block in enumerate(self.single_transformer_blocks): single_block_args = [ - x, - vec, + hidden_states, + temb, txt_seq_len, (freqs_cos, freqs_sin), ] - x = block(*single_block_args) - - img = x[:, :img_seq_len, ...] - - # ---------------------------- Final layer ------------------------------ - img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels) + hidden_states = block(*single_block_args) - img = self.unpatchify(img, tt, th, tw) - if return_dict: - out["x"] = img - return out - return img - - def unpatchify(self, x, t, h, w): - """ - x: (N, T, patch_size**2 * C) imgs: (N, H, W, C) - """ - c = self.unpatchify_channels - pt, ph, pw = self.patch_size - assert t * h * w == x.shape[1] + hidden_states = hidden_states[:, :img_seq_len, ...] - x = x.reshape(shape=(x.shape[0], t, h, w, c, pt, ph, pw)) - x = torch.einsum("nthwcopq->nctohpwq", x) - imgs = x.reshape(shape=(x.shape[0], c, t * pt, h * ph, w * pw)) + hidden_states = self.final_layer(hidden_states, temb) - return imgs + hidden_states = hidden_states.reshape(batch_size, post_patch_num_frames, post_patch_height, post_patch_width, -1, p_t, p, p) + hidden_states = hidden_states.permute(0, 4, 1, 5, 2, 6, 3, 7) + hidden_states = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3) + + if not return_dict: + return (hidden_states,) + + return Transformer2DModelOutput(sample=hidden_states) diff --git a/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py b/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py index 69719eb0ac0c..475b228660b7 100644 --- a/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +++ b/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py @@ -522,19 +522,8 @@ def get_nd_rotary_pos_embed( emb = torch.cat(embs, dim=1) # (WHD, D/2) return emb - target_ndim = 3 - ndim = 5 - 2 - # 884 latents_size = [(video_length - 1) // 4 + 1, height // 8, width // 8] - - assert all(s % self.transformer.config.patch_size[idx] == 0 for idx, s in enumerate(latents_size)), ( - f"Latent size(last {ndim} dimensions) should be divisible by patch size ({self.transformer.config.patch_size}), " - f"but got {latents_size}." - ) - rope_sizes = [s // self.transformer.config.patch_size[idx] for idx, s in enumerate(latents_size)] - - if len(rope_sizes) != target_ndim: - rope_sizes = [1] * (target_ndim - len(rope_sizes)) + rope_sizes # time axis + rope_sizes = [latents_size[0] // self.transformer.config.patch_size_t, latents_size[1] // self.transformer.config.patch_size, latents_size[2] // self.transformer.config.patch_size] freqs_cos, freqs_sin = get_nd_rotary_pos_embed( self.transformer.config.rope_dim_list, @@ -779,17 +768,17 @@ def __call__( * 1000.0 ) - noise_pred = self.transformer( # For an input image (129, 192, 336) (1, 256, 256) - latents, # [2, 16, 33, 24, 42] - timestep, # [2] - text_states=prompt_embeds, # [2, 256, 4096] - text_mask=prompt_mask, # [2, 256] - text_states_2=prompt_embeds_2, # [2, 768] - freqs_cos=image_rotary_emb[0], # [seqlen, head_dim] - freqs_sin=image_rotary_emb[1], # [seqlen, head_dim] + noise_pred = self.transformer( + hidden_states=latents, + timestep=timestep, + encoder_hidden_states=prompt_embeds, + encoder_attention_mask=prompt_mask, + encoder_hidden_states_2=prompt_embeds_2, + freqs_cos=image_rotary_emb[0], + freqs_sin=image_rotary_emb[1], guidance=guidance_expand, - return_dict=True, - )["x"] + return_dict=False, + )[0] # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] From a47a710807e9705279ccdecbf5cad1aea7d25017 Mon Sep 17 00:00:00 2001 From: Aryan Date: Wed, 11 Dec 2024 11:08:17 +0100 Subject: [PATCH 11/58] refactor final layer --- scripts/convert_hunyuan_video_to_diffusers.py | 13 +++- .../transformers/transformer_hunyuan_video.py | 61 ++++--------------- 2 files changed, 23 insertions(+), 51 deletions(-) diff --git a/scripts/convert_hunyuan_video_to_diffusers.py b/scripts/convert_hunyuan_video_to_diffusers.py index 4689a815153b..ff535a456877 100644 --- a/scripts/convert_hunyuan_video_to_diffusers.py +++ b/scripts/convert_hunyuan_video_to_diffusers.py @@ -7,6 +7,13 @@ from diffusers import HunyuanVideoTransformer3DModel +def remap_norm_scale_shift_(key, state_dict): + weight = state_dict.pop(key) + shift, scale = weight.chunk(2, dim=0) + new_weight = torch.cat([scale, shift], dim=0) + state_dict[key.replace("final_layer.adaLN_modulation.1", "norm_out.linear")] = new_weight + + TRANSFORMER_KEYS_RENAME_DICT = { # "time_in.mlp.0": "time_text_embed.timestep_embedder.linear_1", # "time_in.mlp.2": "time_text_embed.timestep_embedder.linear_2", @@ -16,9 +23,13 @@ # "vector_in.out_layer": "time_text_embed.text_embedder.linear_2", "double_blocks": "transformer_blocks", "single_blocks": "single_transformer_blocks", + "final_layer.norm_final": "norm_out.norm", + "final_layer.linear": "proj_out" } -TRANSFORMER_SPECIAL_KEYS_REMAP = {} +TRANSFORMER_SPECIAL_KEYS_REMAP = { + "final_layer.adaLN_modulation.1": remap_norm_scale_shift_, +} VAE_KEYS_RENAME_DICT = {} diff --git a/src/diffusers/models/transformers/transformer_hunyuan_video.py b/src/diffusers/models/transformers/transformer_hunyuan_video.py index 09173773c0aa..e16d910de4d6 100644 --- a/src/diffusers/models/transformers/transformer_hunyuan_video.py +++ b/src/diffusers/models/transformers/transformer_hunyuan_video.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import collections.abc -import itertools import math from functools import partial from typing import Callable, Dict, List, Optional, Tuple, Union @@ -26,15 +24,14 @@ from ...configuration_utils import ConfigMixin, register_to_config from ..modeling_utils import ModelMixin from ..modeling_outputs import Transformer2DModelOutput +from ..normalization import AdaLayerNormContinuous def attention( q, k, v, - drop_rate=0, - attn_mask=None, - causal=False, + attn_mask=None ): q = q.transpose(1, 2) k = k.transpose(1, 2) @@ -42,7 +39,7 @@ def attention( if attn_mask is not None and attn_mask.dtype != torch.bool: attn_mask = attn_mask.to(q.dtype) - x = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, dropout_p=drop_rate, is_causal=causal) + x = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, dropout_p=0, is_causal=False) x = x.transpose(1, 2) b, s, a, d = x.shape @@ -342,38 +339,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return self.out_layer(self.silu(self.in_layer(x))) -class FinalLayer(nn.Module): - """The final layer of DiT.""" - - def __init__(self, hidden_size, patch_size, out_channels, act_layer): - super().__init__() - - # Just use LayerNorm for the final layer - self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) - if isinstance(patch_size, int): - self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True) - else: - self.linear = nn.Linear( - hidden_size, - patch_size[0] * patch_size[1] * patch_size[2] * out_channels, - bias=True, - ) - nn.init.zeros_(self.linear.weight) - nn.init.zeros_(self.linear.bias) - - # Here we don't distinguish between the modulate types. Just use the simple one. - self.adaLN_modulation = nn.Sequential( - act_layer(), - nn.Linear(hidden_size, 2 * hidden_size, bias=True), - ) - - def forward(self, x, c): - shift, scale = self.adaLN_modulation(c).chunk(2, dim=1) - x = modulate(self.norm_final(x), shift=shift, scale=scale) - x = self.linear(x) - return x - - class PatchEmbed(nn.Module): """2D Image to Patch Embedding @@ -620,7 +585,7 @@ def __init__( self, in_channels, hidden_size, - heads_num, + num_attention_heads, depth, mlp_width_ratio: float = 4.0, mlp_drop_rate: float = 0.0, @@ -641,7 +606,7 @@ def __init__( self.individual_token_refiner = IndividualTokenRefiner( hidden_size=hidden_size, - heads_num=heads_num, + heads_num=num_attention_heads, depth=depth, mlp_width_ratio=mlp_width_ratio, mlp_drop_rate=mlp_drop_rate, @@ -964,8 +929,7 @@ def __init__( super().__init__() inner_dim = num_attention_heads * attention_head_dim - self.in_channels = in_channels - self.out_channels = in_channels if out_channels is None else out_channels + out_channels = out_channels or in_channels self.guidance_embed = guidance_embed self.rope_dim_list = rope_dim_list @@ -973,7 +937,7 @@ def __init__( raise ValueError(f"Got {rope_dim_list} but expected positional dim {attention_head_dim}") # image projection - self.img_in = PatchEmbed((patch_size_t, patch_size, patch_size), self.in_channels, inner_dim) + self.img_in = PatchEmbed((patch_size_t, patch_size, patch_size), in_channels, inner_dim) # text projection self.txt_in = SingleTokenRefiner(text_states_dim, inner_dim, num_attention_heads, depth=2) @@ -1016,12 +980,8 @@ def __init__( ] ) - self.final_layer = FinalLayer( - inner_dim, - (patch_size_t, patch_size, patch_size), - self.out_channels, - get_activation_layer("silu"), - ) + self.norm_out = AdaLayerNormContinuous(inner_dim, inner_dim, elementwise_affine=False, eps=1e-6) + self.proj_out = nn.Linear(inner_dim, patch_size_t * patch_size * patch_size * out_channels) def forward( self, @@ -1087,7 +1047,8 @@ def forward( hidden_states = hidden_states[:, :img_seq_len, ...] - hidden_states = self.final_layer(hidden_states, temb) + hidden_states = self.norm_out(hidden_states, temb) + hidden_states = self.proj_out(hidden_states) hidden_states = hidden_states.reshape(batch_size, post_patch_num_frames, post_patch_height, post_patch_width, -1, p_t, p, p) hidden_states = hidden_states.permute(0, 4, 1, 5, 2, 6, 3, 7) From ee6880d2907ce44c1fb81337daf1d4afc5575443 Mon Sep 17 00:00:00 2001 From: Aryan Date: Wed, 11 Dec 2024 12:12:44 +0100 Subject: [PATCH 12/58] their mlp -> our feedforward --- scripts/convert_hunyuan_video_to_diffusers.py | 4 +- src/diffusers/models/activations.py | 12 +++ src/diffusers/models/attention.py | 4 +- .../transformers/transformer_hunyuan_video.py | 102 +++++------------- 4 files changed, 44 insertions(+), 78 deletions(-) diff --git a/scripts/convert_hunyuan_video_to_diffusers.py b/scripts/convert_hunyuan_video_to_diffusers.py index ff535a456877..ebcf11389bc2 100644 --- a/scripts/convert_hunyuan_video_to_diffusers.py +++ b/scripts/convert_hunyuan_video_to_diffusers.py @@ -24,7 +24,9 @@ def remap_norm_scale_shift_(key, state_dict): "double_blocks": "transformer_blocks", "single_blocks": "single_transformer_blocks", "final_layer.norm_final": "norm_out.norm", - "final_layer.linear": "proj_out" + "final_layer.linear": "proj_out", + "fc1": "net.0.proj", + "fc2": "net.2" } TRANSFORMER_SPECIAL_KEYS_REMAP = { diff --git a/src/diffusers/models/activations.py b/src/diffusers/models/activations.py index f4318fc3cd39..2a1be0b0c386 100644 --- a/src/diffusers/models/activations.py +++ b/src/diffusers/models/activations.py @@ -146,6 +146,18 @@ def forward(self, hidden_states): return hidden_states * self.activation(gate) +class SiLU(nn.Module): + def __init__(self, dim_in: int, dim_out: int, bias: bool = True): + super().__init__() + + self.proj = nn.Linear(dim_in, dim_out, bias=bias) + self.activation = nn.SiLU() + + def forward(self, hidden_states): + hidden_states = self.proj(hidden_states) + return self.activation(hidden_states) + + class ApproximateGELU(nn.Module): r""" The approximate form of the Gaussian Error Linear Unit (GELU). For more details, see section 2 of this diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 02ed1f965abf..bd79a5a5d4ff 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -19,7 +19,7 @@ from ..utils import deprecate, logging from ..utils.torch_utils import maybe_allow_in_graph -from .activations import GEGLU, GELU, ApproximateGELU, FP32SiLU, SwiGLU +from .activations import GEGLU, GELU, ApproximateGELU, FP32SiLU, SwiGLU, SiLU from .attention_processor import Attention, JointAttnProcessor2_0 from .embeddings import SinusoidalPositionalEmbedding from .normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero, RMSNorm, SD35AdaLayerNormZeroX @@ -1222,6 +1222,8 @@ def __init__( act_fn = ApproximateGELU(dim, inner_dim, bias=bias) elif activation_fn == "swiglu": act_fn = SwiGLU(dim, inner_dim, bias=bias) + elif activation_fn == "silu": + act_fn = SiLU(dim, inner_dim, bias=bias) self.net = nn.ModuleList([]) # project in diff --git a/src/diffusers/models/transformers/transformer_hunyuan_video.py b/src/diffusers/models/transformers/transformer_hunyuan_video.py index e16d910de4d6..dcbb24146e72 100644 --- a/src/diffusers/models/transformers/transformer_hunyuan_video.py +++ b/src/diffusers/models/transformers/transformer_hunyuan_video.py @@ -22,6 +22,7 @@ from einops import rearrange from ...configuration_utils import ConfigMixin, register_to_config +from ..attention import FeedForward from ..modeling_utils import ModelMixin from ..modeling_outputs import Transformer2DModelOutput from ..normalization import AdaLayerNormContinuous @@ -289,43 +290,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return self.linear(self.act(x)) -class MLP(nn.Module): - """MLP as used in Vision Transformer, MLP-Mixer and related networks""" - - def __init__( - self, - in_channels, - hidden_channels=None, - out_features=None, - act_layer=nn.GELU, - norm_layer=None, - bias=True, - drop=0.0, - use_conv=False, - ): - super().__init__() - out_features = out_features or in_channels - hidden_channels = hidden_channels or in_channels - linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear - - self.fc1 = linear_layer(in_channels, hidden_channels, bias=bias) - self.act = act_layer() - self.drop1 = nn.Dropout(drop) - self.norm = norm_layer(hidden_channels) if norm_layer is not None else nn.Identity() - self.fc2 = linear_layer(hidden_channels, out_features, bias=bias) - self.drop2 = nn.Dropout(drop) - - def forward(self, x): - x = self.fc1(x) - x = self.act(x) - x = self.drop1(x) - x = self.norm(x) - x = self.fc2(x) - x = self.drop2(x) - return x - - -# class MLPEmbedder(nn.Module): """copied from https://github.com/black-forest-labs/flux/blob/main/src/flux/modules/layers.py""" @@ -483,12 +447,8 @@ def __init__( self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6) act_layer = get_activation_layer(act_type) - self.mlp = MLP( - in_channels=hidden_size, - hidden_channels=mlp_hidden_dim, - act_layer=act_layer, - drop=mlp_drop_rate, - ) + + self.mlp = FeedForward(hidden_size, mult=mlp_width_ratio, activation_fn="silu", dropout=mlp_drop_rate) self.adaLN_modulation = nn.Sequential( act_layer(), @@ -498,7 +458,7 @@ def __init__( def forward( self, x: torch.Tensor, - c: torch.Tensor, # timestep_aware_representations + context_aware_representations + c: torch.Tensor, attn_mask: torch.Tensor = None, ): gate_msa, gate_mlp = self.adaLN_modulation(c).chunk(2, dim=1) @@ -675,12 +635,7 @@ def __init__( self.img_attn_proj = nn.Linear(hidden_size, hidden_size, bias=qkv_bias) self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) - self.img_mlp = MLP( - hidden_size, - mlp_hidden_dim, - act_layer=get_activation_layer(mlp_act_type), - bias=True, - ) + self.img_mlp = FeedForward(hidden_size, mult=mlp_width_ratio, activation_fn="gelu-approximate") self.txt_mod = ModulateDiT( hidden_size, @@ -695,19 +650,14 @@ def __init__( self.txt_attn_proj = nn.Linear(hidden_size, hidden_size, bias=qkv_bias) self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) - self.txt_mlp = MLP( - hidden_size, - mlp_hidden_dim, - act_layer=get_activation_layer(mlp_act_type), - bias=True, - ) + self.txt_mlp = FeedForward(hidden_size, mult=mlp_width_ratio, activation_fn="gelu-approximate") def forward( self, - img: torch.Tensor, - txt: torch.Tensor, - vec: torch.Tensor, - freqs_cis: tuple = None, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + temb: torch.Tensor, + freqs_cis: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: ( img_mod1_shift, @@ -716,7 +666,7 @@ def forward( img_mod2_shift, img_mod2_scale, img_mod2_gate, - ) = self.img_mod(vec).chunk(6, dim=-1) + ) = self.img_mod(temb).chunk(6, dim=-1) ( txt_mod1_shift, txt_mod1_scale, @@ -724,10 +674,10 @@ def forward( txt_mod2_shift, txt_mod2_scale, txt_mod2_gate, - ) = self.txt_mod(vec).chunk(6, dim=-1) + ) = self.txt_mod(temb).chunk(6, dim=-1) # Prepare image for attention. - img_modulated = self.img_norm1(img) + img_modulated = self.img_norm1(hidden_states) img_modulated = modulate(img_modulated, shift=img_mod1_shift, scale=img_mod1_scale) img_qkv = self.img_attn_qkv(img_modulated) img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num) @@ -744,7 +694,7 @@ def forward( img_q, img_k = img_qq, img_kk # Prepare txt for attention. - txt_modulated = self.txt_norm1(txt) + txt_modulated = self.txt_norm1(encoder_hidden_states) txt_modulated = modulate(txt_modulated, shift=txt_mod1_shift, scale=txt_mod1_scale) txt_qkv = self.txt_attn_qkv(txt_modulated) txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num) @@ -758,21 +708,21 @@ def forward( v = torch.cat((img_v, txt_v), dim=1) attn = attention(q, k, v) - img_attn, txt_attn = attn[:, : img.shape[1]], attn[:, img.shape[1] :] + img_attn, txt_attn = attn[:, : hidden_states.shape[1]], attn[:, hidden_states.shape[1] :] # Calculate the img bloks. - img = img + self.img_attn_proj(img_attn) * img_mod1_gate.unsqueeze(1) - img = img + self.img_mlp( - modulate(self.img_norm2(img), shift=img_mod2_shift, scale=img_mod2_scale) + hidden_states = hidden_states + self.img_attn_proj(img_attn) * img_mod1_gate.unsqueeze(1) + hidden_states = hidden_states + self.img_mlp( + modulate(self.img_norm2(hidden_states), shift=img_mod2_shift, scale=img_mod2_scale) ) * img_mod2_gate.unsqueeze(1) # Calculate the txt bloks. - txt = txt + self.txt_attn_proj(txt_attn) * txt_mod1_gate.unsqueeze(1) - txt = txt + self.txt_mlp( - modulate(self.txt_norm2(txt), shift=txt_mod2_shift, scale=txt_mod2_scale) + encoder_hidden_states = encoder_hidden_states + self.txt_attn_proj(txt_attn) * txt_mod1_gate.unsqueeze(1) + encoder_hidden_states = encoder_hidden_states + self.txt_mlp( + modulate(self.txt_norm2(encoder_hidden_states), shift=txt_mod2_shift, scale=txt_mod2_scale) ) * txt_mod2_gate.unsqueeze(1) - return img, txt + return hidden_states, encoder_hidden_states class HunyuanVideoSingleStreamBlock(nn.Module): @@ -986,10 +936,10 @@ def __init__( def forward( self, hidden_states: torch.Tensor, - timestep: torch.Tensor, # Should be in range(0, 1000). - encoder_hidden_states: torch.Tensor = None, - encoder_attention_mask: torch.Tensor = None, # Now we don't use it. - encoder_hidden_states_2: Optional[torch.Tensor] = None, # Text embedding for modulation. + timestep: torch.LongTensor, + encoder_hidden_states: torch.Tensor, + encoder_attention_mask: torch.Tensor, + encoder_hidden_states_2: torch.Tensor, freqs_cos: Optional[torch.Tensor] = None, freqs_sin: Optional[torch.Tensor] = None, guidance: torch.Tensor = None, From a23cfa1fe7a6f532902918a36c7406880b442917 Mon Sep 17 00:00:00 2001 From: Aryan Date: Wed, 11 Dec 2024 12:14:17 +0100 Subject: [PATCH 13/58] make style --- scripts/convert_hunyuan_video_to_diffusers.py | 2 +- src/diffusers/models/activations.py | 4 +-- src/diffusers/models/attention.py | 2 +- .../transformers/transformer_hunyuan_video.py | 26 +++++++------------ .../hunyuan_video/pipeline_hunyuan_video.py | 6 ++++- 5 files changed, 18 insertions(+), 22 deletions(-) diff --git a/scripts/convert_hunyuan_video_to_diffusers.py b/scripts/convert_hunyuan_video_to_diffusers.py index ebcf11389bc2..16fefa39ab25 100644 --- a/scripts/convert_hunyuan_video_to_diffusers.py +++ b/scripts/convert_hunyuan_video_to_diffusers.py @@ -26,7 +26,7 @@ def remap_norm_scale_shift_(key, state_dict): "final_layer.norm_final": "norm_out.norm", "final_layer.linear": "proj_out", "fc1": "net.0.proj", - "fc2": "net.2" + "fc2": "net.2", } TRANSFORMER_SPECIAL_KEYS_REMAP = { diff --git a/src/diffusers/models/activations.py b/src/diffusers/models/activations.py index 2a1be0b0c386..f03b9060b595 100644 --- a/src/diffusers/models/activations.py +++ b/src/diffusers/models/activations.py @@ -149,10 +149,10 @@ def forward(self, hidden_states): class SiLU(nn.Module): def __init__(self, dim_in: int, dim_out: int, bias: bool = True): super().__init__() - + self.proj = nn.Linear(dim_in, dim_out, bias=bias) self.activation = nn.SiLU() - + def forward(self, hidden_states): hidden_states = self.proj(hidden_states) return self.activation(hidden_states) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index bd79a5a5d4ff..2666ffe94528 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -19,7 +19,7 @@ from ..utils import deprecate, logging from ..utils.torch_utils import maybe_allow_in_graph -from .activations import GEGLU, GELU, ApproximateGELU, FP32SiLU, SwiGLU, SiLU +from .activations import GEGLU, GELU, ApproximateGELU, FP32SiLU, SiLU, SwiGLU from .attention_processor import Attention, JointAttnProcessor2_0 from .embeddings import SinusoidalPositionalEmbedding from .normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero, RMSNorm, SD35AdaLayerNormZeroX diff --git a/src/diffusers/models/transformers/transformer_hunyuan_video.py b/src/diffusers/models/transformers/transformer_hunyuan_video.py index dcbb24146e72..c2f02718e165 100644 --- a/src/diffusers/models/transformers/transformer_hunyuan_video.py +++ b/src/diffusers/models/transformers/transformer_hunyuan_video.py @@ -13,7 +13,6 @@ # limitations under the License. import math -from functools import partial from typing import Callable, Dict, List, Optional, Tuple, Union import torch @@ -23,17 +22,12 @@ from ...configuration_utils import ConfigMixin, register_to_config from ..attention import FeedForward -from ..modeling_utils import ModelMixin from ..modeling_outputs import Transformer2DModelOutput +from ..modeling_utils import ModelMixin from ..normalization import AdaLayerNormContinuous -def attention( - q, - k, - v, - attn_mask=None -): +def attention(q, k, v, attn_mask=None): q = q.transpose(1, 2) k = k.transpose(1, 2) v = v.transpose(1, 2) @@ -327,7 +321,7 @@ def __init__( bias=True, ): super().__init__() - + patch_size = tuple(patch_size) self.flatten = flatten self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias) @@ -432,7 +426,6 @@ def __init__( super().__init__() self.heads_num = heads_num head_dim = hidden_size // heads_num - mlp_hidden_dim = int(hidden_size * mlp_width_ratio) self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6) self.self_attn_qkv = nn.Linear(hidden_size, hidden_size * 3, bias=qkv_bias) @@ -447,7 +440,7 @@ def __init__( self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6) act_layer = get_activation_layer(act_type) - + self.mlp = FeedForward(hidden_size, mult=mlp_width_ratio, activation_fn="silu", dropout=mlp_drop_rate) self.adaLN_modulation = nn.Sequential( @@ -613,7 +606,6 @@ def __init__( hidden_size: int, heads_num: int, mlp_width_ratio: float, - mlp_act_type: str = "gelu_tanh", qk_norm: bool = True, qk_norm_type: str = "rms", qkv_bias: bool = False, @@ -623,7 +615,6 @@ def __init__( self.deterministic = False self.heads_num = heads_num head_dim = hidden_size // heads_num - mlp_hidden_dim = int(hidden_size * mlp_width_ratio) self.img_mod = ModulateDiT(hidden_size, factor=6, act_layer=get_activation_layer("silu")) self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) @@ -907,7 +898,6 @@ def __init__( inner_dim, num_attention_heads, mlp_width_ratio=mlp_width_ratio, - mlp_act_type=mlp_act_type, qk_norm=qk_norm, qk_norm_type=qk_norm_type, qkv_bias=qkv_bias, @@ -1000,11 +990,13 @@ def forward( hidden_states = self.norm_out(hidden_states, temb) hidden_states = self.proj_out(hidden_states) - hidden_states = hidden_states.reshape(batch_size, post_patch_num_frames, post_patch_height, post_patch_width, -1, p_t, p, p) + hidden_states = hidden_states.reshape( + batch_size, post_patch_num_frames, post_patch_height, post_patch_width, -1, p_t, p, p + ) hidden_states = hidden_states.permute(0, 4, 1, 5, 2, 6, 3, 7) hidden_states = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3) - + if not return_dict: return (hidden_states,) - + return Transformer2DModelOutput(sample=hidden_states) diff --git a/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py b/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py index 475b228660b7..94732c3d8947 100644 --- a/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +++ b/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py @@ -523,7 +523,11 @@ def get_nd_rotary_pos_embed( return emb latents_size = [(video_length - 1) // 4 + 1, height // 8, width // 8] - rope_sizes = [latents_size[0] // self.transformer.config.patch_size_t, latents_size[1] // self.transformer.config.patch_size, latents_size[2] // self.transformer.config.patch_size] + rope_sizes = [ + latents_size[0] // self.transformer.config.patch_size_t, + latents_size[1] // self.transformer.config.patch_size, + latents_size[2] // self.transformer.config.patch_size, + ] freqs_cos, freqs_sin = get_nd_rotary_pos_embed( self.transformer.config.rope_dim_list, From ab319fedf1369f0fb13978aa22492bef8a637694 Mon Sep 17 00:00:00 2001 From: Aryan Date: Wed, 11 Dec 2024 14:39:54 +0100 Subject: [PATCH 14/58] add docs --- docs/source/en/_toctree.yml | 4 +++ .../models/autoencoder_kl_hunyuan_video.md | 30 +++++++++++++++++++ .../models/hunyuan_video_transformer_3d.md | 28 +++++++++++++++++ 3 files changed, 62 insertions(+) create mode 100644 docs/source/en/api/models/autoencoder_kl_hunyuan_video.md create mode 100644 docs/source/en/api/models/hunyuan_video_transformer_3d.md diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 47eb922f525e..cd47e2a77dfb 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -268,6 +268,8 @@ title: FluxTransformer2DModel - local: api/models/hunyuan_transformer2d title: HunyuanDiT2DModel + - local: api/models/hunyuan_video_transformer_3d + title: HunyuanVideoTransformer3DModel - local: api/models/latte_transformer3d title: LatteTransformer3DModel - local: api/models/lumina_nextdit2d @@ -310,6 +312,8 @@ title: AutoencoderKLAllegro - local: api/models/autoencoderkl_cogvideox title: AutoencoderKLCogVideoX + - local: api/models/autoencoder_kl_hunyuan_video + title: AutoencoderKLHunyuanVideo - local: api/models/autoencoderkl_mochi title: AutoencoderKLMochi - local: api/models/asymmetricautoencoderkl diff --git a/docs/source/en/api/models/autoencoder_kl_hunyuan_video.md b/docs/source/en/api/models/autoencoder_kl_hunyuan_video.md new file mode 100644 index 000000000000..89679422d664 --- /dev/null +++ b/docs/source/en/api/models/autoencoder_kl_hunyuan_video.md @@ -0,0 +1,30 @@ + + +# AutoencoderKLHunyuanVideo + +The 3D variational autoencoder (VAE) model with KL loss used in [HunyuanVideo](https://github.com/Tencent/HunyuanVideo/), which was introduced in [HunyuanVideo: A Systematic Framework For Large Video Generative Models](https://huggingface.co/papers/2412.03603) by Tencent. + +The model can be loaded with the following code snippet. + +```python +TODO +``` + +## AutoencoderKLMochi + +[[autodoc]] AutoencoderKLHunyuanVideo + - decode + - all + +## DecoderOutput + +[[autodoc]] models.autoencoders.vae.DecoderOutput diff --git a/docs/source/en/api/models/hunyuan_video_transformer_3d.md b/docs/source/en/api/models/hunyuan_video_transformer_3d.md new file mode 100644 index 000000000000..18fcb80e55ca --- /dev/null +++ b/docs/source/en/api/models/hunyuan_video_transformer_3d.md @@ -0,0 +1,28 @@ + + +# HunyuanVideoTransformer3DModel + +A Diffusion Transformer model for 3D video-like data was introduced in [HunyuanVideo: A Systematic Framework For Large Video Generative Models](https://huggingface.co/papers/2412.03603) by Tencent. + +The model can be loaded with the following code snippet. + +```python +TODO +``` + +## HunyuanVideoTransformer3DModel + +[[autodoc]] MochiTransformer3DModel + +## Transformer2DModelOutput + +[[autodoc]] models.modeling_outputs.Transformer2DModelOutput From e3abe385c5cee47b71ed5d59813bb5619336ce0c Mon Sep 17 00:00:00 2001 From: Aryan Date: Wed, 11 Dec 2024 14:50:42 +0100 Subject: [PATCH 15/58] refactor layer names --- scripts/convert_hunyuan_video_to_diffusers.py | 6 + .../transformers/transformer_hunyuan_video.py | 239 +++++++++--------- 2 files changed, 121 insertions(+), 124 deletions(-) diff --git a/scripts/convert_hunyuan_video_to_diffusers.py b/scripts/convert_hunyuan_video_to_diffusers.py index 16fefa39ab25..54925c3f7d33 100644 --- a/scripts/convert_hunyuan_video_to_diffusers.py +++ b/scripts/convert_hunyuan_video_to_diffusers.py @@ -23,6 +23,12 @@ def remap_norm_scale_shift_(key, state_dict): # "vector_in.out_layer": "time_text_embed.text_embedder.linear_2", "double_blocks": "transformer_blocks", "single_blocks": "single_transformer_blocks", + "img_norm1": "norm1", + "img_norm2": "norm2", + "img_mlp": "ff", + "txt_norm1": "norm1_context", + "txt_norm2": "norm2_context", + "txt_mlp": "ff_context", "final_layer.norm_final": "norm_out.norm", "final_layer.linear": "proj_out", "fc1": "net.0.proj", diff --git a/src/diffusers/models/transformers/transformer_hunyuan_video.py b/src/diffusers/models/transformers/transformer_hunyuan_video.py index c2f02718e165..d0c4d8640c11 100644 --- a/src/diffusers/models/transformers/transformer_hunyuan_video.py +++ b/src/diffusers/models/transformers/transformer_hunyuan_video.py @@ -24,7 +24,7 @@ from ..attention import FeedForward from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin -from ..normalization import AdaLayerNormContinuous +from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero def attention(q, k, v, attn_mask=None): @@ -594,13 +594,86 @@ def forward( return x -class HunyuanVideoDoubleStreamBlock(nn.Module): +class HunyuanVideoSingleTransformerBlock(nn.Module): """ - A multimodal dit block with seperate modulation for text and image/video, see more details (SD3): - https://arxiv.org/abs/2403.03206 - (Flux.1): https://github.com/black-forest-labs/flux + A DiT block with parallel linear layers as described in https://arxiv.org/abs/2302.05442 and adapted modulation + interface. Also refer to (SD3): https://arxiv.org/abs/2403.03206 + (Flux.1): https://github.com/black-forest-labs/flux """ + def __init__( + self, + hidden_size: int, + heads_num: int, + mlp_width_ratio: float = 4.0, + mlp_act_type: str = "gelu_tanh", + qk_norm: bool = True, + qk_norm_type: str = "rms", + ): + super().__init__() + + self.hidden_size = hidden_size + self.heads_num = heads_num + head_dim = hidden_size // heads_num + mlp_hidden_dim = int(hidden_size * mlp_width_ratio) + self.mlp_hidden_dim = mlp_hidden_dim + + # qkv and mlp_in + self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + mlp_hidden_dim) + # proj and mlp_out + self.linear2 = nn.Linear(hidden_size + mlp_hidden_dim, hidden_size) + + qk_norm_layer = get_norm_layer(qk_norm_type) + self.q_norm = qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity() + self.k_norm = qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity() + + self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + + self.mlp_act = get_activation_layer(mlp_act_type)() + self.modulation = ModulateDiT( + hidden_size, + factor=3, + act_layer=get_activation_layer("silu"), + ) + + def forward( + self, + x: torch.Tensor, + vec: torch.Tensor, + txt_len: int, + freqs_cis: Tuple[torch.Tensor, torch.Tensor] = None, + ) -> torch.Tensor: + shift_msa, scale_msa, gate_msa = self.modulation(vec).chunk(3, dim=-1) + x_mod = modulate(self.pre_norm(x), shift=shift_msa, scale=scale_msa) + qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1) + + q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num) + + # Apply QK-Norm if needed. + q = self.q_norm(q).to(v) + k = self.k_norm(k).to(v) + + # Apply RoPE if needed. + if freqs_cis is not None: + img_q, txt_q = q[:, :-txt_len, :, :], q[:, -txt_len:, :, :] + img_k, txt_k = k[:, :-txt_len, :, :], k[:, -txt_len:, :, :] + img_qq, img_kk = apply_rotary_emb(img_q, img_k, freqs_cis, head_first=False) + assert ( + img_qq.shape == img_q.shape and img_kk.shape == img_k.shape + ), f"img_kk: {img_qq.shape}, img_q: {img_q.shape}, img_kk: {img_kk.shape}, img_k: {img_k.shape}" + img_q, img_k = img_qq, img_kk + q = torch.cat((img_q, txt_q), dim=1) + k = torch.cat((img_k, txt_k), dim=1) + + attn = attention(q, k, v) + + # Compute activation in mlp stream, cat again and run second linear layer. + output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2)) + output = x + output * gate_msa.unsqueeze(1) + return output + + +class HunyuanVideoTransformerBlock(nn.Module): def __init__( self, hidden_size: int, @@ -617,31 +690,31 @@ def __init__( head_dim = hidden_size // heads_num self.img_mod = ModulateDiT(hidden_size, factor=6, act_layer=get_activation_layer("silu")) - self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) self.img_attn_qkv = nn.Linear(hidden_size, hidden_size * 3, bias=qkv_bias) qk_norm_layer = get_norm_layer(qk_norm_type) - self.img_attn_q_norm = qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity() - self.img_attn_k_norm = qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity() + self.img_attn_q_norm = qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6) + self.img_attn_k_norm = qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6) self.img_attn_proj = nn.Linear(hidden_size, hidden_size, bias=qkv_bias) - self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) - self.img_mlp = FeedForward(hidden_size, mult=mlp_width_ratio, activation_fn="gelu-approximate") - self.txt_mod = ModulateDiT( hidden_size, factor=6, act_layer=get_activation_layer("silu"), ) - self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.norm1_context = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) self.txt_attn_qkv = nn.Linear(hidden_size, hidden_size * 3, bias=qkv_bias) - self.txt_attn_q_norm = qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity() - self.txt_attn_k_norm = qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity() + self.txt_attn_q_norm = qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6) + self.txt_attn_k_norm = qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6) self.txt_attn_proj = nn.Linear(hidden_size, hidden_size, bias=qkv_bias) - self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) - self.txt_mlp = FeedForward(hidden_size, mult=mlp_width_ratio, activation_fn="gelu-approximate") + self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.ff = FeedForward(hidden_size, mult=mlp_width_ratio, activation_fn="gelu-approximate") + + self.norm2_context = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.ff_context = FeedForward(hidden_size, mult=mlp_width_ratio, activation_fn="gelu-approximate") def forward( self, @@ -651,25 +724,25 @@ def forward( freqs_cis: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: ( - img_mod1_shift, - img_mod1_scale, - img_mod1_gate, - img_mod2_shift, - img_mod2_scale, - img_mod2_gate, + shift_msa, + scale_msa, + gate_msa, + shift_mlp, + scale_mlp, + gate_mlp, ) = self.img_mod(temb).chunk(6, dim=-1) ( - txt_mod1_shift, - txt_mod1_scale, - txt_mod1_gate, - txt_mod2_shift, - txt_mod2_scale, - txt_mod2_gate, + c_shift_msa, + c_scale_msa, + c_gate_msa, + c_shift_mlp, + c_scale_mlp, + c_gate_mlp, ) = self.txt_mod(temb).chunk(6, dim=-1) # Prepare image for attention. - img_modulated = self.img_norm1(hidden_states) - img_modulated = modulate(img_modulated, shift=img_mod1_shift, scale=img_mod1_scale) + img_modulated = self.norm1(hidden_states) + img_modulated = modulate(img_modulated, shift=shift_msa, scale=scale_msa) img_qkv = self.img_attn_qkv(img_modulated) img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num) # Apply QK-Norm if needed @@ -685,8 +758,8 @@ def forward( img_q, img_k = img_qq, img_kk # Prepare txt for attention. - txt_modulated = self.txt_norm1(encoder_hidden_states) - txt_modulated = modulate(txt_modulated, shift=txt_mod1_shift, scale=txt_mod1_scale) + txt_modulated = self.norm1_context(encoder_hidden_states) + txt_modulated = modulate(txt_modulated, shift=c_shift_msa, scale=c_scale_msa) txt_qkv = self.txt_attn_qkv(txt_modulated) txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num) # Apply QK-Norm if needed. @@ -702,102 +775,20 @@ def forward( img_attn, txt_attn = attn[:, : hidden_states.shape[1]], attn[:, hidden_states.shape[1] :] # Calculate the img bloks. - hidden_states = hidden_states + self.img_attn_proj(img_attn) * img_mod1_gate.unsqueeze(1) - hidden_states = hidden_states + self.img_mlp( - modulate(self.img_norm2(hidden_states), shift=img_mod2_shift, scale=img_mod2_scale) - ) * img_mod2_gate.unsqueeze(1) + hidden_states = hidden_states + self.img_attn_proj(img_attn) * gate_msa.unsqueeze(1) + hidden_states = hidden_states + self.ff( + modulate(self.norm2(hidden_states), shift=shift_mlp, scale=scale_mlp) + ) * gate_mlp.unsqueeze(1) # Calculate the txt bloks. - encoder_hidden_states = encoder_hidden_states + self.txt_attn_proj(txt_attn) * txt_mod1_gate.unsqueeze(1) - encoder_hidden_states = encoder_hidden_states + self.txt_mlp( - modulate(self.txt_norm2(encoder_hidden_states), shift=txt_mod2_shift, scale=txt_mod2_scale) - ) * txt_mod2_gate.unsqueeze(1) + encoder_hidden_states = encoder_hidden_states + self.txt_attn_proj(txt_attn) * c_gate_msa.unsqueeze(1) + encoder_hidden_states = encoder_hidden_states + self.ff_context( + modulate(self.norm2_context(encoder_hidden_states), shift=c_shift_mlp, scale=c_scale_mlp) + ) * c_gate_mlp.unsqueeze(1) return hidden_states, encoder_hidden_states -class HunyuanVideoSingleStreamBlock(nn.Module): - """ - A DiT block with parallel linear layers as described in https://arxiv.org/abs/2302.05442 and adapted modulation - interface. Also refer to (SD3): https://arxiv.org/abs/2403.03206 - (Flux.1): https://github.com/black-forest-labs/flux - """ - - def __init__( - self, - hidden_size: int, - heads_num: int, - mlp_width_ratio: float = 4.0, - mlp_act_type: str = "gelu_tanh", - qk_norm: bool = True, - qk_norm_type: str = "rms", - qk_scale: float = None, - ): - super().__init__() - - self.deterministic = False - self.hidden_size = hidden_size - self.heads_num = heads_num - head_dim = hidden_size // heads_num - mlp_hidden_dim = int(hidden_size * mlp_width_ratio) - self.mlp_hidden_dim = mlp_hidden_dim - self.scale = qk_scale or head_dim**-0.5 - - # qkv and mlp_in - self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + mlp_hidden_dim) - # proj and mlp_out - self.linear2 = nn.Linear(hidden_size + mlp_hidden_dim, hidden_size) - - qk_norm_layer = get_norm_layer(qk_norm_type) - self.q_norm = qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity() - self.k_norm = qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity() - - self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) - - self.mlp_act = get_activation_layer(mlp_act_type)() - self.modulation = ModulateDiT( - hidden_size, - factor=3, - act_layer=get_activation_layer("silu"), - ) - - def forward( - self, - x: torch.Tensor, - vec: torch.Tensor, - txt_len: int, - freqs_cis: Tuple[torch.Tensor, torch.Tensor] = None, - ) -> torch.Tensor: - mod_shift, mod_scale, mod_gate = self.modulation(vec).chunk(3, dim=-1) - x_mod = modulate(self.pre_norm(x), shift=mod_shift, scale=mod_scale) - qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1) - - q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num) - - # Apply QK-Norm if needed. - q = self.q_norm(q).to(v) - k = self.k_norm(k).to(v) - - # Apply RoPE if needed. - if freqs_cis is not None: - img_q, txt_q = q[:, :-txt_len, :, :], q[:, -txt_len:, :, :] - img_k, txt_k = k[:, :-txt_len, :, :], k[:, -txt_len:, :, :] - img_qq, img_kk = apply_rotary_emb(img_q, img_k, freqs_cis, head_first=False) - assert ( - img_qq.shape == img_q.shape and img_kk.shape == img_k.shape - ), f"img_kk: {img_qq.shape}, img_q: {img_q.shape}, img_kk: {img_kk.shape}, img_k: {img_k.shape}" - img_q, img_k = img_qq, img_kk - q = torch.cat((img_q, txt_q), dim=1) - k = torch.cat((img_k, txt_k), dim=1) - - attn = attention(q, k, v) - - # Compute activation in mlp stream, cat again and run second linear layer. - output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2)) - output = x + output * mod_gate.unsqueeze(1) - return output - - class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin): """ HunyuanVideo Transformer backbone @@ -894,7 +885,7 @@ def __init__( self.transformer_blocks = nn.ModuleList( [ - HunyuanVideoDoubleStreamBlock( + HunyuanVideoTransformerBlock( inner_dim, num_attention_heads, mlp_width_ratio=mlp_width_ratio, @@ -908,7 +899,7 @@ def __init__( self.single_transformer_blocks = nn.ModuleList( [ - HunyuanVideoSingleStreamBlock( + HunyuanVideoSingleTransformerBlock( inner_dim, num_attention_heads, mlp_width_ratio=mlp_width_ratio, From 43f62951c775ddf83f2294cf760741c5551ddd34 Mon Sep 17 00:00:00 2001 From: Aryan Date: Wed, 11 Dec 2024 15:17:25 +0100 Subject: [PATCH 16/58] refactor modulation --- scripts/convert_hunyuan_video_to_diffusers.py | 8 +- .../transformers/transformer_hunyuan_video.py | 167 +++++++----------- 2 files changed, 74 insertions(+), 101 deletions(-) diff --git a/scripts/convert_hunyuan_video_to_diffusers.py b/scripts/convert_hunyuan_video_to_diffusers.py index 54925c3f7d33..e73e29703f8f 100644 --- a/scripts/convert_hunyuan_video_to_diffusers.py +++ b/scripts/convert_hunyuan_video_to_diffusers.py @@ -23,12 +23,16 @@ def remap_norm_scale_shift_(key, state_dict): # "vector_in.out_layer": "time_text_embed.text_embedder.linear_2", "double_blocks": "transformer_blocks", "single_blocks": "single_transformer_blocks", - "img_norm1": "norm1", + "img_mod.linear": "norm1.linear", + "img_norm1": "norm1.norm", "img_norm2": "norm2", "img_mlp": "ff", - "txt_norm1": "norm1_context", + "txt_mod.linear": "norm1_context.linear", + "txt_norm1": "norm1.norm", "txt_norm2": "norm2_context", "txt_mlp": "ff_context", + "modulation.linear": "norm.linear", + "pre_norm": "norm.norm", "final_layer.norm_final": "norm_out.norm", "final_layer.linear": "proj_out", "fc1": "net.0.proj", diff --git a/src/diffusers/models/transformers/transformer_hunyuan_video.py b/src/diffusers/models/transformers/transformer_hunyuan_video.py index d0c4d8640c11..8dae06d65cf4 100644 --- a/src/diffusers/models/transformers/transformer_hunyuan_video.py +++ b/src/diffusers/models/transformers/transformer_hunyuan_video.py @@ -24,7 +24,7 @@ from ..attention import FeedForward from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin -from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero +from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle def attention(q, k, v, attn_mask=None): @@ -267,23 +267,6 @@ def modulate(x, shift=None, scale=None): return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) -class ModulateDiT(nn.Module): - """Modulation layer for DiT.""" - - def __init__( - self, - hidden_size: int, - factor: int, - act_layer: Callable, - ): - super().__init__() - self.act = act_layer() - self.linear = nn.Linear(hidden_size, factor * hidden_size, bias=True) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - return self.linear(self.act(x)) - - class MLPEmbedder(nn.Module): """copied from https://github.com/black-forest-labs/flux/blob/main/src/flux/modules/layers.py""" @@ -595,12 +578,6 @@ def forward( class HunyuanVideoSingleTransformerBlock(nn.Module): - """ - A DiT block with parallel linear layers as described in https://arxiv.org/abs/2302.05442 and adapted modulation - interface. Also refer to (SD3): https://arxiv.org/abs/2403.03206 - (Flux.1): https://github.com/black-forest-labs/flux - """ - def __init__( self, hidden_size: int, @@ -624,28 +601,22 @@ def __init__( self.linear2 = nn.Linear(hidden_size + mlp_hidden_dim, hidden_size) qk_norm_layer = get_norm_layer(qk_norm_type) - self.q_norm = qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity() - self.k_norm = qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity() - - self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.q_norm = qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6) + self.k_norm = qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6) self.mlp_act = get_activation_layer(mlp_act_type)() - self.modulation = ModulateDiT( - hidden_size, - factor=3, - act_layer=get_activation_layer("silu"), - ) + self.norm = AdaLayerNormZeroSingle(hidden_size, norm_type="layer_norm") def forward( self, - x: torch.Tensor, - vec: torch.Tensor, + hidden_states: torch.Tensor, + temb: torch.Tensor, txt_len: int, - freqs_cis: Tuple[torch.Tensor, torch.Tensor] = None, + image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, ) -> torch.Tensor: - shift_msa, scale_msa, gate_msa = self.modulation(vec).chunk(3, dim=-1) - x_mod = modulate(self.pre_norm(x), shift=shift_msa, scale=scale_msa) - qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1) + norm_hidden_states, gate = self.norm(hidden_states, emb=temb) + + qkv, mlp = torch.split(self.linear1(norm_hidden_states), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1) q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num) @@ -654,13 +625,10 @@ def forward( k = self.k_norm(k).to(v) # Apply RoPE if needed. - if freqs_cis is not None: + if image_rotary_emb is not None: img_q, txt_q = q[:, :-txt_len, :, :], q[:, -txt_len:, :, :] img_k, txt_k = k[:, :-txt_len, :, :], k[:, -txt_len:, :, :] - img_qq, img_kk = apply_rotary_emb(img_q, img_k, freqs_cis, head_first=False) - assert ( - img_qq.shape == img_q.shape and img_kk.shape == img_k.shape - ), f"img_kk: {img_qq.shape}, img_q: {img_q.shape}, img_kk: {img_kk.shape}, img_k: {img_k.shape}" + img_qq, img_kk = apply_rotary_emb(img_q, img_k, image_rotary_emb, head_first=False) img_q, img_k = img_qq, img_kk q = torch.cat((img_q, txt_q), dim=1) k = torch.cat((img_k, txt_k), dim=1) @@ -669,7 +637,7 @@ def forward( # Compute activation in mlp stream, cat again and run second linear layer. output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2)) - output = x + output * gate_msa.unsqueeze(1) + output = hidden_states + output * gate.unsqueeze(1) return output @@ -685,12 +653,11 @@ def __init__( ): super().__init__() - self.deterministic = False self.heads_num = heads_num head_dim = hidden_size // heads_num - self.img_mod = ModulateDiT(hidden_size, factor=6, act_layer=get_activation_layer("silu")) - self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.norm1 = AdaLayerNormZero(hidden_size, norm_type="layer_norm") + self.norm1_context = AdaLayerNormZero(hidden_size, norm_type="layer_norm") self.img_attn_qkv = nn.Linear(hidden_size, hidden_size * 3, bias=qkv_bias) qk_norm_layer = get_norm_layer(qk_norm_type) @@ -698,13 +665,6 @@ def __init__( self.img_attn_k_norm = qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6) self.img_attn_proj = nn.Linear(hidden_size, hidden_size, bias=qkv_bias) - self.txt_mod = ModulateDiT( - hidden_size, - factor=6, - act_layer=get_activation_layer("silu"), - ) - self.norm1_context = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) - self.txt_attn_qkv = nn.Linear(hidden_size, hidden_size * 3, bias=qkv_bias) self.txt_attn_q_norm = qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6) self.txt_attn_k_norm = qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6) @@ -723,27 +683,33 @@ def forward( temb: torch.Tensor, freqs_cis: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: - ( - shift_msa, - scale_msa, - gate_msa, - shift_mlp, - scale_mlp, - gate_mlp, - ) = self.img_mod(temb).chunk(6, dim=-1) - ( - c_shift_msa, - c_scale_msa, - c_gate_msa, - c_shift_mlp, - c_scale_mlp, - c_gate_mlp, - ) = self.txt_mod(temb).chunk(6, dim=-1) - - # Prepare image for attention. - img_modulated = self.norm1(hidden_states) - img_modulated = modulate(img_modulated, shift=shift_msa, scale=scale_msa) - img_qkv = self.img_attn_qkv(img_modulated) + # ( + # shift_msa, + # scale_msa, + # gate_msa, + # shift_mlp, + # scale_mlp, + # gate_mlp, + # ) = self.img_mod(temb).chunk(6, dim=-1) + # ( + # c_shift_msa, + # c_scale_msa, + # c_gate_msa, + # c_shift_mlp, + # c_scale_mlp, + # c_gate_mlp, + # ) = self.txt_mod(temb).chunk(6, dim=-1) + + norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb) + norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context( + encoder_hidden_states, emb=temb + ) + + # # Prepare image for attention. + # img_modulated = self.norm1(hidden_states) + # img_modulated = modulate(img_modulated, shift=shift_msa, scale=scale_msa) + + img_qkv = self.img_attn_qkv(norm_hidden_states) img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num) # Apply QK-Norm if needed img_q = self.img_attn_q_norm(img_q).to(img_v) @@ -758,9 +724,10 @@ def forward( img_q, img_k = img_qq, img_kk # Prepare txt for attention. - txt_modulated = self.norm1_context(encoder_hidden_states) - txt_modulated = modulate(txt_modulated, shift=c_shift_msa, scale=c_scale_msa) - txt_qkv = self.txt_attn_qkv(txt_modulated) + # txt_modulated = self.norm1_context(encoder_hidden_states) + # txt_modulated = modulate(txt_modulated, shift=c_shift_msa, scale=c_scale_msa) + + txt_qkv = self.txt_attn_qkv(norm_encoder_hidden_states) txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num) # Apply QK-Norm if needed. txt_q = self.txt_attn_q_norm(txt_q).to(txt_v) @@ -774,17 +741,19 @@ def forward( img_attn, txt_attn = attn[:, : hidden_states.shape[1]], attn[:, hidden_states.shape[1] :] - # Calculate the img bloks. hidden_states = hidden_states + self.img_attn_proj(img_attn) * gate_msa.unsqueeze(1) - hidden_states = hidden_states + self.ff( - modulate(self.norm2(hidden_states), shift=shift_mlp, scale=scale_mlp) - ) * gate_mlp.unsqueeze(1) - - # Calculate the txt bloks. encoder_hidden_states = encoder_hidden_states + self.txt_attn_proj(txt_attn) * c_gate_msa.unsqueeze(1) - encoder_hidden_states = encoder_hidden_states + self.ff_context( - modulate(self.norm2_context(encoder_hidden_states), shift=c_shift_mlp, scale=c_scale_mlp) - ) * c_gate_mlp.unsqueeze(1) + + norm_hidden_states = self.norm2(hidden_states) + norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] + ff_output = self.ff(norm_hidden_states) + ff_output = gate_mlp.unsqueeze(1) * ff_output + hidden_states = hidden_states + ff_output + + norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states) + norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None] + context_ff_output = self.ff_context(norm_encoder_hidden_states) + encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output return hidden_states, encoder_hidden_states @@ -964,17 +933,17 @@ def forward( hidden_states, encoder_hidden_states = block(*double_block_args) - hidden_states = torch.cat((hidden_states, encoder_hidden_states), 1) - if len(self.single_transformer_blocks) > 0: - for _, block in enumerate(self.single_transformer_blocks): - single_block_args = [ - hidden_states, - temb, - txt_seq_len, - (freqs_cos, freqs_sin), - ] - - hidden_states = block(*single_block_args) + hidden_states = torch.cat((hidden_states, encoder_hidden_states), dim=1) + + for block in self.single_transformer_blocks: + single_block_args = [ + hidden_states, + temb, + txt_seq_len, + (freqs_cos, freqs_sin), + ] + + hidden_states = block(*single_block_args) hidden_states = hidden_states[:, :img_seq_len, ...] From bb6f023f22576df053fb93bb455b375d55fbe5d5 Mon Sep 17 00:00:00 2001 From: Aryan Date: Wed, 11 Dec 2024 15:18:29 +0100 Subject: [PATCH 17/58] cleanup --- .../transformers/transformer_hunyuan_video.py | 53 ------------------- 1 file changed, 53 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_hunyuan_video.py b/src/diffusers/models/transformers/transformer_hunyuan_video.py index 8dae06d65cf4..bc4259215081 100644 --- a/src/diffusers/models/transformers/transformer_hunyuan_video.py +++ b/src/diffusers/models/transformers/transformer_hunyuan_video.py @@ -246,27 +246,6 @@ def get_norm_layer(norm_layer): raise NotImplementedError(f"Norm layer {norm_layer} is not implemented") -def modulate(x, shift=None, scale=None): - """modulate by shift and scale - - Args: - x (torch.Tensor): input tensor. - shift (torch.Tensor, optional): shift tensor. Defaults to None. - scale (torch.Tensor, optional): scale tensor. Defaults to None. - - Returns: - torch.Tensor: the output tensor after modulate. - """ - if scale is None and shift is None: - return x - elif shift is None: - return x * (1 + scale.unsqueeze(1)) - elif scale is None: - return x + shift.unsqueeze(1) - else: - return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) - - class MLPEmbedder(nn.Module): """copied from https://github.com/black-forest-labs/flux/blob/main/src/flux/modules/layers.py""" @@ -595,9 +574,7 @@ def __init__( mlp_hidden_dim = int(hidden_size * mlp_width_ratio) self.mlp_hidden_dim = mlp_hidden_dim - # qkv and mlp_in self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + mlp_hidden_dim) - # proj and mlp_out self.linear2 = nn.Linear(hidden_size + mlp_hidden_dim, hidden_size) qk_norm_layer = get_norm_layer(qk_norm_type) @@ -617,14 +594,12 @@ def forward( norm_hidden_states, gate = self.norm(hidden_states, emb=temb) qkv, mlp = torch.split(self.linear1(norm_hidden_states), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1) - q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num) # Apply QK-Norm if needed. q = self.q_norm(q).to(v) k = self.k_norm(k).to(v) - # Apply RoPE if needed. if image_rotary_emb is not None: img_q, txt_q = q[:, :-txt_len, :, :], q[:, -txt_len:, :, :] img_k, txt_k = k[:, :-txt_len, :, :], k[:, -txt_len:, :, :] @@ -635,7 +610,6 @@ def forward( attn = attention(q, k, v) - # Compute activation in mlp stream, cat again and run second linear layer. output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2)) output = hidden_states + output * gate.unsqueeze(1) return output @@ -683,31 +657,10 @@ def forward( temb: torch.Tensor, freqs_cis: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: - # ( - # shift_msa, - # scale_msa, - # gate_msa, - # shift_mlp, - # scale_mlp, - # gate_mlp, - # ) = self.img_mod(temb).chunk(6, dim=-1) - # ( - # c_shift_msa, - # c_scale_msa, - # c_gate_msa, - # c_shift_mlp, - # c_scale_mlp, - # c_gate_mlp, - # ) = self.txt_mod(temb).chunk(6, dim=-1) - norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb) norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context( encoder_hidden_states, emb=temb ) - - # # Prepare image for attention. - # img_modulated = self.norm1(hidden_states) - # img_modulated = modulate(img_modulated, shift=shift_msa, scale=scale_msa) img_qkv = self.img_attn_qkv(norm_hidden_states) img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num) @@ -722,18 +675,12 @@ def forward( img_qq.shape == img_q.shape and img_kk.shape == img_k.shape ), f"img_kk: {img_qq.shape}, img_q: {img_q.shape}, img_kk: {img_kk.shape}, img_k: {img_k.shape}" img_q, img_k = img_qq, img_kk - - # Prepare txt for attention. - # txt_modulated = self.norm1_context(encoder_hidden_states) - # txt_modulated = modulate(txt_modulated, shift=c_shift_msa, scale=c_scale_msa) txt_qkv = self.txt_attn_qkv(norm_encoder_hidden_states) txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num) - # Apply QK-Norm if needed. txt_q = self.txt_attn_q_norm(txt_q).to(txt_v) txt_k = self.txt_attn_k_norm(txt_k).to(txt_v) - # Run actual attention. q = torch.cat((img_q, txt_q), dim=1) k = torch.cat((img_k, txt_k), dim=1) v = torch.cat((img_v, txt_v), dim=1) From d72768469b2449b44876e46068815a8b79aecd91 Mon Sep 17 00:00:00 2001 From: Aryan Date: Wed, 11 Dec 2024 15:31:39 +0100 Subject: [PATCH 18/58] refactor norms --- .../autoencoder_kl_hunyuan_video.py | 15 +- .../transformers/transformer_hunyuan_video.py | 138 ++---------------- 2 files changed, 17 insertions(+), 136 deletions(-) diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py b/src/diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py index ea3867a1d017..ee1b958f404a 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py @@ -1042,17 +1042,16 @@ def __init__( self, in_channels: int = 3, out_channels: int = 3, - down_block_types: Tuple[str] = ("DownEncoderBlockCausal3D",), - up_block_types: Tuple[str] = ("UpDecoderBlockCausal3D",), - block_out_channels: Tuple[int] = (64,), - layers_per_block: int = 1, + latent_channels: int = 16, + down_block_types: Tuple[str, ...] = ("DownEncoderBlockCausal3D", "DownEncoderBlockCausal3D", "DownEncoderBlockCausal3D", "DownEncoderBlockCausal3D",), + up_block_types: Tuple[str, ...] = ("UpDecoderBlockCausal3D", "UpDecoderBlockCausal3D", "UpDecoderBlockCausal3D", "UpDecoderBlockCausal3D",), + block_out_channels: Tuple[int] = (128, 256, 512, 512), + layers_per_block: int = 2, act_fn: str = "silu", - latent_channels: int = 4, norm_num_groups: int = 32, - sample_size: int = 32, + sample_size: int = 256, sample_tsize: int = 64, - scaling_factor: float = 0.18215, - force_upcast: float = True, + scaling_factor: float = 0.476986, spatial_compression_ratio: int = 8, time_compression_ratio: int = 4, mid_block_add_attention: bool = True, diff --git a/src/diffusers/models/transformers/transformer_hunyuan_video.py b/src/diffusers/models/transformers/transformer_hunyuan_video.py index bc4259215081..8b35a6fea75a 100644 --- a/src/diffusers/models/transformers/transformer_hunyuan_video.py +++ b/src/diffusers/models/transformers/transformer_hunyuan_video.py @@ -24,7 +24,7 @@ from ..attention import FeedForward from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin -from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle +from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle, RMSNorm def attention(q, k, v, attn_mask=None): @@ -174,78 +174,6 @@ def apply_rotary_emb( return xq_out, xk_out -class RMSNorm(nn.Module): - def __init__( - self, - dim: int, - elementwise_affine=True, - eps: float = 1e-6, - ): - """ - Initialize the RMSNorm normalization layer. - - Args: - dim (int): The dimension of the input tensor. - eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6. - - Attributes: - eps (float): A small value added to the denominator for numerical stability. - weight (nn.Parameter): Learnable scaling parameter. - - """ - super().__init__() - self.eps = eps - if elementwise_affine: - self.weight = nn.Parameter(torch.ones(dim)) - - def _norm(self, x): - """ - Apply the RMSNorm normalization to the input tensor. - - Args: - x (torch.Tensor): The input tensor. - - Returns: - torch.Tensor: The normalized tensor. - - """ - return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) - - def forward(self, x): - """ - Forward pass through the RMSNorm layer. - - Args: - x (torch.Tensor): The input tensor. - - Returns: - torch.Tensor: The output tensor after applying RMSNorm. - - """ - output = self._norm(x.float()).type_as(x) - if hasattr(self, "weight"): - output = output * self.weight - return output - - -def get_norm_layer(norm_layer): - """ - Get the normalization layer. - - Args: - norm_layer (str): The type of normalization layer. - - Returns: - norm_layer (nn.Module): The normalization layer. - """ - if norm_layer == "layer": - return nn.LayerNorm - elif norm_layer == "rms": - return RMSNorm - else: - raise NotImplementedError(f"Norm layer {norm_layer} is not implemented") - - class MLPEmbedder(nn.Module): """copied from https://github.com/black-forest-labs/flux/blob/main/src/flux/modules/layers.py""" @@ -260,19 +188,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class PatchEmbed(nn.Module): - """2D Image to Patch Embedding - - Image to Patch Embedding using Conv2d - - A convolution based approach to patchifying a 2D image w/ embedding projection. - - Based on the impl in https://github.com/google-research/vision_transformer - - Hacked together by / Copyright 2020 Ross Wightman - - Remove the _assert function in forward function to be compatible with multi-resolution images. - """ - def __init__( self, patch_size=16, @@ -298,12 +213,6 @@ def forward(self, x): class TextProjection(nn.Module): - """ - Projects text embeddings. Also handles dropout for classifier-free guidance. - - Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py - """ - def __init__(self, in_channels, hidden_size, act_layer): super().__init__() self.linear_1 = nn.Linear(in_features=in_channels, out_features=hidden_size, bias=True) @@ -381,8 +290,6 @@ def __init__( mlp_width_ratio: str = 4.0, mlp_drop_rate: float = 0.0, act_type: str = "silu", - qk_norm: bool = False, - qk_norm_type: str = "layer", qkv_bias: bool = True, ): super().__init__() @@ -391,13 +298,6 @@ def __init__( self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6) self.self_attn_qkv = nn.Linear(hidden_size, hidden_size * 3, bias=qkv_bias) - qk_norm_layer = get_norm_layer(qk_norm_type) - self.self_attn_q_norm = ( - qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity() - ) - self.self_attn_k_norm = ( - qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity() - ) self.self_attn_proj = nn.Linear(hidden_size, hidden_size, bias=qkv_bias) self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6) @@ -421,9 +321,6 @@ def forward( norm_x = self.norm1(x) qkv = self.self_attn_qkv(norm_x) q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num) - # Apply QK-Norm if needed - q = self.self_attn_q_norm(q).to(v) - k = self.self_attn_k_norm(k).to(v) # Self-Attention attn = attention(q, k, v, attn_mask=attn_mask) @@ -445,8 +342,6 @@ def __init__( mlp_width_ratio: float = 4.0, mlp_drop_rate: float = 0.0, act_type: str = "silu", - qk_norm: bool = False, - qk_norm_type: str = "layer", qkv_bias: bool = True, ): super().__init__() @@ -458,8 +353,6 @@ def __init__( mlp_width_ratio=mlp_width_ratio, mlp_drop_rate=mlp_drop_rate, act_type=act_type, - qk_norm=qk_norm, - qk_norm_type=qk_norm_type, qkv_bias=qkv_bias, ) for _ in range(depth) @@ -505,8 +398,6 @@ def __init__( mlp_width_ratio: float = 4.0, mlp_drop_rate: float = 0.0, act_type: str = "silu", - qk_norm: bool = False, - qk_norm_type: str = "layer", qkv_bias: bool = True, ): super().__init__() @@ -526,8 +417,6 @@ def __init__( mlp_width_ratio=mlp_width_ratio, mlp_drop_rate=mlp_drop_rate, act_type=act_type, - qk_norm=qk_norm, - qk_norm_type=qk_norm_type, qkv_bias=qkv_bias, ) @@ -563,8 +452,7 @@ def __init__( heads_num: int, mlp_width_ratio: float = 4.0, mlp_act_type: str = "gelu_tanh", - qk_norm: bool = True, - qk_norm_type: str = "rms", + qk_norm: str = "rms_norm", ): super().__init__() @@ -577,9 +465,8 @@ def __init__( self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + mlp_hidden_dim) self.linear2 = nn.Linear(hidden_size + mlp_hidden_dim, hidden_size) - qk_norm_layer = get_norm_layer(qk_norm_type) - self.q_norm = qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6) - self.k_norm = qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6) + self.q_norm = RMSNorm(head_dim, elementwise_affine=True, eps=1e-6) + self.k_norm = RMSNorm(head_dim, elementwise_affine=True, eps=1e-6) self.mlp_act = get_activation_layer(mlp_act_type)() self.norm = AdaLayerNormZeroSingle(hidden_size, norm_type="layer_norm") @@ -621,8 +508,7 @@ def __init__( hidden_size: int, heads_num: int, mlp_width_ratio: float, - qk_norm: bool = True, - qk_norm_type: str = "rms", + qk_norm: str = "rms_norm", qkv_bias: bool = False, ): super().__init__() @@ -634,14 +520,13 @@ def __init__( self.norm1_context = AdaLayerNormZero(hidden_size, norm_type="layer_norm") self.img_attn_qkv = nn.Linear(hidden_size, hidden_size * 3, bias=qkv_bias) - qk_norm_layer = get_norm_layer(qk_norm_type) - self.img_attn_q_norm = qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6) - self.img_attn_k_norm = qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6) + self.img_attn_q_norm = RMSNorm(head_dim, elementwise_affine=True, eps=1e-6) + self.img_attn_k_norm = RMSNorm(head_dim, elementwise_affine=True, eps=1e-6) self.img_attn_proj = nn.Linear(hidden_size, hidden_size, bias=qkv_bias) self.txt_attn_qkv = nn.Linear(hidden_size, hidden_size * 3, bias=qkv_bias) - self.txt_attn_q_norm = qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6) - self.txt_attn_k_norm = qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6) + self.txt_attn_q_norm = RMSNorm(head_dim, elementwise_affine=True, eps=1e-6) + self.txt_attn_k_norm = RMSNorm(head_dim, elementwise_affine=True, eps=1e-6) self.txt_attn_proj = nn.Linear(hidden_size, hidden_size, bias=qkv_bias) self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) @@ -768,8 +653,7 @@ def __init__( mm_single_blocks_depth: int = 40, rope_dim_list: List[int] = [16, 56, 56], qkv_bias: bool = True, - qk_norm: bool = True, - qk_norm_type: str = "rms", + qk_norm: str = "rms_norm", guidance_embed: bool = True, text_states_dim: int = 4096, text_states_dim_2: int = 768, @@ -806,7 +690,6 @@ def __init__( num_attention_heads, mlp_width_ratio=mlp_width_ratio, qk_norm=qk_norm, - qk_norm_type=qk_norm_type, qkv_bias=qkv_bias, ) for _ in range(mm_double_blocks_depth) @@ -821,7 +704,6 @@ def __init__( mlp_width_ratio=mlp_width_ratio, mlp_act_type=mlp_act_type, qk_norm=qk_norm, - qk_norm_type=qk_norm_type, ) for _ in range(mm_single_blocks_depth) ] From a247ca64352a508f0a1e17cb41c488590993d2ab Mon Sep 17 00:00:00 2001 From: Aryan Date: Wed, 11 Dec 2024 15:37:15 +0100 Subject: [PATCH 19/58] refactor activations --- .../transformers/transformer_hunyuan_video.py | 47 +++---------------- 1 file changed, 7 insertions(+), 40 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_hunyuan_video.py b/src/diffusers/models/transformers/transformer_hunyuan_video.py index 8b35a6fea75a..a331a63b28a5 100644 --- a/src/diffusers/models/transformers/transformer_hunyuan_video.py +++ b/src/diffusers/models/transformers/transformer_hunyuan_video.py @@ -42,28 +42,6 @@ def attention(q, k, v, attn_mask=None): return out -def get_activation_layer(act_type): - """get activation layer - - Args: - act_type (str): the activation type - - Returns: - torch.nn.functional: the activation layer - """ - if act_type == "gelu": - return lambda: nn.GELU() - elif act_type == "gelu_tanh": - # Approximate `tanh` requires torch >= 1.13 - return lambda: nn.GELU(approximate="tanh") - elif act_type == "relu": - return nn.ReLU - elif act_type == "silu": - return nn.SiLU - else: - raise ValueError(f"Unknown activation type: {act_type}") - - def reshape_for_broadcast( freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]], x: torch.Tensor, @@ -289,24 +267,21 @@ def __init__( heads_num, mlp_width_ratio: str = 4.0, mlp_drop_rate: float = 0.0, - act_type: str = "silu", qkv_bias: bool = True, ): super().__init__() self.heads_num = heads_num - head_dim = hidden_size // heads_num self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6) self.self_attn_qkv = nn.Linear(hidden_size, hidden_size * 3, bias=qkv_bias) self.self_attn_proj = nn.Linear(hidden_size, hidden_size, bias=qkv_bias) self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6) - act_layer = get_activation_layer(act_type) self.mlp = FeedForward(hidden_size, mult=mlp_width_ratio, activation_fn="silu", dropout=mlp_drop_rate) self.adaLN_modulation = nn.Sequential( - act_layer(), + nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True), ) @@ -341,7 +316,6 @@ def __init__( depth, mlp_width_ratio: float = 4.0, mlp_drop_rate: float = 0.0, - act_type: str = "silu", qkv_bias: bool = True, ): super().__init__() @@ -352,7 +326,6 @@ def __init__( heads_num=heads_num, mlp_width_ratio=mlp_width_ratio, mlp_drop_rate=mlp_drop_rate, - act_type=act_type, qkv_bias=qkv_bias, ) for _ in range(depth) @@ -397,18 +370,16 @@ def __init__( depth, mlp_width_ratio: float = 4.0, mlp_drop_rate: float = 0.0, - act_type: str = "silu", qkv_bias: bool = True, ): super().__init__() self.input_embedder = nn.Linear(in_channels, hidden_size, bias=True) - act_layer = get_activation_layer(act_type) # Build timestep embedding layer - self.t_embedder = TimestepEmbedder(hidden_size, act_layer) + self.t_embedder = TimestepEmbedder(hidden_size, nn.SiLU) # Build context embedding layer - self.c_embedder = TextProjection(in_channels, hidden_size, act_layer) + self.c_embedder = TextProjection(in_channels, hidden_size, nn.SiLU) self.individual_token_refiner = IndividualTokenRefiner( hidden_size=hidden_size, @@ -416,7 +387,6 @@ def __init__( depth=depth, mlp_width_ratio=mlp_width_ratio, mlp_drop_rate=mlp_drop_rate, - act_type=act_type, qkv_bias=qkv_bias, ) @@ -451,7 +421,6 @@ def __init__( hidden_size: int, heads_num: int, mlp_width_ratio: float = 4.0, - mlp_act_type: str = "gelu_tanh", qk_norm: str = "rms_norm", ): super().__init__() @@ -468,7 +437,7 @@ def __init__( self.q_norm = RMSNorm(head_dim, elementwise_affine=True, eps=1e-6) self.k_norm = RMSNorm(head_dim, elementwise_affine=True, eps=1e-6) - self.mlp_act = get_activation_layer(mlp_act_type)() + self.act_mlp = nn.GELU(approximate="tanh") self.norm = AdaLayerNormZeroSingle(hidden_size, norm_type="layer_norm") def forward( @@ -497,7 +466,7 @@ def forward( attn = attention(q, k, v) - output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2)) + output = self.linear2(torch.cat((attn, self.act_mlp(mlp)), 2)) output = hidden_states + output * gate.unsqueeze(1) return output @@ -648,7 +617,6 @@ def __init__( num_attention_heads: int = 24, attention_head_dim: int = 128, mlp_width_ratio: float = 4.0, - mlp_act_type: str = "gelu_tanh", mm_double_blocks_depth: int = 20, mm_single_blocks_depth: int = 40, rope_dim_list: List[int] = [16, 56, 56], @@ -675,13 +643,13 @@ def __init__( self.txt_in = SingleTokenRefiner(text_states_dim, inner_dim, num_attention_heads, depth=2) # time modulation - self.time_in = TimestepEmbedder(inner_dim, get_activation_layer("silu")) + self.time_in = TimestepEmbedder(inner_dim, nn.SiLU) # text modulation self.vector_in = MLPEmbedder(text_states_dim_2, inner_dim) # guidance modulation - self.guidance_in = TimestepEmbedder(inner_dim, get_activation_layer("silu")) if guidance_embed else None + self.guidance_in = TimestepEmbedder(inner_dim, nn.SiLU) self.transformer_blocks = nn.ModuleList( [ @@ -702,7 +670,6 @@ def __init__( inner_dim, num_attention_heads, mlp_width_ratio=mlp_width_ratio, - mlp_act_type=mlp_act_type, qk_norm=qk_norm, ) for _ in range(mm_single_blocks_depth) From 7ba46091a67263db019a9960eb86d200407c6c22 Mon Sep 17 00:00:00 2001 From: Aryan Date: Wed, 11 Dec 2024 16:50:24 +0100 Subject: [PATCH 20/58] refactor single blocks attention --- scripts/convert_hunyuan_video_to_diffusers.py | 33 ++- src/diffusers/models/attention_processor.py | 8 + .../transformers/transformer_hunyuan_video.py | 208 ++++++++++++++---- 3 files changed, 205 insertions(+), 44 deletions(-) diff --git a/scripts/convert_hunyuan_video_to_diffusers.py b/scripts/convert_hunyuan_video_to_diffusers.py index e73e29703f8f..c66cb50cde2a 100644 --- a/scripts/convert_hunyuan_video_to_diffusers.py +++ b/scripts/convert_hunyuan_video_to_diffusers.py @@ -14,6 +14,37 @@ def remap_norm_scale_shift_(key, state_dict): state_dict[key.replace("final_layer.adaLN_modulation.1", "norm_out.linear")] = new_weight +def remap_single_transformer_blocks_(key, state_dict): + hidden_size = 3072 + + if "linear1.weight" in key: + linear1_weight = state_dict.pop(key) + split_size = (hidden_size, hidden_size, hidden_size, linear1_weight.size(0) - 3 * hidden_size) + q, k, v, mlp = torch.split(linear1_weight, split_size, dim=0) + new_key = key.replace("single_blocks", "single_transformer_blocks").removesuffix(".linear1.weight") + state_dict[f"{new_key}.attn.to_q.weight"] = q + state_dict[f"{new_key}.attn.to_k.weight"] = k + state_dict[f"{new_key}.attn.to_v.weight"] = v + state_dict[f"{new_key}.proj_mlp.weight"] = mlp + + elif "linear1.bias" in key: + linear1_bias = state_dict.pop(key) + split_size = (hidden_size, hidden_size, hidden_size, linear1_bias.size(0) - 3 * hidden_size) + q_bias, k_bias, v_bias, mlp_bias = torch.split(linear1_bias, split_size, dim=0) + new_key = key.replace("single_blocks", "single_transformer_blocks").removesuffix(".linear1.bias") + state_dict[f"{new_key}.attn.to_q.bias"] = q_bias + state_dict[f"{new_key}.attn.to_k.bias"] = k_bias + state_dict[f"{new_key}.attn.to_v.bias"] = v_bias + state_dict[f"{new_key}.proj_mlp.bias"] = mlp_bias + + else: + new_key = key.replace("single_blocks", "single_transformer_blocks") + new_key = new_key.replace("linear2", "proj_out") + new_key = new_key.replace("q_norm", "attn.norm_q") + new_key = new_key.replace("k_norm", "attn.norm_k") + state_dict[new_key] = state_dict.pop(key) + + TRANSFORMER_KEYS_RENAME_DICT = { # "time_in.mlp.0": "time_text_embed.timestep_embedder.linear_1", # "time_in.mlp.2": "time_text_embed.timestep_embedder.linear_2", @@ -22,7 +53,6 @@ def remap_norm_scale_shift_(key, state_dict): # "vector_in.in_layer": "time_text_embed.text_embedder.linear_1", # "vector_in.out_layer": "time_text_embed.text_embedder.linear_2", "double_blocks": "transformer_blocks", - "single_blocks": "single_transformer_blocks", "img_mod.linear": "norm1.linear", "img_norm1": "norm1.norm", "img_norm2": "norm2", @@ -41,6 +71,7 @@ def remap_norm_scale_shift_(key, state_dict): TRANSFORMER_SPECIAL_KEYS_REMAP = { "final_layer.adaLN_modulation.1": remap_norm_scale_shift_, + "single_blocks": remap_single_transformer_blocks_, } VAE_KEYS_RENAME_DICT = {} diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index faacc431c386..46d884616602 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -250,14 +250,22 @@ def __init__( self.add_v_proj = nn.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=added_proj_bias) if self.context_pre_only is not None: self.add_q_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias) + else: + self.add_q_proj = None + self.add_k_proj = None + self.add_v_proj = None if not self.pre_only: self.to_out = nn.ModuleList([]) self.to_out.append(nn.Linear(self.inner_dim, self.out_dim, bias=out_bias)) self.to_out.append(nn.Dropout(dropout)) + else: + self.to_out = None if self.context_pre_only is not None and not self.context_pre_only: self.to_add_out = nn.Linear(self.inner_dim, self.out_context_dim, bias=out_bias) + else: + self.to_add_out = None if qk_norm is not None and added_kv_proj_dim is not None: if qk_norm == "fp32_layer_norm": diff --git a/src/diffusers/models/transformers/transformer_hunyuan_video.py b/src/diffusers/models/transformers/transformer_hunyuan_video.py index a331a63b28a5..beb22244acea 100644 --- a/src/diffusers/models/transformers/transformer_hunyuan_video.py +++ b/src/diffusers/models/transformers/transformer_hunyuan_video.py @@ -13,7 +13,7 @@ # limitations under the License. import math -from typing import Callable, Dict, List, Optional, Tuple, Union +from typing import Dict, List, Optional, Tuple, Union import torch import torch.nn as nn @@ -22,6 +22,7 @@ from ...configuration_utils import ConfigMixin, register_to_config from ..attention import FeedForward +from ..attention_processor import Attention from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle, RMSNorm @@ -152,6 +153,102 @@ def apply_rotary_emb( return xq_out, xk_out +class HunyuanVideoAttnProcessor2_0: + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError("HunyuanVideoAttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0.") + + def __call__( + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + batch_size, _, _ = hidden_states.shape + + if attn.add_q_proj is None and encoder_hidden_states is not None: + hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1) + + query = attn.to_q(hidden_states) + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + if image_rotary_emb is not None: + from ..embeddings import apply_rotary_emb + + if attn.add_q_proj is None and encoder_hidden_states is not None: + query = torch.cat([ + apply_rotary_emb(query[:, :, :-encoder_hidden_states.shape[1]], image_rotary_emb), + query[:, :, -encoder_hidden_states.shape[1]:], + ], dim=2) + key = torch.cat([ + apply_rotary_emb(key[:, :, :-encoder_hidden_states.shape[1]], image_rotary_emb), + key[:, :, -encoder_hidden_states.shape[1]:], + ], dim=2) + else: + query = apply_rotary_emb(query, image_rotary_emb) + key = apply_rotary_emb(key, image_rotary_emb) + + if attn.add_q_proj is not None and encoder_hidden_states is not None: + encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states) + encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) + encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) + + encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + + if attn.norm_added_q is not None: + encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj) + if attn.norm_added_k is not None: + encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj) + + query = torch.cat([query, encoder_hidden_states_query_proj], dim=2) + key = torch.cat([key, encoder_hidden_states_key_proj], dim=2) + value = torch.cat([value, encoder_hidden_states_value_proj], dim=2) + + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + if encoder_hidden_states is not None: + hidden_states, encoder_hidden_states = ( + hidden_states[:, :-encoder_hidden_states.shape[1]], + hidden_states[:, -encoder_hidden_states.shape[1]:], + ) + + if not attn.pre_only: + hidden_states = attn.to_out[0](hidden_states) + hidden_states = attn.to_out[1](hidden_states) + + if attn.context_pre_only is not None and not attn.context_pre_only: + encoder_hidden_states = attn.to_add_out(encoder_hidden_states) + + return hidden_states, encoder_hidden_states + + class MLPEmbedder(nn.Module): """copied from https://github.com/black-forest-labs/flux/blob/main/src/flux/modules/layers.py""" @@ -418,57 +515,88 @@ def forward( class HunyuanVideoSingleTransformerBlock(nn.Module): def __init__( self, - hidden_size: int, - heads_num: int, + num_attention_heads: int, + attention_head_dim: int, mlp_width_ratio: float = 4.0, qk_norm: str = "rms_norm", ): super().__init__() - self.hidden_size = hidden_size - self.heads_num = heads_num - head_dim = hidden_size // heads_num + hidden_size = num_attention_heads * attention_head_dim mlp_hidden_dim = int(hidden_size * mlp_width_ratio) + + self.hidden_size = hidden_size + self.heads_num = num_attention_heads self.mlp_hidden_dim = mlp_hidden_dim - self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + mlp_hidden_dim) - self.linear2 = nn.Linear(hidden_size + mlp_hidden_dim, hidden_size) - - self.q_norm = RMSNorm(head_dim, elementwise_affine=True, eps=1e-6) - self.k_norm = RMSNorm(head_dim, elementwise_affine=True, eps=1e-6) + self.attn = Attention( + query_dim=hidden_size, + cross_attention_dim=None, + dim_head=attention_head_dim, + heads=num_attention_heads, + out_dim=hidden_size, + bias=True, + processor=HunyuanVideoAttnProcessor2_0(), + qk_norm="rms_norm", + eps=1e-6, + pre_only=True, + ) - self.act_mlp = nn.GELU(approximate="tanh") self.norm = AdaLayerNormZeroSingle(hidden_size, norm_type="layer_norm") + self.proj_mlp = nn.Linear(hidden_size, self.mlp_hidden_dim) + self.act_mlp = nn.GELU(approximate="tanh") + self.proj_out = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size) def forward( self, hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, temb: torch.Tensor, txt_len: int, image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, ) -> torch.Tensor: + text_seq_length = encoder_hidden_states.shape[1] + hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1) + + residual = hidden_states + norm_hidden_states, gate = self.norm(hidden_states, emb=temb) + mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states)) - qkv, mlp = torch.split(self.linear1(norm_hidden_states), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1) - q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num) - - # Apply QK-Norm if needed. - q = self.q_norm(q).to(v) - k = self.k_norm(k).to(v) - - if image_rotary_emb is not None: - img_q, txt_q = q[:, :-txt_len, :, :], q[:, -txt_len:, :, :] - img_k, txt_k = k[:, :-txt_len, :, :], k[:, -txt_len:, :, :] - img_qq, img_kk = apply_rotary_emb(img_q, img_k, image_rotary_emb, head_first=False) - img_q, img_k = img_qq, img_kk - q = torch.cat((img_q, txt_q), dim=1) - k = torch.cat((img_k, txt_k), dim=1) + norm_hidden_states, norm_encoder_hidden_states = norm_hidden_states[:, :-text_seq_length, :], norm_hidden_states[:, -text_seq_length:, :] + + # qkv, mlp = torch.split(self.linear1(norm_hidden_states), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1) + # q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num) + + # # Apply QK-Norm if needed. + # q = self.q_norm(q).to(v) + # k = self.k_norm(k).to(v) + + # if image_rotary_emb is not None: + # img_q, txt_q = q[:, :-txt_len, :, :], q[:, -txt_len:, :, :] + # img_k, txt_k = k[:, :-txt_len, :, :], k[:, -txt_len:, :, :] + # img_qq, img_kk = apply_rotary_emb(img_q, img_k, image_rotary_emb, head_first=False) + # img_q, img_k = img_qq, img_kk + # q = torch.cat((img_q, txt_q), dim=1) + # k = torch.cat((img_k, txt_k), dim=1) + + # attn = attention(q, k, v) + # output = self.linear2(torch.cat((attn, self.act_mlp(mlp)), 2)) + # output = hidden_states + output * gate.unsqueeze(1) + + attn_output, context_attn_output = self.attn( + hidden_states=norm_hidden_states, + encoder_hidden_states=norm_encoder_hidden_states, + image_rotary_emb=image_rotary_emb, + ) + attn_output = torch.cat([attn_output, context_attn_output], dim=1) - attn = attention(q, k, v) + hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2) + hidden_states = gate.unsqueeze(1) * self.proj_out(hidden_states) + hidden_states = hidden_states + residual - output = self.linear2(torch.cat((attn, self.act_mlp(mlp)), 2)) - output = hidden_states + output * gate.unsqueeze(1) - return output + hidden_states, encoder_hidden_states = hidden_states[:, :-text_seq_length, :], hidden_states[:, -text_seq_length:, :] + return hidden_states, encoder_hidden_states class HunyuanVideoTransformerBlock(nn.Module): @@ -478,7 +606,6 @@ def __init__( heads_num: int, mlp_width_ratio: float, qk_norm: str = "rms_norm", - qkv_bias: bool = False, ): super().__init__() @@ -488,15 +615,15 @@ def __init__( self.norm1 = AdaLayerNormZero(hidden_size, norm_type="layer_norm") self.norm1_context = AdaLayerNormZero(hidden_size, norm_type="layer_norm") - self.img_attn_qkv = nn.Linear(hidden_size, hidden_size * 3, bias=qkv_bias) + self.img_attn_qkv = nn.Linear(hidden_size, hidden_size * 3) self.img_attn_q_norm = RMSNorm(head_dim, elementwise_affine=True, eps=1e-6) self.img_attn_k_norm = RMSNorm(head_dim, elementwise_affine=True, eps=1e-6) - self.img_attn_proj = nn.Linear(hidden_size, hidden_size, bias=qkv_bias) + self.img_attn_proj = nn.Linear(hidden_size, hidden_size) - self.txt_attn_qkv = nn.Linear(hidden_size, hidden_size * 3, bias=qkv_bias) + self.txt_attn_qkv = nn.Linear(hidden_size, hidden_size * 3) self.txt_attn_q_norm = RMSNorm(head_dim, elementwise_affine=True, eps=1e-6) self.txt_attn_k_norm = RMSNorm(head_dim, elementwise_affine=True, eps=1e-6) - self.txt_attn_proj = nn.Linear(hidden_size, hidden_size, bias=qkv_bias) + self.txt_attn_proj = nn.Linear(hidden_size, hidden_size) self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) self.ff = FeedForward(hidden_size, mult=mlp_width_ratio, activation_fn="gelu-approximate") @@ -620,7 +747,6 @@ def __init__( mm_double_blocks_depth: int = 20, mm_single_blocks_depth: int = 40, rope_dim_list: List[int] = [16, 56, 56], - qkv_bias: bool = True, qk_norm: str = "rms_norm", guidance_embed: bool = True, text_states_dim: int = 4096, @@ -658,7 +784,6 @@ def __init__( num_attention_heads, mlp_width_ratio=mlp_width_ratio, qk_norm=qk_norm, - qkv_bias=qkv_bias, ) for _ in range(mm_double_blocks_depth) ] @@ -667,8 +792,8 @@ def __init__( self.single_transformer_blocks = nn.ModuleList( [ HunyuanVideoSingleTransformerBlock( - inner_dim, num_attention_heads, + attention_head_dim, mlp_width_ratio=mlp_width_ratio, qk_norm=qk_norm, ) @@ -728,20 +853,17 @@ def forward( ] hidden_states, encoder_hidden_states = block(*double_block_args) - - hidden_states = torch.cat((hidden_states, encoder_hidden_states), dim=1) for block in self.single_transformer_blocks: single_block_args = [ hidden_states, + encoder_hidden_states, temb, txt_seq_len, (freqs_cos, freqs_sin), ] - hidden_states = block(*single_block_args) - - hidden_states = hidden_states[:, :img_seq_len, ...] + hidden_states, encoder_hidden_states = block(*single_block_args) hidden_states = self.norm_out(hidden_states, temb) hidden_states = self.proj_out(hidden_states) From cb4fc37602ee52223d4141885a9ba6b7a2ae6671 Mon Sep 17 00:00:00 2001 From: Aryan Date: Wed, 11 Dec 2024 16:55:55 +0100 Subject: [PATCH 21/58] refactor attention processor --- .../transformers/transformer_hunyuan_video.py | 43 +++++++------------ 1 file changed, 16 insertions(+), 27 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_hunyuan_video.py b/src/diffusers/models/transformers/transformer_hunyuan_video.py index beb22244acea..77de3c01b8e4 100644 --- a/src/diffusers/models/transformers/transformer_hunyuan_video.py +++ b/src/diffusers/models/transformers/transformer_hunyuan_video.py @@ -166,8 +166,6 @@ def __call__( attention_mask: Optional[torch.Tensor] = None, image_rotary_emb: Optional[torch.Tensor] = None, ) -> torch.Tensor: - batch_size, _, _ = hidden_states.shape - if attn.add_q_proj is None and encoder_hidden_states is not None: hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1) @@ -175,12 +173,9 @@ def __call__( key = attn.to_k(hidden_states) value = attn.to_v(hidden_states) - inner_dim = key.shape[-1] - head_dim = inner_dim // attn.heads - - query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2) + key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2) + value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2) if attn.norm_q is not None: query = attn.norm_q(query) @@ -204,33 +199,27 @@ def __call__( key = apply_rotary_emb(key, image_rotary_emb) if attn.add_q_proj is not None and encoder_hidden_states is not None: - encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states) - encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) - encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) - - encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view( - batch_size, -1, attn.heads, head_dim - ).transpose(1, 2) - encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view( - batch_size, -1, attn.heads, head_dim - ).transpose(1, 2) - encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view( - batch_size, -1, attn.heads, head_dim - ).transpose(1, 2) + encoder_query = attn.add_q_proj(encoder_hidden_states) + encoder_key = attn.add_k_proj(encoder_hidden_states) + encoder_value = attn.add_v_proj(encoder_hidden_states) + + encoder_query = encoder_query.unflatten(2, (attn.heads, -1)).transpose(1, 2) + encoder_key = encoder_key.unflatten(2, (attn.heads, -1)).transpose(1, 2) + encoder_value = encoder_value.unflatten(2, (attn.heads, -1)).transpose(1, 2) if attn.norm_added_q is not None: - encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj) + encoder_query = attn.norm_added_q(encoder_query) if attn.norm_added_k is not None: - encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj) + encoder_key = attn.norm_added_k(encoder_key) - query = torch.cat([query, encoder_hidden_states_query_proj], dim=2) - key = torch.cat([key, encoder_hidden_states_key_proj], dim=2) - value = torch.cat([value, encoder_hidden_states_value_proj], dim=2) + query = torch.cat([query, encoder_query], dim=2) + key = torch.cat([key, encoder_key], dim=2) + value = torch.cat([value, encoder_value], dim=2) hidden_states = F.scaled_dot_product_attention( query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False ) - hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.transpose(1, 2).flatten(2, 3) hidden_states = hidden_states.to(query.dtype) if encoder_hidden_states is not None: From 1e80f7c2e4217269cdc0e91c04ce3d1993a4a002 Mon Sep 17 00:00:00 2001 From: Aryan Date: Wed, 11 Dec 2024 17:05:23 +0100 Subject: [PATCH 22/58] make style --- scripts/convert_hunyuan_video_to_diffusers.py | 4 +- .../autoencoder_kl_hunyuan_video.py | 14 ++++- .../transformers/transformer_hunyuan_video.py | 59 +++++++++++-------- 3 files changed, 50 insertions(+), 27 deletions(-) diff --git a/scripts/convert_hunyuan_video_to_diffusers.py b/scripts/convert_hunyuan_video_to_diffusers.py index c66cb50cde2a..3efcd95e9791 100644 --- a/scripts/convert_hunyuan_video_to_diffusers.py +++ b/scripts/convert_hunyuan_video_to_diffusers.py @@ -26,7 +26,7 @@ def remap_single_transformer_blocks_(key, state_dict): state_dict[f"{new_key}.attn.to_k.weight"] = k state_dict[f"{new_key}.attn.to_v.weight"] = v state_dict[f"{new_key}.proj_mlp.weight"] = mlp - + elif "linear1.bias" in key: linear1_bias = state_dict.pop(key) split_size = (hidden_size, hidden_size, hidden_size, linear1_bias.size(0) - 3 * hidden_size) @@ -36,7 +36,7 @@ def remap_single_transformer_blocks_(key, state_dict): state_dict[f"{new_key}.attn.to_k.bias"] = k_bias state_dict[f"{new_key}.attn.to_v.bias"] = v_bias state_dict[f"{new_key}.proj_mlp.bias"] = mlp_bias - + else: new_key = key.replace("single_blocks", "single_transformer_blocks") new_key = new_key.replace("linear2", "proj_out") diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py b/src/diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py index ee1b958f404a..dd4502115896 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py @@ -1043,8 +1043,18 @@ def __init__( in_channels: int = 3, out_channels: int = 3, latent_channels: int = 16, - down_block_types: Tuple[str, ...] = ("DownEncoderBlockCausal3D", "DownEncoderBlockCausal3D", "DownEncoderBlockCausal3D", "DownEncoderBlockCausal3D",), - up_block_types: Tuple[str, ...] = ("UpDecoderBlockCausal3D", "UpDecoderBlockCausal3D", "UpDecoderBlockCausal3D", "UpDecoderBlockCausal3D",), + down_block_types: Tuple[str, ...] = ( + "DownEncoderBlockCausal3D", + "DownEncoderBlockCausal3D", + "DownEncoderBlockCausal3D", + "DownEncoderBlockCausal3D", + ), + up_block_types: Tuple[str, ...] = ( + "UpDecoderBlockCausal3D", + "UpDecoderBlockCausal3D", + "UpDecoderBlockCausal3D", + "UpDecoderBlockCausal3D", + ), block_out_channels: Tuple[int] = (128, 256, 512, 512), layers_per_block: int = 2, act_fn: str = "silu", diff --git a/src/diffusers/models/transformers/transformer_hunyuan_video.py b/src/diffusers/models/transformers/transformer_hunyuan_video.py index 77de3c01b8e4..3ffc428856c6 100644 --- a/src/diffusers/models/transformers/transformer_hunyuan_video.py +++ b/src/diffusers/models/transformers/transformer_hunyuan_video.py @@ -156,7 +156,9 @@ def apply_rotary_emb( class HunyuanVideoAttnProcessor2_0: def __init__(self): if not hasattr(F, "scaled_dot_product_attention"): - raise ImportError("HunyuanVideoAttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0.") + raise ImportError( + "HunyuanVideoAttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0." + ) def __call__( self, @@ -186,18 +188,24 @@ def __call__( from ..embeddings import apply_rotary_emb if attn.add_q_proj is None and encoder_hidden_states is not None: - query = torch.cat([ - apply_rotary_emb(query[:, :, :-encoder_hidden_states.shape[1]], image_rotary_emb), - query[:, :, -encoder_hidden_states.shape[1]:], - ], dim=2) - key = torch.cat([ - apply_rotary_emb(key[:, :, :-encoder_hidden_states.shape[1]], image_rotary_emb), - key[:, :, -encoder_hidden_states.shape[1]:], - ], dim=2) + query = torch.cat( + [ + apply_rotary_emb(query[:, :, : -encoder_hidden_states.shape[1]], image_rotary_emb), + query[:, :, -encoder_hidden_states.shape[1] :], + ], + dim=2, + ) + key = torch.cat( + [ + apply_rotary_emb(key[:, :, : -encoder_hidden_states.shape[1]], image_rotary_emb), + key[:, :, -encoder_hidden_states.shape[1] :], + ], + dim=2, + ) else: query = apply_rotary_emb(query, image_rotary_emb) key = apply_rotary_emb(key, image_rotary_emb) - + if attn.add_q_proj is not None and encoder_hidden_states is not None: encoder_query = attn.add_q_proj(encoder_hidden_states) encoder_key = attn.add_k_proj(encoder_hidden_states) @@ -224,8 +232,8 @@ def __call__( if encoder_hidden_states is not None: hidden_states, encoder_hidden_states = ( - hidden_states[:, :-encoder_hidden_states.shape[1]], - hidden_states[:, -encoder_hidden_states.shape[1]:], + hidden_states[:, : -encoder_hidden_states.shape[1]], + hidden_states[:, -encoder_hidden_states.shape[1] :], ) if not attn.pre_only: @@ -513,7 +521,7 @@ def __init__( hidden_size = num_attention_heads * attention_head_dim mlp_hidden_dim = int(hidden_size * mlp_width_ratio) - + self.hidden_size = hidden_size self.heads_num = num_attention_heads self.mlp_hidden_dim = mlp_hidden_dim @@ -546,14 +554,17 @@ def forward( ) -> torch.Tensor: text_seq_length = encoder_hidden_states.shape[1] hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1) - + residual = hidden_states - + norm_hidden_states, gate = self.norm(hidden_states, emb=temb) mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states)) - - norm_hidden_states, norm_encoder_hidden_states = norm_hidden_states[:, :-text_seq_length, :], norm_hidden_states[:, -text_seq_length:, :] - + + norm_hidden_states, norm_encoder_hidden_states = ( + norm_hidden_states[:, :-text_seq_length, :], + norm_hidden_states[:, -text_seq_length:, :], + ) + # qkv, mlp = torch.split(self.linear1(norm_hidden_states), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1) # q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num) @@ -584,7 +595,10 @@ def forward( hidden_states = gate.unsqueeze(1) * self.proj_out(hidden_states) hidden_states = hidden_states + residual - hidden_states, encoder_hidden_states = hidden_states[:, :-text_seq_length, :], hidden_states[:, -text_seq_length:, :] + hidden_states, encoder_hidden_states = ( + hidden_states[:, :-text_seq_length, :], + hidden_states[:, -text_seq_length:, :], + ) return hidden_states, encoder_hidden_states @@ -631,7 +645,7 @@ def forward( norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context( encoder_hidden_states, emb=temb ) - + img_qkv = self.img_attn_qkv(norm_hidden_states) img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num) # Apply QK-Norm if needed @@ -645,7 +659,7 @@ def forward( img_qq.shape == img_q.shape and img_kk.shape == img_k.shape ), f"img_kk: {img_qq.shape}, img_q: {img_q.shape}, img_kk: {img_kk.shape}, img_k: {img_k.shape}" img_q, img_k = img_qq, img_kk - + txt_qkv = self.txt_attn_qkv(norm_encoder_hidden_states) txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num) txt_q = self.txt_attn_q_norm(txt_q).to(txt_v) @@ -830,7 +844,6 @@ def forward( encoder_hidden_states = self.txt_in(encoder_hidden_states, timestep, encoder_attention_mask) txt_seq_len = encoder_hidden_states.shape[1] - img_seq_len = hidden_states.shape[1] freqs_cis = (freqs_cos, freqs_sin) if freqs_cos is not None else None for _, block in enumerate(self.transformer_blocks): @@ -842,7 +855,7 @@ def forward( ] hidden_states, encoder_hidden_states = block(*double_block_args) - + for block in self.single_transformer_blocks: single_block_args = [ hidden_states, From a9bd457864b7cf2bed707c4d7c2cf1e7c8b2f794 Mon Sep 17 00:00:00 2001 From: Aryan Date: Wed, 11 Dec 2024 17:25:36 +0100 Subject: [PATCH 23/58] cleanup a bit --- .../transformers/transformer_hunyuan_video.py | 214 ++++++++---------- 1 file changed, 99 insertions(+), 115 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_hunyuan_video.py b/src/diffusers/models/transformers/transformer_hunyuan_video.py index 3ffc428856c6..4ef8144c0772 100644 --- a/src/diffusers/models/transformers/transformer_hunyuan_video.py +++ b/src/diffusers/models/transformers/transformer_hunyuan_video.py @@ -13,6 +13,7 @@ # limitations under the License. import math +from functools import partial from typing import Dict, List, Optional, Tuple, Union import torch @@ -21,8 +22,9 @@ from einops import rearrange from ...configuration_utils import ConfigMixin, register_to_config +from ...utils import is_torch_version from ..attention import FeedForward -from ..attention_processor import Attention +from ..attention_processor import Attention, AttentionProcessor from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle, RMSNorm @@ -514,13 +516,13 @@ def __init__( self, num_attention_heads: int, attention_head_dim: int, - mlp_width_ratio: float = 4.0, + mlp_ratio: float = 4.0, qk_norm: str = "rms_norm", ): super().__init__() hidden_size = num_attention_heads * attention_head_dim - mlp_hidden_dim = int(hidden_size * mlp_width_ratio) + mlp_hidden_dim = int(hidden_size * mlp_ratio) self.hidden_size = hidden_size self.heads_num = num_attention_heads @@ -549,7 +551,6 @@ def forward( hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, temb: torch.Tensor, - txt_len: int, image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, ) -> torch.Tensor: text_seq_length = encoder_hidden_states.shape[1] @@ -565,25 +566,6 @@ def forward( norm_hidden_states[:, -text_seq_length:, :], ) - # qkv, mlp = torch.split(self.linear1(norm_hidden_states), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1) - # q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num) - - # # Apply QK-Norm if needed. - # q = self.q_norm(q).to(v) - # k = self.k_norm(k).to(v) - - # if image_rotary_emb is not None: - # img_q, txt_q = q[:, :-txt_len, :, :], q[:, -txt_len:, :, :] - # img_k, txt_k = k[:, :-txt_len, :, :], k[:, -txt_len:, :, :] - # img_qq, img_kk = apply_rotary_emb(img_q, img_k, image_rotary_emb, head_first=False) - # img_q, img_k = img_qq, img_kk - # q = torch.cat((img_q, txt_q), dim=1) - # k = torch.cat((img_k, txt_k), dim=1) - - # attn = attention(q, k, v) - # output = self.linear2(torch.cat((attn, self.act_mlp(mlp)), 2)) - # output = hidden_states + output * gate.unsqueeze(1) - attn_output, context_attn_output = self.attn( hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states, @@ -607,7 +589,7 @@ def __init__( self, hidden_size: int, heads_num: int, - mlp_width_ratio: float, + mlp_ratio: float, qk_norm: str = "rms_norm", ): super().__init__() @@ -629,10 +611,10 @@ def __init__( self.txt_attn_proj = nn.Linear(hidden_size, hidden_size) self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) - self.ff = FeedForward(hidden_size, mult=mlp_width_ratio, activation_fn="gelu-approximate") + self.ff = FeedForward(hidden_size, mult=mlp_ratio, activation_fn="gelu-approximate") self.norm2_context = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) - self.ff_context = FeedForward(hidden_size, mult=mlp_width_ratio, activation_fn="gelu-approximate") + self.ff_context = FeedForward(hidden_size, mult=mlp_ratio, activation_fn="gelu-approximate") def forward( self, @@ -690,70 +672,23 @@ def forward( class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin): - """ - HunyuanVideo Transformer backbone - - Inherited from ModelMixin and ConfigMixin for compatibility with diffusers' sampler StableDiffusionPipeline. - - Reference: [1] Flux.1: https://github.com/black-forest-labs/flux [2] MMDiT: http://arxiv.org/abs/2403.03206 - - Parameters ---------- args: argparse.Namespace - The arguments parsed by argparse. - patch_size: list - The size of the patch. - in_channels: int - The number of input channels. - out_channels: int - The number of output channels. - hidden_size: int - The hidden size of the transformer backbone. - heads_num: int - The number of attention heads. - mlp_width_ratio: float - The ratio of the hidden size of the MLP in the transformer block. - mlp_act_type: str - The activation function of the MLP in the transformer block. - depth_double_blocks: int - The number of transformer blocks in the double blocks. - depth_single_blocks: int - The number of transformer blocks in the single blocks. - rope_dim_list: list - The dimension of the rotary embedding for t, h, w. - qkv_bias: bool - Whether to use bias in the qkv linear layer. - qk_norm: bool - Whether to use qk norm. - qk_norm_type: str - The type of qk norm. - guidance_embed: bool - Whether to use guidance embedding for distillation. - text_projection: str - The type of the text projection, default is single_refiner. - use_attention_mask: bool - Whether to use attention mask for text encoder. - dtype: torch.dtype - The dtype of the model. - device: torch.device - The device of the model. - """ - @register_to_config def __init__( self, - patch_size: int = 2, - patch_size_t: int = 1, in_channels: int = 16, out_channels: int = 16, num_attention_heads: int = 24, attention_head_dim: int = 128, - mlp_width_ratio: float = 4.0, - mm_double_blocks_depth: int = 20, - mm_single_blocks_depth: int = 40, + num_layers: int = 20, + num_single_layers: int = 40, + mlp_ratio: float = 4.0, + patch_size: int = 2, + patch_size_t: int = 1, rope_dim_list: List[int] = [16, 56, 56], qk_norm: str = "rms_norm", guidance_embed: bool = True, - text_states_dim: int = 4096, - text_states_dim_2: int = 768, + text_embed_dim: int = 4096, + text_embed_dim_2: int = 768, ) -> None: super().__init__() @@ -762,51 +697,106 @@ def __init__( self.guidance_embed = guidance_embed self.rope_dim_list = rope_dim_list - if sum(rope_dim_list) != attention_head_dim: - raise ValueError(f"Got {rope_dim_list} but expected positional dim {attention_head_dim}") - # image projection self.img_in = PatchEmbed((patch_size_t, patch_size, patch_size), in_channels, inner_dim) # text projection - self.txt_in = SingleTokenRefiner(text_states_dim, inner_dim, num_attention_heads, depth=2) + self.txt_in = SingleTokenRefiner(text_embed_dim, inner_dim, num_attention_heads, depth=2) # time modulation self.time_in = TimestepEmbedder(inner_dim, nn.SiLU) # text modulation - self.vector_in = MLPEmbedder(text_states_dim_2, inner_dim) + self.vector_in = MLPEmbedder(text_embed_dim_2, inner_dim) # guidance modulation self.guidance_in = TimestepEmbedder(inner_dim, nn.SiLU) self.transformer_blocks = nn.ModuleList( [ - HunyuanVideoTransformerBlock( - inner_dim, - num_attention_heads, - mlp_width_ratio=mlp_width_ratio, - qk_norm=qk_norm, - ) - for _ in range(mm_double_blocks_depth) + HunyuanVideoTransformerBlock(inner_dim, num_attention_heads, mlp_ratio=mlp_ratio, qk_norm=qk_norm) + for _ in range(num_layers) ] ) self.single_transformer_blocks = nn.ModuleList( [ HunyuanVideoSingleTransformerBlock( - num_attention_heads, - attention_head_dim, - mlp_width_ratio=mlp_width_ratio, - qk_norm=qk_norm, + num_attention_heads, attention_head_dim, mlp_ratio=mlp_ratio, qk_norm=qk_norm ) - for _ in range(mm_single_blocks_depth) + for _ in range(num_single_layers) ] ) self.norm_out = AdaLayerNormContinuous(inner_dim, inner_dim, elementwise_affine=False, eps=1e-6) self.proj_out = nn.Linear(inner_dim, patch_size_t * patch_size * patch_size * out_channels) + self.gradient_checkpointing = False + + @property + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors + def attn_processors(self) -> Dict[str, AttentionProcessor]: + r""" + Returns: + `dict` of attention processors: A dictionary containing all attention processors used in the model with + indexed by its weight name. + """ + # set recursively + processors = {} + + def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): + if hasattr(module, "get_processor"): + processors[f"{name}.processor"] = module.get_processor() + + for sub_name, child in module.named_children(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) + + return processors + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor + def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): + r""" + Sets the attention processor to use to compute attention. + + Parameters: + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + for **all** `Attention` layers. + + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. + + """ + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor) + else: + module.set_processor(processor.pop(f"{name}.processor")) + + for sub_name, child in module.named_children(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) + + def _set_gradient_checkpointing(self, module, value=False): + if hasattr(module, "gradient_checkpointing"): + module.gradient_checkpointing = value + def forward( self, hidden_states: torch.Tensor, @@ -843,29 +833,23 @@ def forward( hidden_states = self.img_in(hidden_states) encoder_hidden_states = self.txt_in(encoder_hidden_states, timestep, encoder_attention_mask) - txt_seq_len = encoder_hidden_states.shape[1] + use_reentrant = is_torch_version(">=", "1.11.0") + block_forward = ( + partial(torch.utils.checkpoint.checkpoint, use_reentrant=use_reentrant) + if torch.is_grad_enabled() and self.gradient_checkpointing + else lambda x: x + ) freqs_cis = (freqs_cos, freqs_sin) if freqs_cos is not None else None for _, block in enumerate(self.transformer_blocks): - double_block_args = [ - hidden_states, - encoder_hidden_states, - temb, - freqs_cis, - ] - - hidden_states, encoder_hidden_states = block(*double_block_args) + hidden_states, encoder_hidden_states = block_forward(block)( + hidden_states, encoder_hidden_states, temb, freqs_cis + ) for block in self.single_transformer_blocks: - single_block_args = [ - hidden_states, - encoder_hidden_states, - temb, - txt_seq_len, - (freqs_cos, freqs_sin), - ] - - hidden_states, encoder_hidden_states = block(*single_block_args) + hidden_states, encoder_hidden_states = block_forward(block)( + hidden_states, encoder_hidden_states, temb, freqs_cis + ) hidden_states = self.norm_out(hidden_states, temb) hidden_states = self.proj_out(hidden_states) From f637479b60c62c7e79dc6cca76b199175540feae Mon Sep 17 00:00:00 2001 From: Aryan Date: Wed, 11 Dec 2024 17:45:11 +0100 Subject: [PATCH 24/58] refactor double transformer block attention --- scripts/convert_hunyuan_video_to_diffusers.py | 24 +++ .../transformers/transformer_hunyuan_video.py | 154 +++++++----------- 2 files changed, 84 insertions(+), 94 deletions(-) diff --git a/scripts/convert_hunyuan_video_to_diffusers.py b/scripts/convert_hunyuan_video_to_diffusers.py index 3efcd95e9791..67e154c0cb9a 100644 --- a/scripts/convert_hunyuan_video_to_diffusers.py +++ b/scripts/convert_hunyuan_video_to_diffusers.py @@ -14,6 +14,22 @@ def remap_norm_scale_shift_(key, state_dict): state_dict[key.replace("final_layer.adaLN_modulation.1", "norm_out.linear")] = new_weight +def remap_img_attn_qkv_(key, state_dict): + weight = state_dict.pop(key) + to_q, to_k, to_v = weight.chunk(3, dim=0) + state_dict[key.replace("img_attn_qkv", "attn.to_q")] = to_q + state_dict[key.replace("img_attn_qkv", "attn.to_k")] = to_k + state_dict[key.replace("img_attn_qkv", "attn.to_v")] = to_v + + +def remap_txt_attn_qkv_(key, state_dict): + weight = state_dict.pop(key) + to_q, to_k, to_v = weight.chunk(3, dim=0) + state_dict[key.replace("txt_attn_qkv", "attn.add_q_proj")] = to_q + state_dict[key.replace("txt_attn_qkv", "attn.add_k_proj")] = to_k + state_dict[key.replace("txt_attn_qkv", "attn.add_v_proj")] = to_v + + def remap_single_transformer_blocks_(key, state_dict): hidden_size = 3072 @@ -53,6 +69,12 @@ def remap_single_transformer_blocks_(key, state_dict): # "vector_in.in_layer": "time_text_embed.text_embedder.linear_1", # "vector_in.out_layer": "time_text_embed.text_embedder.linear_2", "double_blocks": "transformer_blocks", + "img_attn_q_norm": "attn.norm_q", + "img_attn_k_norm": "attn.norm_k", + "img_attn_proj": "attn.to_out.0", + "txt_attn_q_norm": "attn.norm_added_q", + "txt_attn_k_norm": "attn.norm_added_k", + "txt_attn_proj": "attn.to_add_out", "img_mod.linear": "norm1.linear", "img_norm1": "norm1.norm", "img_norm2": "norm2", @@ -71,6 +93,8 @@ def remap_single_transformer_blocks_(key, state_dict): TRANSFORMER_SPECIAL_KEYS_REMAP = { "final_layer.adaLN_modulation.1": remap_norm_scale_shift_, + "img_attn_qkv": remap_img_attn_qkv_, + "txt_attn_qkv": remap_txt_attn_qkv_, "single_blocks": remap_single_transformer_blocks_, } diff --git a/src/diffusers/models/transformers/transformer_hunyuan_video.py b/src/diffusers/models/transformers/transformer_hunyuan_video.py index 4ef8144c0772..ee304a40d0c1 100644 --- a/src/diffusers/models/transformers/transformer_hunyuan_video.py +++ b/src/diffusers/models/transformers/transformer_hunyuan_video.py @@ -360,13 +360,13 @@ class IndividualTokenRefinerBlock(nn.Module): def __init__( self, hidden_size, - heads_num, + num_attention_heads: int, mlp_width_ratio: str = 4.0, mlp_drop_rate: float = 0.0, qkv_bias: bool = True, - ): + ) -> None: super().__init__() - self.heads_num = heads_num + self.heads_num = num_attention_heads self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6) self.self_attn_qkv = nn.Linear(hidden_size, hidden_size * 3, bias=qkv_bias) @@ -383,25 +383,25 @@ def __init__( def forward( self, - x: torch.Tensor, - c: torch.Tensor, - attn_mask: torch.Tensor = None, - ): - gate_msa, gate_mlp = self.adaLN_modulation(c).chunk(2, dim=1) + hidden_states: torch.Tensor, + temb: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + gate_msa, gate_mlp = self.adaLN_modulation(temb).chunk(2, dim=1) - norm_x = self.norm1(x) + norm_x = self.norm1(hidden_states) qkv = self.self_attn_qkv(norm_x) q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num) # Self-Attention - attn = attention(q, k, v, attn_mask=attn_mask) + attn = attention(q, k, v, attn_mask=attention_mask) - x = x + self.self_attn_proj(attn) * gate_msa.unsqueeze(1) + hidden_states = hidden_states + self.self_attn_proj(attn) * gate_msa.unsqueeze(1) # FFN Layer - x = x + self.mlp(self.norm2(x)) * gate_mlp.unsqueeze(1) + hidden_states = hidden_states + self.mlp(self.norm2(hidden_states)) * gate_mlp.unsqueeze(1) - return x + return hidden_states class IndividualTokenRefiner(nn.Module): @@ -419,7 +419,7 @@ def __init__( [ IndividualTokenRefinerBlock( hidden_size=hidden_size, - heads_num=heads_num, + num_attention_heads=heads_num, mlp_width_ratio=mlp_width_ratio, mlp_drop_rate=mlp_drop_rate, qkv_bias=qkv_bias, @@ -430,41 +430,34 @@ def __init__( def forward( self, - x: torch.Tensor, - c: torch.LongTensor, - mask: Optional[torch.Tensor] = None, + hidden_states: torch.Tensor, + temb: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, ): self_attn_mask = None - if mask is not None: - batch_size = mask.shape[0] - seq_len = mask.shape[1] - mask = mask.to(x.device).bool() - # batch_size x 1 x seq_len x seq_len - self_attn_mask_1 = mask.view(batch_size, 1, 1, seq_len).repeat(1, 1, seq_len, 1) - # batch_size x 1 x seq_len x seq_len + if attention_mask is not None: + batch_size = attention_mask.shape[0] + seq_len = attention_mask.shape[1] + attention_mask = attention_mask.to(hidden_states.device).bool() + self_attn_mask_1 = attention_mask.view(batch_size, 1, 1, seq_len).repeat(1, 1, seq_len, 1) self_attn_mask_2 = self_attn_mask_1.transpose(2, 3) - # batch_size x 1 x seq_len x seq_len, 1 for broadcasting of heads_num self_attn_mask = (self_attn_mask_1 & self_attn_mask_2).bool() - # avoids self-attention weight being NaN for padding tokens self_attn_mask[:, :, :, 0] = True for block in self.blocks: - x = block(x, c, self_attn_mask) - return x + hidden_states = block(hidden_states, temb, self_attn_mask) + + return hidden_states class SingleTokenRefiner(nn.Module): - """ - A single token refiner block for llm text embedding refine. - """ - def __init__( self, - in_channels, - hidden_size, - num_attention_heads, - depth, - mlp_width_ratio: float = 4.0, + in_channels: int, + hidden_size: int, + num_attention_heads: int, + depth: int, + mlp_ratio: float = 4.0, mlp_drop_rate: float = 0.0, qkv_bias: bool = True, ): @@ -481,7 +474,7 @@ def __init__( hidden_size=hidden_size, heads_num=num_attention_heads, depth=depth, - mlp_width_ratio=mlp_width_ratio, + mlp_width_ratio=mlp_ratio, mlp_drop_rate=mlp_drop_rate, qkv_bias=qkv_bias, ) @@ -587,28 +580,31 @@ def forward( class HunyuanVideoTransformerBlock(nn.Module): def __init__( self, - hidden_size: int, - heads_num: int, + num_attention_heads: int, + attention_head_dim: int, mlp_ratio: float, qk_norm: str = "rms_norm", - ): + ) -> None: super().__init__() - self.heads_num = heads_num - head_dim = hidden_size // heads_num + hidden_size = num_attention_heads * attention_head_dim self.norm1 = AdaLayerNormZero(hidden_size, norm_type="layer_norm") self.norm1_context = AdaLayerNormZero(hidden_size, norm_type="layer_norm") - self.img_attn_qkv = nn.Linear(hidden_size, hidden_size * 3) - self.img_attn_q_norm = RMSNorm(head_dim, elementwise_affine=True, eps=1e-6) - self.img_attn_k_norm = RMSNorm(head_dim, elementwise_affine=True, eps=1e-6) - self.img_attn_proj = nn.Linear(hidden_size, hidden_size) - - self.txt_attn_qkv = nn.Linear(hidden_size, hidden_size * 3) - self.txt_attn_q_norm = RMSNorm(head_dim, elementwise_affine=True, eps=1e-6) - self.txt_attn_k_norm = RMSNorm(head_dim, elementwise_affine=True, eps=1e-6) - self.txt_attn_proj = nn.Linear(hidden_size, hidden_size) + self.attn = Attention( + query_dim=hidden_size, + cross_attention_dim=None, + added_kv_proj_dim=hidden_size, + dim_head=attention_head_dim, + heads=num_attention_heads, + out_dim=hidden_size, + context_pre_only=False, + bias=True, + processor=HunyuanVideoAttnProcessor2_0(), + qk_norm=qk_norm, + eps=1e-6, + ) self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) self.ff = FeedForward(hidden_size, mult=mlp_ratio, activation_fn="gelu-approximate") @@ -627,35 +623,15 @@ def forward( norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context( encoder_hidden_states, emb=temb ) - - img_qkv = self.img_attn_qkv(norm_hidden_states) - img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num) - # Apply QK-Norm if needed - img_q = self.img_attn_q_norm(img_q).to(img_v) - img_k = self.img_attn_k_norm(img_k).to(img_v) - - # Apply RoPE if needed. - if freqs_cis is not None: - img_qq, img_kk = apply_rotary_emb(img_q, img_k, freqs_cis, head_first=False) - assert ( - img_qq.shape == img_q.shape and img_kk.shape == img_k.shape - ), f"img_kk: {img_qq.shape}, img_q: {img_q.shape}, img_kk: {img_kk.shape}, img_k: {img_k.shape}" - img_q, img_k = img_qq, img_kk - - txt_qkv = self.txt_attn_qkv(norm_encoder_hidden_states) - txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num) - txt_q = self.txt_attn_q_norm(txt_q).to(txt_v) - txt_k = self.txt_attn_k_norm(txt_k).to(txt_v) - - q = torch.cat((img_q, txt_q), dim=1) - k = torch.cat((img_k, txt_k), dim=1) - v = torch.cat((img_v, txt_v), dim=1) - attn = attention(q, k, v) - - img_attn, txt_attn = attn[:, : hidden_states.shape[1]], attn[:, hidden_states.shape[1] :] - - hidden_states = hidden_states + self.img_attn_proj(img_attn) * gate_msa.unsqueeze(1) - encoder_hidden_states = encoder_hidden_states + self.txt_attn_proj(txt_attn) * c_gate_msa.unsqueeze(1) + + img_attn, txt_attn = self.attn( + hidden_states=norm_hidden_states, + encoder_hidden_states=norm_encoder_hidden_states, + image_rotary_emb=freqs_cis, + ) + + hidden_states = hidden_states + img_attn * gate_msa.unsqueeze(1) + encoder_hidden_states = encoder_hidden_states + txt_attn * c_gate_msa.unsqueeze(1) norm_hidden_states = self.norm2(hidden_states) norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] @@ -686,7 +662,7 @@ def __init__( patch_size_t: int = 1, rope_dim_list: List[int] = [16, 56, 56], qk_norm: str = "rms_norm", - guidance_embed: bool = True, + guidance_embeds: bool = True, text_embed_dim: int = 4096, text_embed_dim_2: int = 768, ) -> None: @@ -694,7 +670,6 @@ def __init__( inner_dim = num_attention_heads * attention_head_dim out_channels = out_channels or in_channels - self.guidance_embed = guidance_embed self.rope_dim_list = rope_dim_list # image projection @@ -714,7 +689,7 @@ def __init__( self.transformer_blocks = nn.ModuleList( [ - HunyuanVideoTransformerBlock(inner_dim, num_attention_heads, mlp_ratio=mlp_ratio, qk_norm=qk_norm) + HunyuanVideoTransformerBlock(num_attention_heads, attention_head_dim, mlp_ratio=mlp_ratio, qk_norm=qk_norm) for _ in range(num_layers) ] ) @@ -816,18 +791,9 @@ def forward( post_patch_height = height // p post_patch_width = width // p - # Prepare modulation vectors. temb = self.time_in(timestep) - - # text modulation temb = temb + self.vector_in(encoder_hidden_states_2) - - # guidance modulation - if self.guidance_embed: - if guidance is None: - raise ValueError("Didn't get guidance strength for guidance distilled model.") - - temb = temb + self.guidance_in(guidance) + temb = temb + self.guidance_in(guidance) # Embed image and text. hidden_states = self.img_in(hidden_states) From 19b2d565c9a8d6aeb8fa771ab59e9ae16e1e2e38 Mon Sep 17 00:00:00 2001 From: Aryan Date: Wed, 11 Dec 2024 17:47:13 +0100 Subject: [PATCH 25/58] update mochi attn proc --- src/diffusers/models/attention_processor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 46d884616602..59f2c432f319 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -3847,7 +3847,7 @@ def apply_rotary_emb(x, freqs_cos, freqs_sin): # dropout hidden_states = attn.to_out[1](hidden_states) - if hasattr(attn, "to_add_out"): + if attn.context_pre_only is not None and not attn.context_pre_only: encoder_hidden_states = attn.to_add_out(encoder_hidden_states) return hidden_states, encoder_hidden_states From c1faf0d217fa8c5b70d15522c4bab961d46dab63 Mon Sep 17 00:00:00 2001 From: Aryan Date: Wed, 11 Dec 2024 18:13:26 +0100 Subject: [PATCH 26/58] use diffusers attention implementation in all modules; checkpoint for all values matching original --- scripts/convert_hunyuan_video_to_diffusers.py | 11 + .../transformers/transformer_hunyuan_video.py | 239 +++++------------- 2 files changed, 73 insertions(+), 177 deletions(-) diff --git a/scripts/convert_hunyuan_video_to_diffusers.py b/scripts/convert_hunyuan_video_to_diffusers.py index 67e154c0cb9a..2fd54a0fa3b2 100644 --- a/scripts/convert_hunyuan_video_to_diffusers.py +++ b/scripts/convert_hunyuan_video_to_diffusers.py @@ -30,6 +30,14 @@ def remap_txt_attn_qkv_(key, state_dict): state_dict[key.replace("txt_attn_qkv", "attn.add_v_proj")] = to_v +def remap_self_attn_qkv_(key, state_dict): + weight = state_dict.pop(key) + to_q, to_k, to_v = weight.chunk(3, dim=0) + state_dict[key.replace("self_attn_qkv", "attn.to_q")] = to_q + state_dict[key.replace("self_attn_qkv", "attn.to_k")] = to_k + state_dict[key.replace("self_attn_qkv", "attn.to_v")] = to_v + + def remap_single_transformer_blocks_(key, state_dict): hidden_size = 3072 @@ -69,6 +77,7 @@ def remap_single_transformer_blocks_(key, state_dict): # "vector_in.in_layer": "time_text_embed.text_embedder.linear_1", # "vector_in.out_layer": "time_text_embed.text_embedder.linear_2", "double_blocks": "transformer_blocks", + "individual_token_refiner.blocks": "token_refiner.refiner_blocks", "img_attn_q_norm": "attn.norm_q", "img_attn_k_norm": "attn.norm_k", "img_attn_proj": "attn.to_out.0", @@ -83,6 +92,7 @@ def remap_single_transformer_blocks_(key, state_dict): "txt_norm1": "norm1.norm", "txt_norm2": "norm2_context", "txt_mlp": "ff_context", + "self_attn_proj": "attn.to_out.0", "modulation.linear": "norm.linear", "pre_norm": "norm.norm", "final_layer.norm_final": "norm_out.norm", @@ -95,6 +105,7 @@ def remap_single_transformer_blocks_(key, state_dict): "final_layer.adaLN_modulation.1": remap_norm_scale_shift_, "img_attn_qkv": remap_img_attn_qkv_, "txt_attn_qkv": remap_txt_attn_qkv_, + "self_attn_qkv": remap_self_attn_qkv_, "single_blocks": remap_single_transformer_blocks_, } diff --git a/src/diffusers/models/transformers/transformer_hunyuan_video.py b/src/diffusers/models/transformers/transformer_hunyuan_video.py index ee304a40d0c1..83aac1e68e9e 100644 --- a/src/diffusers/models/transformers/transformer_hunyuan_video.py +++ b/src/diffusers/models/transformers/transformer_hunyuan_video.py @@ -19,7 +19,6 @@ import torch import torch.nn as nn import torch.nn.functional as F -from einops import rearrange from ...configuration_utils import ConfigMixin, register_to_config from ...utils import is_torch_version @@ -27,132 +26,7 @@ from ..attention_processor import Attention, AttentionProcessor from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin -from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle, RMSNorm - - -def attention(q, k, v, attn_mask=None): - q = q.transpose(1, 2) - k = k.transpose(1, 2) - v = v.transpose(1, 2) - - if attn_mask is not None and attn_mask.dtype != torch.bool: - attn_mask = attn_mask.to(q.dtype) - x = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, dropout_p=0, is_causal=False) - - x = x.transpose(1, 2) - b, s, a, d = x.shape - out = x.reshape(b, s, -1) - return out - - -def reshape_for_broadcast( - freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]], - x: torch.Tensor, - head_first=False, -): - """ - Reshape frequency tensor for broadcasting it with another tensor. - - This function reshapes the frequency tensor to have the same shape as the target tensor 'x' for the purpose of - broadcasting the frequency tensor during element-wise operations. - - Notes: - When using FlashMHAModified, head_first should be False. When using Attention, head_first should be True. - - Args: - freqs_cis (Union[torch.Tensor, Tuple[torch.Tensor]]): Frequency tensor to be reshaped. - x (torch.Tensor): Target tensor for broadcasting compatibility. - head_first (bool): head dimension first (except batch dim) or not. - - Returns: - torch.Tensor: Reshaped frequency tensor. - - Raises: - AssertionError: If the frequency tensor doesn't match the expected shape. AssertionError: If the target tensor - 'x' doesn't have the expected number of dimensions. - """ - ndim = x.ndim - assert 0 <= 1 < ndim - - if isinstance(freqs_cis, tuple): - # freqs_cis: (cos, sin) in real space - if head_first: - assert freqs_cis[0].shape == ( - x.shape[-2], - x.shape[-1], - ), f"freqs_cis shape {freqs_cis[0].shape} does not match x shape {x.shape}" - shape = [d if i == ndim - 2 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] - else: - assert freqs_cis[0].shape == ( - x.shape[1], - x.shape[-1], - ), f"freqs_cis shape {freqs_cis[0].shape} does not match x shape {x.shape}" - shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] - return freqs_cis[0].view(*shape), freqs_cis[1].view(*shape) - else: - # freqs_cis: values in complex space - if head_first: - assert freqs_cis.shape == ( - x.shape[-2], - x.shape[-1], - ), f"freqs_cis shape {freqs_cis.shape} does not match x shape {x.shape}" - shape = [d if i == ndim - 2 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] - else: - assert freqs_cis.shape == ( - x.shape[1], - x.shape[-1], - ), f"freqs_cis shape {freqs_cis.shape} does not match x shape {x.shape}" - shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] - return freqs_cis.view(*shape) - - -def rotate_half(x): - x_real, x_imag = x.float().reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2] - return torch.stack([-x_imag, x_real], dim=-1).flatten(3) - - -def apply_rotary_emb( - xq: torch.Tensor, - xk: torch.Tensor, - freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], - head_first: bool = False, -) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Apply rotary embeddings to input tensors using the given frequency tensor. - - This function applies rotary embeddings to the given query 'xq' and key 'xk' tensors using the provided frequency - tensor 'freqs_cis'. The input tensors are reshaped as complex numbers, and the frequency tensor is reshaped for - broadcasting compatibility. The resulting tensors contain rotary embeddings and are returned as real tensors. - - Args: - xq (torch.Tensor): Query tensor to apply rotary embeddings. [B, S, H, D] - xk (torch.Tensor): Key tensor to apply rotary embeddings. [B, S, H, D] - freqs_cis (torch.Tensor or tuple): Precomputed frequency tensor for complex exponential. - head_first (bool): head dimension first (except batch dim) or not. - - Returns: - Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings. - - """ - xk_out = None - if isinstance(freqs_cis, tuple): - cos, sin = reshape_for_broadcast(freqs_cis, xq, head_first) # [S, D] - cos, sin = cos.to(xq.device), sin.to(xq.device) - # real * cos - imag * sin - # imag * cos + real * sin - xq_out = (xq.float() * cos + rotate_half(xq.float()) * sin).type_as(xq) - xk_out = (xk.float() * cos + rotate_half(xk.float()) * sin).type_as(xk) - else: - # view_as_complex will pack [..., D/2, 2](real) to [..., D/2](complex) - xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # [B, S, H, D//2] - freqs_cis = reshape_for_broadcast(freqs_cis, xq_, head_first).to(xq.device) # [S, D//2] --> [1, S, 1, D//2] - # (real, imag) * (cos, sin) = (real * cos - imag * sin, imag * cos + real * sin) - # view_as_real will expand [..., D/2](complex) to [..., D/2, 2](real) - xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3).type_as(xq) - xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) # [B, S, H, D//2] - xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3).type_as(xk) - - return xq_out, xk_out +from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle class HunyuanVideoAttnProcessor2_0: @@ -359,18 +233,25 @@ def forward(self, t): class IndividualTokenRefinerBlock(nn.Module): def __init__( self, - hidden_size, num_attention_heads: int, + attention_head_dim: int, mlp_width_ratio: str = 4.0, mlp_drop_rate: float = 0.0, qkv_bias: bool = True, ) -> None: super().__init__() - self.heads_num = num_attention_heads + + hidden_size = num_attention_heads * attention_head_dim self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6) - self.self_attn_qkv = nn.Linear(hidden_size, hidden_size * 3, bias=qkv_bias) - self.self_attn_proj = nn.Linear(hidden_size, hidden_size, bias=qkv_bias) + + self.attn = Attention( + query_dim=hidden_size, + cross_attention_dim=None, + heads=num_attention_heads, + dim_head=attention_head_dim, + bias=True, + ) self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6) @@ -389,16 +270,15 @@ def forward( ) -> torch.Tensor: gate_msa, gate_mlp = self.adaLN_modulation(temb).chunk(2, dim=1) - norm_x = self.norm1(hidden_states) - qkv = self.self_attn_qkv(norm_x) - q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num) + norm_hidden_states = self.norm1(hidden_states) - # Self-Attention - attn = attention(q, k, v, attn_mask=attention_mask) - - hidden_states = hidden_states + self.self_attn_proj(attn) * gate_msa.unsqueeze(1) + attn_output = self.attn( + hidden_states=norm_hidden_states, + encoder_hidden_states=None, + attention_mask=attention_mask, + ) + hidden_states = hidden_states + attn_output * gate_msa.unsqueeze(1) - # FFN Layer hidden_states = hidden_states + self.mlp(self.norm2(hidden_states)) * gate_mlp.unsqueeze(1) return hidden_states @@ -407,24 +287,25 @@ def forward( class IndividualTokenRefiner(nn.Module): def __init__( self, - hidden_size, - heads_num, - depth, + num_attention_heads: int, + attention_head_dim: int, + num_layers: int, mlp_width_ratio: float = 4.0, mlp_drop_rate: float = 0.0, qkv_bias: bool = True, ): super().__init__() - self.blocks = nn.ModuleList( + + self.refiner_blocks = nn.ModuleList( [ IndividualTokenRefinerBlock( - hidden_size=hidden_size, - num_attention_heads=heads_num, + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, mlp_width_ratio=mlp_width_ratio, mlp_drop_rate=mlp_drop_rate, qkv_bias=qkv_bias, ) - for _ in range(depth) + for _ in range(num_layers) ] ) @@ -444,9 +325,9 @@ def forward( self_attn_mask = (self_attn_mask_1 & self_attn_mask_2).bool() self_attn_mask[:, :, :, 0] = True - for block in self.blocks: + for block in self.refiner_blocks: hidden_states = block(hidden_states, temb, self_attn_mask) - + return hidden_states @@ -454,26 +335,25 @@ class SingleTokenRefiner(nn.Module): def __init__( self, in_channels: int, - hidden_size: int, num_attention_heads: int, - depth: int, + attention_head_dim: int, + num_layers: int, mlp_ratio: float = 4.0, mlp_drop_rate: float = 0.0, qkv_bias: bool = True, ): super().__init__() - self.input_embedder = nn.Linear(in_channels, hidden_size, bias=True) + hidden_size = num_attention_heads * attention_head_dim - # Build timestep embedding layer + self.input_embedder = nn.Linear(in_channels, hidden_size, bias=True) self.t_embedder = TimestepEmbedder(hidden_size, nn.SiLU) - # Build context embedding layer self.c_embedder = TextProjection(in_channels, hidden_size, nn.SiLU) - self.individual_token_refiner = IndividualTokenRefiner( - hidden_size=hidden_size, - heads_num=num_attention_heads, - depth=depth, + self.token_refiner = IndividualTokenRefiner( + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + num_layers=num_layers, mlp_width_ratio=mlp_ratio, mlp_drop_rate=mlp_drop_rate, qkv_bias=qkv_bias, @@ -481,27 +361,27 @@ def __init__( def forward( self, - x: torch.Tensor, - t: torch.LongTensor, - mask: Optional[torch.LongTensor] = None, - ): - original_dtype = x.dtype - timestep_aware_representations = self.t_embedder(t) + hidden_states: torch.Tensor, + timestep: torch.LongTensor, + attention_mask: Optional[torch.LongTensor] = None, + ) -> torch.Tensor: + original_dtype = hidden_states.dtype + temb = self.t_embedder(timestep) - if mask is None: - context_aware_representations = x.mean(dim=1) + if attention_mask is None: + pooled_projections = hidden_states.mean(dim=1) else: - mask_float = mask.float().unsqueeze(-1) # [b, s1, 1] - context_aware_representations = (x * mask_float).sum(dim=1) / mask_float.sum(dim=1) - context_aware_representations = context_aware_representations.to(original_dtype) + mask_float = attention_mask.float().unsqueeze(-1) # [b, s1, 1] + pooled_projections = (hidden_states * mask_float).sum(dim=1) / mask_float.sum(dim=1) + pooled_projections = pooled_projections.to(original_dtype) - context_aware_representations = self.c_embedder(context_aware_representations) - c = timestep_aware_representations + context_aware_representations + pooled_projections = self.c_embedder(pooled_projections) + emb = temb + pooled_projections - x = self.input_embedder(x) - x = self.individual_token_refiner(x, c, mask) + hidden_states = self.input_embedder(hidden_states) + hidden_states = self.token_refiner(hidden_states, emb, attention_mask) - return x + return hidden_states class HunyuanVideoSingleTransformerBlock(nn.Module): @@ -623,13 +503,13 @@ def forward( norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context( encoder_hidden_states, emb=temb ) - + img_attn, txt_attn = self.attn( hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states, image_rotary_emb=freqs_cis, ) - + hidden_states = hidden_states + img_attn * gate_msa.unsqueeze(1) encoder_hidden_states = encoder_hidden_states + txt_attn * c_gate_msa.unsqueeze(1) @@ -657,6 +537,7 @@ def __init__( attention_head_dim: int = 128, num_layers: int = 20, num_single_layers: int = 40, + num_refiner_layers: int = 2, mlp_ratio: float = 4.0, patch_size: int = 2, patch_size_t: int = 1, @@ -676,7 +557,9 @@ def __init__( self.img_in = PatchEmbed((patch_size_t, patch_size, patch_size), in_channels, inner_dim) # text projection - self.txt_in = SingleTokenRefiner(text_embed_dim, inner_dim, num_attention_heads, depth=2) + self.txt_in = SingleTokenRefiner( + text_embed_dim, num_attention_heads, attention_head_dim, num_layers=num_refiner_layers + ) # time modulation self.time_in = TimestepEmbedder(inner_dim, nn.SiLU) @@ -689,7 +572,9 @@ def __init__( self.transformer_blocks = nn.ModuleList( [ - HunyuanVideoTransformerBlock(num_attention_heads, attention_head_dim, mlp_ratio=mlp_ratio, qk_norm=qk_norm) + HunyuanVideoTransformerBlock( + num_attention_heads, attention_head_dim, mlp_ratio=mlp_ratio, qk_norm=qk_norm + ) for _ in range(num_layers) ] ) From 6915d62ca8b9d003ef963395843080e98686ab7a Mon Sep 17 00:00:00 2001 From: Aryan Date: Thu, 12 Dec 2024 20:40:38 +0100 Subject: [PATCH 27/58] remove helper functions in vae --- scripts/convert_hunyuan_video_to_diffusers.py | 37 ++- .../autoencoder_kl_hunyuan_video.py | 263 ++++++------------ .../transformers/transformer_hunyuan_video.py | 4 +- .../hunyuan_video/pipeline_hunyuan_video.py | 26 +- 4 files changed, 116 insertions(+), 214 deletions(-) diff --git a/scripts/convert_hunyuan_video_to_diffusers.py b/scripts/convert_hunyuan_video_to_diffusers.py index 2fd54a0fa3b2..d1ab60a767ce 100644 --- a/scripts/convert_hunyuan_video_to_diffusers.py +++ b/scripts/convert_hunyuan_video_to_diffusers.py @@ -4,7 +4,7 @@ import torch from accelerate import init_empty_weights -from diffusers import HunyuanVideoTransformer3DModel +from diffusers import AutoencoderKLHunyuanVideo, HunyuanVideoTransformer3DModel def remap_norm_scale_shift_(key, state_dict): @@ -109,7 +109,9 @@ def remap_single_transformer_blocks_(key, state_dict): "single_blocks": remap_single_transformer_blocks_, } -VAE_KEYS_RENAME_DICT = {} +VAE_KEYS_RENAME_DICT = { + +} VAE_SPECIAL_KEYS_REMAP = {} @@ -151,14 +153,37 @@ def convert_transformer(ckpt_path: str): return transformer +def convert_vae(ckpt_path: str): + original_state_dict = get_state_dict(torch.load(ckpt_path, map_location="cpu", weights_only=True)) + + with init_empty_weights(): + vae = AutoencoderKLHunyuanVideo() + + for key in list(original_state_dict.keys()): + new_key = key[:] + for replace_key, rename_key in VAE_KEYS_RENAME_DICT.items(): + new_key = new_key.replace(replace_key, rename_key) + update_state_dict_(original_state_dict, key, new_key) + + for key in list(original_state_dict.keys()): + for special_key, handler_fn_inplace in VAE_SPECIAL_KEYS_REMAP.items(): + if special_key not in key: + continue + handler_fn_inplace(key, original_state_dict) + + vae.load_state_dict(original_state_dict, strict=True, assign=True) + return vae + + def get_args(): parser = argparse.ArgumentParser() parser.add_argument( "--transformer_ckpt_path", type=str, default=None, help="Path to original transformer checkpoint" ) + parser.add_argument("--vae_ckpt_path", type=str, default=None, help="Path to original VAE checkpoint") parser.add_argument("--save_pipeline", action="store_true") parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved") - parser.add_argument("--dtype", default="bf16", help="Torch dtype to save the model in.") + parser.add_argument("--dtype", default="bf16", help="Torch dtype to save the transformer in.") return parser.parse_args() @@ -180,5 +205,11 @@ def get_args(): if args.transformer_ckpt_path is not None: transformer = convert_transformer(args.transformer_ckpt_path) + transformer = transformer.to(dtype=dtype) if not args.save_pipeline: transformer.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB") + + if args.vae_ckpt_path is not None: + vae = convert_vae(args.vae_ckpt_path) + if not args.save_pipeline: + vae.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB") diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py b/src/diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py index dd4502115896..6965c7fcc15c 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py @@ -22,7 +22,7 @@ from einops import rearrange from ...configuration_utils import ConfigMixin, register_to_config -from ...utils import logging +from ...utils import logging, is_torch_version from ...utils.accelerate_utils import apply_forward_hook from ..activations import get_activation from ..attention_processor import Attention, SpatialNorm @@ -141,10 +141,10 @@ def __init__( def forward( self, - hidden_states: torch.FloatTensor, + hidden_states: torch.Tensor, output_size: Optional[int] = None, scale: float = 1.0, - ) -> torch.FloatTensor: + ) -> torch.Tensor: assert hidden_states.shape[1] == self.channels if self.norm is not None: @@ -246,7 +246,7 @@ def __init__( else: self.conv = conv - def forward(self, hidden_states: torch.FloatTensor, scale: float = 1.0) -> torch.FloatTensor: + def forward(self, hidden_states: torch.Tensor, scale: float = 1.0) -> torch.Tensor: assert hidden_states.shape[1] == self.channels if self.norm is not None: @@ -259,114 +259,6 @@ def forward(self, hidden_states: torch.FloatTensor, scale: float = 1.0) -> torch return hidden_states -def get_down_block3d( - down_block_type: str, - num_layers: int, - in_channels: int, - out_channels: int, - temb_channels: int, - add_downsample: bool, - downsample_stride: int, - resnet_eps: float, - resnet_act_fn: str, - transformer_layers_per_block: int = 1, - num_attention_heads: Optional[int] = None, - resnet_groups: Optional[int] = None, - cross_attention_dim: Optional[int] = None, - downsample_padding: Optional[int] = None, - dual_cross_attention: bool = False, - use_linear_projection: bool = False, - only_cross_attention: bool = False, - upcast_attention: bool = False, - resnet_time_scale_shift: str = "default", - attention_type: str = "default", - resnet_skip_time_act: bool = False, - resnet_out_scale_factor: float = 1.0, - cross_attention_norm: Optional[str] = None, - attention_head_dim: Optional[int] = None, - downsample_type: Optional[str] = None, - dropout: float = 0.0, -): - # If attn head dim is not defined, we default it to the number of heads - if attention_head_dim is None: - logger.warn( - f"It is recommended to provide `attention_head_dim` when calling `get_down_block`. Defaulting `attention_head_dim` to {num_attention_heads}." - ) - attention_head_dim = num_attention_heads - - down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type - if down_block_type == "DownEncoderBlockCausal3D": - return DownEncoderBlockCausal3D( - num_layers=num_layers, - in_channels=in_channels, - out_channels=out_channels, - dropout=dropout, - add_downsample=add_downsample, - downsample_stride=downsample_stride, - resnet_eps=resnet_eps, - resnet_act_fn=resnet_act_fn, - resnet_groups=resnet_groups, - downsample_padding=downsample_padding, - resnet_time_scale_shift=resnet_time_scale_shift, - ) - raise ValueError(f"{down_block_type} does not exist.") - - -def get_up_block3d( - up_block_type: str, - num_layers: int, - in_channels: int, - out_channels: int, - prev_output_channel: int, - temb_channels: int, - add_upsample: bool, - upsample_scale_factor: Tuple, - resnet_eps: float, - resnet_act_fn: str, - resolution_idx: Optional[int] = None, - transformer_layers_per_block: int = 1, - num_attention_heads: Optional[int] = None, - resnet_groups: Optional[int] = None, - cross_attention_dim: Optional[int] = None, - dual_cross_attention: bool = False, - use_linear_projection: bool = False, - only_cross_attention: bool = False, - upcast_attention: bool = False, - resnet_time_scale_shift: str = "default", - attention_type: str = "default", - resnet_skip_time_act: bool = False, - resnet_out_scale_factor: float = 1.0, - cross_attention_norm: Optional[str] = None, - attention_head_dim: Optional[int] = None, - upsample_type: Optional[str] = None, - dropout: float = 0.0, -) -> nn.Module: - # If attn head dim is not defined, we default it to the number of heads - if attention_head_dim is None: - logger.warn( - f"It is recommended to provide `attention_head_dim` when calling `get_up_block`. Defaulting `attention_head_dim` to {num_attention_heads}." - ) - attention_head_dim = num_attention_heads - - up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type - if up_block_type == "UpDecoderBlockCausal3D": - return UpDecoderBlockCausal3D( - num_layers=num_layers, - in_channels=in_channels, - out_channels=out_channels, - resolution_idx=resolution_idx, - dropout=dropout, - add_upsample=add_upsample, - upsample_scale_factor=upsample_scale_factor, - resnet_eps=resnet_eps, - resnet_act_fn=resnet_act_fn, - resnet_groups=resnet_groups, - resnet_time_scale_shift=resnet_time_scale_shift, - temb_channels=temb_channels, - ) - raise ValueError(f"{up_block_type} does not exist.") - - class ResnetBlockCausal3D(nn.Module): r""" A Resnet block. @@ -388,7 +280,7 @@ def __init__( skip_time_act: bool = False, # default, scale_shift, ada_group, spatial time_embedding_norm: str = "default", - kernel: Optional[torch.FloatTensor] = None, + kernel: Optional[torch.Tensor] = None, output_scale_factor: float = 1.0, use_in_shortcut: Optional[bool] = None, up: bool = False, @@ -468,10 +360,10 @@ def __init__( def forward( self, - input_tensor: torch.FloatTensor, - temb: torch.FloatTensor, + input_tensor: torch.Tensor, + temb: torch.Tensor, scale: float = 1.0, - ) -> torch.FloatTensor: + ) -> torch.Tensor: hidden_states = input_tensor if self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial": @@ -614,7 +506,7 @@ def __init__( self.attentions = nn.ModuleList(attentions) self.resnets = nn.ModuleList(resnets) - def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor: + def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor: hidden_states = self.resnets[0](hidden_states, temb) for attn, resnet in zip(self.attentions, self.resnets[1:]): if attn is not None: @@ -685,7 +577,7 @@ def __init__( else: self.downsamplers = None - def forward(self, hidden_states: torch.FloatTensor, scale: float = 1.0) -> torch.FloatTensor: + def forward(self, hidden_states: torch.Tensor, scale: float = 1.0) -> torch.Tensor: for resnet in self.resnets: hidden_states = resnet(hidden_states, temb=None, scale=scale) @@ -754,8 +646,8 @@ def __init__( self.resolution_idx = resolution_idx def forward( - self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, scale: float = 1.0 - ) -> torch.FloatTensor: + self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None, scale: float = 1.0 + ) -> torch.Tensor: for resnet in self.resnets: hidden_states = resnet(hidden_states, temb=temb, scale=scale) @@ -768,15 +660,15 @@ def forward( class EncoderCausal3D(nn.Module): r""" - The `EncoderCausal3D` layer of a variational autoencoder that encodes its input into a latent representation. + Causal encoder for 3D video-like data introduced in [Hunyuan Video](https://huggingface.co/papers/2412.03603). """ - + def __init__( self, in_channels: int = 3, out_channels: int = 3, - down_block_types: Tuple[str, ...] = ("DownEncoderBlockCausal3D",), - block_out_channels: Tuple[int, ...] = (64,), + down_block_types: Tuple[str, ...] = ("DownEncoderBlockCausal3D", "DownEncoderBlockCausal3D", "DownEncoderBlockCausal3D", "DownEncoderBlockCausal3D"), + block_out_channels: Tuple[int, ...] = (128, 256, 512, 512), layers_per_block: int = 2, norm_num_groups: int = 32, act_fn: str = "silu", @@ -784,15 +676,13 @@ def __init__( mid_block_add_attention=True, time_compression_ratio: int = 4, spatial_compression_ratio: int = 8, - ): + ) -> None: super().__init__() - self.layers_per_block = layers_per_block - + self.conv_in = CausalConv3d(in_channels, block_out_channels[0], kernel_size=3, stride=1) self.mid_block = None self.down_blocks = nn.ModuleList([]) - # down output_channel = block_out_channels[0] for i, down_block_type in enumerate(down_block_types): input_channel = output_channel @@ -815,23 +705,21 @@ def __init__( downsample_stride_HW = (2, 2) if add_spatial_downsample else (1, 1) downsample_stride_T = (2,) if add_time_downsample else (1,) downsample_stride = tuple(downsample_stride_T + downsample_stride_HW) - down_block = get_down_block3d( - down_block_type, - num_layers=self.layers_per_block, + + down_block = DownEncoderBlockCausal3D( + num_layers=layers_per_block, in_channels=input_channel, out_channels=output_channel, add_downsample=bool(add_spatial_downsample or add_time_downsample), downsample_stride=downsample_stride, resnet_eps=1e-6, - downsample_padding=0, resnet_act_fn=act_fn, resnet_groups=norm_num_groups, - attention_head_dim=output_channel, - temb_channels=None, + downsample_padding=0, ) + self.down_blocks.append(down_block) - # mid self.mid_block = UNetMidBlockCausal3D( in_channels=block_out_channels[-1], resnet_eps=1e-6, @@ -844,50 +732,58 @@ def __init__( add_attention=mid_block_add_attention, ) - # out self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6) self.conv_act = nn.SiLU() conv_out_channels = 2 * out_channels if double_z else out_channels self.conv_out = CausalConv3d(block_out_channels[-1], conv_out_channels, kernel_size=3) - def forward(self, sample: torch.FloatTensor) -> torch.FloatTensor: - r"""The forward method of the `EncoderCausal3D` class.""" - assert len(sample.shape) == 5, "The input tensor should have 5 dimensions" + self.gradient_checkpointing = False - sample = self.conv_in(sample) + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + # 1. Input layer + hidden_states = self.conv_in(hidden_states) - # down + use_reentrant = is_torch_version("<=", "1.11.0") + + def create_block_forward(block): + if torch.is_grad_enabled() and self.gradient_checkpointing: + return lambda *inputs: torch.utils.checkpoint.checkpoint( + lambda *x: block(*x), *inputs, use_reentrant=use_reentrant + ) + else: + return block + + # 2. Down blocks for down_block in self.down_blocks: - sample = down_block(sample) + hidden_states = create_block_forward(down_block)(hidden_states) - # middle - sample = self.mid_block(sample) + # 3. Mid block + hidden_states = self.mid_block(hidden_states) - # post-process - sample = self.conv_norm_out(sample) - sample = self.conv_act(sample) - sample = self.conv_out(sample) + # 4. Output layers + hidden_states = self.conv_norm_out(hidden_states) + hidden_states = self.conv_act(hidden_states) + hidden_states = self.conv_out(hidden_states) - return sample + return hidden_states class DecoderCausal3D(nn.Module): r""" - The `DecoderCausal3D` layer of a variational autoencoder that decodes its latent representation into an output - sample. + Causal decoder for 3D video-like data introduced in [Hunyuan Video](https://huggingface.co/papers/2412.03603). """ def __init__( self, in_channels: int = 3, out_channels: int = 3, - up_block_types: Tuple[str, ...] = ("UpDecoderBlockCausal3D",), - block_out_channels: Tuple[int, ...] = (64,), + up_block_types: Tuple[str, ...] = ("UpDecoderBlockCausal3D", "UpDecoderBlockCausal3D", "UpDecoderBlockCausal3D", "UpDecoderBlockCausal3D"), + block_out_channels: Tuple[int, ...] = (128, 256, 512, 512), layers_per_block: int = 2, norm_num_groups: int = 32, act_fn: str = "silu", - norm_type: str = "group", # group, spatial + norm_type: str = "group", mid_block_add_attention=True, time_compression_ratio: int = 4, spatial_compression_ratio: int = 8, @@ -935,21 +831,20 @@ def __init__( upsample_scale_factor_HW = (2, 2) if add_spatial_upsample else (1, 1) upsample_scale_factor_T = (2,) if add_time_upsample else (1,) upsample_scale_factor = tuple(upsample_scale_factor_T + upsample_scale_factor_HW) - up_block = get_up_block3d( - up_block_type, + + up_block = UpDecoderBlockCausal3D( num_layers=self.layers_per_block + 1, in_channels=prev_output_channel, out_channels=output_channel, - prev_output_channel=None, add_upsample=bool(add_spatial_upsample or add_time_upsample), upsample_scale_factor=upsample_scale_factor, resnet_eps=1e-6, resnet_act_fn=act_fn, resnet_groups=norm_num_groups, - attention_head_dim=output_channel, - temb_channels=temb_channels, resnet_time_scale_shift=norm_type, + temb_channels=temb_channels, ) + self.up_blocks.append(up_block) prev_output_channel = output_channel @@ -965,9 +860,9 @@ def __init__( def forward( self, - sample: torch.FloatTensor, - latent_embeds: Optional[torch.FloatTensor] = None, - ) -> torch.FloatTensor: + sample: torch.Tensor, + latent_embeds: Optional[torch.Tensor] = None, + ) -> torch.Tensor: r"""The forward method of the `DecoderCausal3D` class.""" assert len(sample.shape) == 5, "The input tensor should have 5 dimensions" @@ -1022,14 +917,14 @@ def custom_forward(*inputs): @dataclass class DecoderOutput2(BaseOutput): - sample: torch.FloatTensor + sample: torch.Tensor posterior: Optional[DiagonalGaussianDistribution] = None class AutoencoderKLHunyuanVideo(ModelMixin, ConfigMixin): r""" - A VAE model with KL loss for encoding images/videos into latents and decoding latent representations into - images/videos. + A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos. Introduced + in [HunyuanVideo](https://huggingface.co/papers/2412.03603). This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented for all models (such as downloading or saving). @@ -1076,12 +971,12 @@ def __init__( down_block_types=down_block_types, block_out_channels=block_out_channels, layers_per_block=layers_per_block, - act_fn=act_fn, norm_num_groups=norm_num_groups, + act_fn=act_fn, double_z=True, + mid_block_add_attention=mid_block_add_attention, time_compression_ratio=time_compression_ratio, spatial_compression_ratio=spatial_compression_ratio, - mid_block_add_attention=mid_block_add_attention, ) self.decoder = DecoderCausal3D( @@ -1166,13 +1061,13 @@ def disable_slicing(self): @apply_forward_hook def encode( - self, x: torch.FloatTensor, return_dict: bool = True + self, x: torch.Tensor, return_dict: bool = True ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]: """ Encode a batch of images/videos into latents. Args: - x (`torch.FloatTensor`): Input batch of images/videos. + x (`torch.Tensor`): Input batch of images/videos. return_dict (`bool`, *optional*, defaults to `True`): Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple. @@ -1204,7 +1099,7 @@ def encode( return AutoencoderKLOutput(latent_dist=posterior) - def _decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]: + def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]: assert len(z.shape) == 5, "The input tensor should have 5 dimensions" if self.use_temporal_tiling and z.shape[2] > self.tile_latent_min_tsize: @@ -1225,13 +1120,13 @@ def _decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[Decod @apply_forward_hook def decode( - self, z: torch.FloatTensor, return_dict: bool = True, generator=None - ) -> Union[DecoderOutput, torch.FloatTensor]: + self, z: torch.Tensor, return_dict: bool = True + ) -> Union[DecoderOutput, torch.Tensor]: """ Decode a batch of images/videos. Args: - z (`torch.FloatTensor`): Input batch of latent vectors. + z (`torch.Tensor`): Input batch of latent vectors. return_dict (`bool`, *optional*, defaults to `True`): Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. @@ -1277,7 +1172,7 @@ def blend_t(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch. return b def spatial_tiled_encode( - self, x: torch.FloatTensor, return_dict: bool = True, return_moments: bool = False + self, x: torch.Tensor, return_dict: bool = True, return_moments: bool = False ) -> AutoencoderKLOutput: r"""Encode a batch of images/videos using a tiled encoder. @@ -1288,7 +1183,7 @@ def spatial_tiled_encode( changes in the output, but they should be much less noticeable. Args: - x (`torch.FloatTensor`): Input batch of images/videos. + x (`torch.Tensor`): Input batch of images/videos. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple. @@ -1335,13 +1230,13 @@ def spatial_tiled_encode( return AutoencoderKLOutput(latent_dist=posterior) def spatial_tiled_decode( - self, z: torch.FloatTensor, return_dict: bool = True - ) -> Union[DecoderOutput, torch.FloatTensor]: + self, z: torch.Tensor, return_dict: bool = True + ) -> Union[DecoderOutput, torch.Tensor]: r""" Decode a batch of images/videos using a tiled decoder. Args: - z (`torch.FloatTensor`): Input batch of latent vectors. + z (`torch.Tensor`): Input batch of latent vectors. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. @@ -1384,7 +1279,7 @@ def spatial_tiled_decode( return DecoderOutput(sample=dec) - def temporal_tiled_encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput: + def temporal_tiled_encode(self, x: torch.Tensor, return_dict: bool = True) -> AutoencoderKLOutput: B, C, T, H, W = x.shape overlap_size = int(self.tile_sample_min_tsize * (1 - self.tile_overlap_factor)) blend_extent = int(self.tile_latent_min_tsize * self.tile_overlap_factor) @@ -1421,8 +1316,8 @@ def temporal_tiled_encode(self, x: torch.FloatTensor, return_dict: bool = True) return AutoencoderKLOutput(latent_dist=posterior) def temporal_tiled_decode( - self, z: torch.FloatTensor, return_dict: bool = True - ) -> Union[DecoderOutput, torch.FloatTensor]: + self, z: torch.Tensor, return_dict: bool = True + ) -> Union[DecoderOutput, torch.Tensor]: # Split z into overlapping tiles and decode them separately. B, C, T, H, W = z.shape @@ -1459,15 +1354,15 @@ def temporal_tiled_decode( def forward( self, - sample: torch.FloatTensor, + sample: torch.Tensor, sample_posterior: bool = False, return_dict: bool = True, return_posterior: bool = False, generator: Optional[torch.Generator] = None, - ) -> Union[DecoderOutput2, torch.FloatTensor]: + ) -> Union[DecoderOutput2, torch.Tensor]: r""" Args: - sample (`torch.FloatTensor`): Input sample. + sample (`torch.Tensor`): Input sample. sample_posterior (`bool`, *optional*, defaults to `False`): Whether to sample from the posterior. return_dict (`bool`, *optional*, defaults to `True`): diff --git a/src/diffusers/models/transformers/transformer_hunyuan_video.py b/src/diffusers/models/transformers/transformer_hunyuan_video.py index 83aac1e68e9e..50f8254648d1 100644 --- a/src/diffusers/models/transformers/transformer_hunyuan_video.py +++ b/src/diffusers/models/transformers/transformer_hunyuan_video.py @@ -371,7 +371,7 @@ def forward( if attention_mask is None: pooled_projections = hidden_states.mean(dim=1) else: - mask_float = attention_mask.float().unsqueeze(-1) # [b, s1, 1] + mask_float = attention_mask.float().unsqueeze(-1) pooled_projections = (hidden_states * mask_float).sum(dim=1) / mask_float.sum(dim=1) pooled_projections = pooled_projections.to(original_dtype) @@ -409,7 +409,7 @@ def __init__( out_dim=hidden_size, bias=True, processor=HunyuanVideoAttnProcessor2_0(), - qk_norm="rms_norm", + qk_norm=qk_norm, eps=1e-6, pre_only=True, ) diff --git a/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py b/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py index 94732c3d8947..4dbe14afb2dd 100644 --- a/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +++ b/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py @@ -582,7 +582,6 @@ def __call__( Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] ] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], - enable_tiling: bool = False, ): r""" The call function to the pipeline for generation. @@ -804,33 +803,10 @@ def __call__( latents = latents.to(vae_dtype) if not output_type == "latent": - expand_temporal_dim = False - if len(latents.shape) == 4: - latents = latents.unsqueeze(2) - expand_temporal_dim = True - elif len(latents.shape) == 5: - pass - else: - raise ValueError( - f"Only support latents with shape (b, c, h, w) or (b, c, f, h, w), but got {latents.shape}." - ) - - if hasattr(self.vae.config, "shift_factor") and self.vae.config.shift_factor: - latents = latents / self.vae.config.scaling_factor + self.vae.config.shift_factor - else: - latents = latents / self.vae.config.scaling_factor - - if enable_tiling: - self.vae.enable_tiling() - image = self.vae.decode(latents, return_dict=False, generator=generator)[0] - else: - image = self.vae.decode(latents, return_dict=False, generator=generator)[0] + image = self.vae.decode(latents, return_dict=False)[0] torch.save(image, "diffusers_latents_decoded.pt") - if expand_temporal_dim or image.shape[2] == 1: - image = image.squeeze(2) - else: image = latents From 1b27c3a23757ce6148c766eeb5510c68604666d4 Mon Sep 17 00:00:00 2001 From: Aryan Date: Thu, 12 Dec 2024 20:57:53 +0100 Subject: [PATCH 28/58] refactor upsample --- scripts/convert_hunyuan_video_to_diffusers.py | 6 +- .../autoencoder_kl_hunyuan_video.py | 191 ++++-------------- 2 files changed, 44 insertions(+), 153 deletions(-) diff --git a/scripts/convert_hunyuan_video_to_diffusers.py b/scripts/convert_hunyuan_video_to_diffusers.py index d1ab60a767ce..61ffe0474087 100644 --- a/scripts/convert_hunyuan_video_to_diffusers.py +++ b/scripts/convert_hunyuan_video_to_diffusers.py @@ -109,9 +109,7 @@ def remap_single_transformer_blocks_(key, state_dict): "single_blocks": remap_single_transformer_blocks_, } -VAE_KEYS_RENAME_DICT = { - -} +VAE_KEYS_RENAME_DICT = {} VAE_SPECIAL_KEYS_REMAP = {} @@ -208,7 +206,7 @@ def get_args(): transformer = transformer.to(dtype=dtype) if not args.save_pipeline: transformer.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB") - + if args.vae_ckpt_path is not None: vae = convert_vae(args.vae_ckpt_path) if not args.save_pipeline: diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py b/src/diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py index 6965c7fcc15c..354e1508e4e0 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py @@ -22,7 +22,7 @@ from einops import rearrange from ...configuration_utils import ConfigMixin, register_to_config -from ...utils import logging, is_torch_version +from ...utils import is_torch_version, logging from ...utils.accelerate_utils import apply_forward_hook from ..activations import get_activation from ..attention_processor import Attention, SpatialNorm @@ -83,116 +83,35 @@ def forward(self, x): class UpsampleCausal3D(nn.Module): - """ - A 3D upsampling layer with an optional convolution. - """ - def __init__( self, - channels: int, - use_conv: bool = False, - use_conv_transpose: bool = False, + in_channels: int, out_channels: Optional[int] = None, - name: str = "conv", - kernel_size: Optional[int] = None, - padding=1, - norm_type=None, - eps=None, - elementwise_affine=None, - bias=True, - interpolate=True, - upsample_factor=(2, 2, 2), - ): + bias: bool = True, + upsample_factor: Tuple[float, float, float] = (2, 2, 2), + ) -> None: super().__init__() - self.channels = channels - self.out_channels = out_channels or channels - self.use_conv = use_conv - self.use_conv_transpose = use_conv_transpose - self.name = name - self.interpolate = interpolate - self.upsample_factor = upsample_factor - if norm_type == "ln_norm": - self.norm = nn.LayerNorm(channels, eps, elementwise_affine) - elif norm_type == "rms_norm": - self.norm = RMSNorm(channels, eps, elementwise_affine) - elif norm_type is None: - self.norm = None - else: - raise ValueError(f"unknown norm_type: {norm_type}") - - conv = None - if use_conv_transpose: - assert False, "Not Implement yet" - if kernel_size is None: - kernel_size = 4 - conv = nn.ConvTranspose2d( - channels, self.out_channels, kernel_size=kernel_size, stride=2, padding=padding, bias=bias - ) - elif use_conv: - if kernel_size is None: - kernel_size = 3 - conv = CausalConv3d(self.channels, self.out_channels, kernel_size=kernel_size, bias=bias) - - if name == "conv": - self.conv = conv - else: - self.Conv2d_0 = conv - - def forward( - self, - hidden_states: torch.Tensor, - output_size: Optional[int] = None, - scale: float = 1.0, - ) -> torch.Tensor: - assert hidden_states.shape[1] == self.channels + out_channels = out_channels or in_channels + self.upsample_factor = upsample_factor - if self.norm is not None: - assert False, "Not Implement yet" - hidden_states = self.norm(hidden_states.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) + self.conv = CausalConv3d(in_channels, out_channels, 3, 1, bias=bias) - if self.use_conv_transpose: - return self.conv(hidden_states) - - # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16 - dtype = hidden_states.dtype - if dtype == torch.bfloat16: - hidden_states = hidden_states.to(torch.float32) - - # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984 - if hidden_states.shape[0] >= 64: - hidden_states = hidden_states.contiguous() - - # if `output_size` is passed we force the interpolation output - # size and do not make use of `scale_factor=2` - if self.interpolate: - B, C, T, H, W = hidden_states.shape - first_h, other_h = hidden_states.split((1, T - 1), dim=2) - if output_size is None: - if T > 1: - other_h = F.interpolate(other_h, scale_factor=self.upsample_factor, mode="nearest") - - first_h = first_h.squeeze(2) - first_h = F.interpolate(first_h, scale_factor=self.upsample_factor[1:], mode="nearest") - first_h = first_h.unsqueeze(2) - else: - assert False, "Not Implement yet" - other_h = F.interpolate(other_h, size=output_size, mode="nearest") + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + num_frames = hidden_states.size(2) + first_frame, other_frames = hidden_states.split((1, num_frames - 1), dim=2) - if T > 1: - hidden_states = torch.cat((first_h, other_h), dim=2) - else: - hidden_states = first_h + first_frame = F.interpolate( + first_frame.squeeze(2), scale_factor=self.upsample_factor[1:], mode="nearest" + ).unsqueeze(2) - # If the input is bfloat16, we cast back to bfloat16 - if dtype == torch.bfloat16: - hidden_states = hidden_states.to(dtype) + if num_frames > 1: + other_frames = F.interpolate(other_frames, scale_factor=self.upsample_factor, mode="nearest") + hidden_states = torch.cat((first_frame, other_frames), dim=2) + else: + hidden_states = first_frame - if self.use_conv: - if self.name == "conv": - hidden_states = self.conv(hidden_states) - else: - hidden_states = self.Conv2d_0(hidden_states) + hidden_states = self.conv(hidden_states) return hidden_states @@ -278,13 +197,10 @@ def __init__( eps: float = 1e-6, non_linearity: str = "swish", skip_time_act: bool = False, - # default, scale_shift, ada_group, spatial time_embedding_norm: str = "default", kernel: Optional[torch.Tensor] = None, output_scale_factor: float = 1.0, use_in_shortcut: Optional[bool] = None, - up: bool = False, - down: bool = False, conv_shortcut_bias: bool = True, conv_3d_out_channels: Optional[int] = None, ): @@ -295,8 +211,6 @@ def __init__( out_channels = in_channels if out_channels is None else out_channels self.out_channels = out_channels self.use_conv_shortcut = conv_shortcut - self.up = up - self.down = down self.output_scale_factor = output_scale_factor self.time_embedding_norm = time_embedding_norm self.skip_time_act = skip_time_act @@ -340,12 +254,6 @@ def __init__( self.nonlinearity = get_activation(non_linearity) - self.upsample = self.downsample = None - if self.up: - self.upsample = UpsampleCausal3D(in_channels, use_conv=False) - elif self.down: - self.downsample = DownsampleCausal3D(in_channels, use_conv=False, name="op") - self.use_in_shortcut = self.in_channels != conv_3d_out_channels if use_in_shortcut is None else use_in_shortcut self.conv_shortcut = None @@ -372,18 +280,6 @@ def forward( hidden_states = self.norm1(hidden_states) hidden_states = self.nonlinearity(hidden_states) - - if self.upsample is not None: - # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984 - if hidden_states.shape[0] >= 64: - input_tensor = input_tensor.contiguous() - hidden_states = hidden_states.contiguous() - input_tensor = self.upsample(input_tensor, scale=scale) - hidden_states = self.upsample(hidden_states, scale=scale) - elif self.downsample is not None: - input_tensor = self.downsample(input_tensor, scale=scale) - hidden_states = self.downsample(hidden_states, scale=scale) - hidden_states = self.conv1(hidden_states) if self.time_emb_proj is not None: @@ -461,12 +357,6 @@ def __init__( ] attentions = [] - if attention_head_dim is None: - logger.warn( - f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {in_channels}." - ) - attention_head_dim = in_channels - for _ in range(num_layers): if self.add_attention: # assert False, "Not implemented yet" @@ -634,7 +524,6 @@ def __init__( [ UpsampleCausal3D( out_channels, - use_conv=True, out_channels=out_channels, upsample_factor=upsample_scale_factor, ) @@ -662,12 +551,17 @@ class EncoderCausal3D(nn.Module): r""" Causal encoder for 3D video-like data introduced in [Hunyuan Video](https://huggingface.co/papers/2412.03603). """ - + def __init__( self, in_channels: int = 3, out_channels: int = 3, - down_block_types: Tuple[str, ...] = ("DownEncoderBlockCausal3D", "DownEncoderBlockCausal3D", "DownEncoderBlockCausal3D", "DownEncoderBlockCausal3D"), + down_block_types: Tuple[str, ...] = ( + "DownEncoderBlockCausal3D", + "DownEncoderBlockCausal3D", + "DownEncoderBlockCausal3D", + "DownEncoderBlockCausal3D", + ), block_out_channels: Tuple[int, ...] = (128, 256, 512, 512), layers_per_block: int = 2, norm_num_groups: int = 32, @@ -678,7 +572,7 @@ def __init__( spatial_compression_ratio: int = 8, ) -> None: super().__init__() - + self.conv_in = CausalConv3d(in_channels, block_out_channels[0], kernel_size=3, stride=1) self.mid_block = None self.down_blocks = nn.ModuleList([]) @@ -717,7 +611,7 @@ def __init__( resnet_groups=norm_num_groups, downsample_padding=0, ) - + self.down_blocks.append(down_block) self.mid_block = UNetMidBlockCausal3D( @@ -778,7 +672,12 @@ def __init__( self, in_channels: int = 3, out_channels: int = 3, - up_block_types: Tuple[str, ...] = ("UpDecoderBlockCausal3D", "UpDecoderBlockCausal3D", "UpDecoderBlockCausal3D", "UpDecoderBlockCausal3D"), + up_block_types: Tuple[str, ...] = ( + "UpDecoderBlockCausal3D", + "UpDecoderBlockCausal3D", + "UpDecoderBlockCausal3D", + "UpDecoderBlockCausal3D", + ), block_out_channels: Tuple[int, ...] = (128, 256, 512, 512), layers_per_block: int = 2, norm_num_groups: int = 32, @@ -831,7 +730,7 @@ def __init__( upsample_scale_factor_HW = (2, 2) if add_spatial_upsample else (1, 1) upsample_scale_factor_T = (2,) if add_time_upsample else (1,) upsample_scale_factor = tuple(upsample_scale_factor_T + upsample_scale_factor_HW) - + up_block = UpDecoderBlockCausal3D( num_layers=self.layers_per_block + 1, in_channels=prev_output_channel, @@ -844,7 +743,7 @@ def __init__( resnet_time_scale_shift=norm_type, temb_channels=temb_channels, ) - + self.up_blocks.append(up_block) prev_output_channel = output_channel @@ -923,8 +822,8 @@ class DecoderOutput2(BaseOutput): class AutoencoderKLHunyuanVideo(ModelMixin, ConfigMixin): r""" - A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos. Introduced - in [HunyuanVideo](https://huggingface.co/papers/2412.03603). + A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos. + Introduced in [HunyuanVideo](https://huggingface.co/papers/2412.03603). This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented for all models (such as downloading or saving). @@ -1119,9 +1018,7 @@ def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOut return DecoderOutput(sample=dec) @apply_forward_hook - def decode( - self, z: torch.Tensor, return_dict: bool = True - ) -> Union[DecoderOutput, torch.Tensor]: + def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]: """ Decode a batch of images/videos. @@ -1229,9 +1126,7 @@ def spatial_tiled_encode( return AutoencoderKLOutput(latent_dist=posterior) - def spatial_tiled_decode( - self, z: torch.Tensor, return_dict: bool = True - ) -> Union[DecoderOutput, torch.Tensor]: + def spatial_tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]: r""" Decode a batch of images/videos using a tiled decoder. @@ -1315,9 +1210,7 @@ def temporal_tiled_encode(self, x: torch.Tensor, return_dict: bool = True) -> Au return AutoencoderKLOutput(latent_dist=posterior) - def temporal_tiled_decode( - self, z: torch.Tensor, return_dict: bool = True - ) -> Union[DecoderOutput, torch.Tensor]: + def temporal_tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]: # Split z into overlapping tiles and decode them separately. B, C, T, H, W = z.shape From bea9e1b67bd4b2c7c50fd61d02640f4c8630ffdf Mon Sep 17 00:00:00 2001 From: Aryan Date: Thu, 12 Dec 2024 21:06:13 +0100 Subject: [PATCH 29/58] refactor causal conv --- .../autoencoder_kl_hunyuan_video.py | 90 +++++-------------- 1 file changed, 24 insertions(+), 66 deletions(-) diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py b/src/diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py index 354e1508e4e0..86e8cb04adf2 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py @@ -28,7 +28,7 @@ from ..attention_processor import Attention, SpatialNorm from ..modeling_outputs import AutoencoderKLOutput from ..modeling_utils import ModelMixin -from ..normalization import AdaGroupNorm, RMSNorm +from ..normalization import AdaGroupNorm from .vae import BaseOutput, DecoderOutput, DiagonalGaussianDistribution @@ -47,39 +47,36 @@ def prepare_causal_attention_mask(n_frame: int, n_hw: int, dtype, device, batch_ class CausalConv3d(nn.Module): - """ - Implements a causal 3D convolution layer where each position only depends on previous timesteps and current spatial - locations. This maintains temporal causality in video generation tasks. - """ - def __init__( self, - chan_in, - chan_out, - kernel_size: Union[int, Tuple[int, int, int]], + in_channels: int, + out_channels: int, + kernel_size: Union[int, Tuple[int, int, int]] = 3, stride: Union[int, Tuple[int, int, int]] = 1, + padding: Union[int, Tuple[int, int, int]] = 0, dilation: Union[int, Tuple[int, int, int]] = 1, - pad_mode="replicate", - **kwargs, - ): + bias: bool = True, + pad_mode: str = "replicate", + ) -> None: super().__init__() + kernel_size = (kernel_size, kernel_size, kernel_size) if isinstance(kernel_size, int) else kernel_size + self.pad_mode = pad_mode - padding = ( - kernel_size // 2, - kernel_size // 2, - kernel_size // 2, - kernel_size // 2, - kernel_size - 1, + self.time_causal_padding = ( + kernel_size[0] // 2, + kernel_size[0] // 2, + kernel_size[1] // 2, + kernel_size[1] // 2, + kernel_size[2] - 1, 0, - ) # W, H, T - self.time_causal_padding = padding + ) - self.conv = nn.Conv3d(chan_in, chan_out, kernel_size, stride=stride, dilation=dilation, **kwargs) + self.conv = nn.Conv3d(in_channels, out_channels, kernel_size, stride, padding, dilation, bias=bias) - def forward(self, x): - x = F.pad(x, self.time_causal_padding, mode=self.pad_mode) - return self.conv(x) + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = F.pad(hidden_states, self.time_causal_padding, mode=self.pad_mode) + return self.conv(hidden_states) class UpsampleCausal3D(nn.Module): @@ -117,62 +114,25 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: class DownsampleCausal3D(nn.Module): - """ - A 3D downsampling layer with an optional convolution. - """ - def __init__( self, channels: int, - use_conv: bool = False, out_channels: Optional[int] = None, padding: int = 1, - name: str = "conv", kernel_size=3, - norm_type=None, - eps=None, - elementwise_affine=None, bias=True, stride=2, ): super().__init__() - self.channels = channels - self.out_channels = out_channels or channels - self.use_conv = use_conv - self.padding = padding - stride = stride - self.name = name - - if norm_type == "ln_norm": - self.norm = nn.LayerNorm(channels, eps, elementwise_affine) - elif norm_type == "rms_norm": - self.norm = RMSNorm(channels, eps, elementwise_affine) - elif norm_type is None: - self.norm = None - else: - raise ValueError(f"unknown norm_type: {norm_type}") - - if use_conv: - conv = CausalConv3d(self.channels, self.out_channels, kernel_size=kernel_size, stride=stride, bias=bias) - else: - raise NotImplementedError - if name == "conv": - self.Conv2d_0 = conv - self.conv = conv - elif name == "Conv2d_0": - self.conv = conv - else: - self.conv = conv + out_channels = out_channels or channels - def forward(self, hidden_states: torch.Tensor, scale: float = 1.0) -> torch.Tensor: - assert hidden_states.shape[1] == self.channels + self.conv = CausalConv3d(channels, out_channels, kernel_size=kernel_size, stride=stride, bias=bias) + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: if self.norm is not None: hidden_states = self.norm(hidden_states.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) - assert hidden_states.shape[1] == self.channels - hidden_states = self.conv(hidden_states) return hidden_states @@ -456,10 +416,8 @@ def __init__( [ DownsampleCausal3D( out_channels, - use_conv=True, out_channels=out_channels, padding=downsample_padding, - name="op", stride=downsample_stride, ) ] From d6c16ef13af29f67819018faa117694e588f068b Mon Sep 17 00:00:00 2001 From: Aryan Date: Thu, 12 Dec 2024 21:42:00 +0100 Subject: [PATCH 30/58] refactor resnet --- .../autoencoder_kl_hunyuan_video.py | 248 ++++-------------- 1 file changed, 45 insertions(+), 203 deletions(-) diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py b/src/diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py index 86e8cb04adf2..db28ff43bade 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py @@ -25,10 +25,9 @@ from ...utils import is_torch_version, logging from ...utils.accelerate_utils import apply_forward_hook from ..activations import get_activation -from ..attention_processor import Attention, SpatialNorm +from ..attention_processor import Attention from ..modeling_outputs import AutoencoderKLOutput from ..modeling_utils import ModelMixin -from ..normalization import AdaGroupNorm from .vae import BaseOutput, DecoderOutput, DiagonalGaussianDistribution @@ -84,6 +83,8 @@ def __init__( self, in_channels: int, out_channels: Optional[int] = None, + kernel_size: int = 3, + stride: int = 1, bias: bool = True, upsample_factor: Tuple[float, float, float] = (2, 2, 2), ) -> None: @@ -92,12 +93,12 @@ def __init__( out_channels = out_channels or in_channels self.upsample_factor = upsample_factor - self.conv = CausalConv3d(in_channels, out_channels, 3, 1, bias=bias) + self.conv = CausalConv3d(in_channels, out_channels, kernel_size, stride, bias=bias) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: num_frames = hidden_states.size(2) - first_frame, other_frames = hidden_states.split((1, num_frames - 1), dim=2) + first_frame, other_frames = hidden_states.split((1, num_frames - 1), dim=2) first_frame = F.interpolate( first_frame.squeeze(2), scale_factor=self.upsample_factor[1:], mode="nearest" ).unsqueeze(2) @@ -109,7 +110,6 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states = first_frame hidden_states = self.conv(hidden_states) - return hidden_states @@ -119,216 +119,103 @@ def __init__( channels: int, out_channels: Optional[int] = None, padding: int = 1, - kernel_size=3, - bias=True, + kernel_size: int = 3, + bias: bool = True, stride=2, - ): + ) -> None: super().__init__() - out_channels = out_channels or channels - self.conv = CausalConv3d(channels, out_channels, kernel_size=kernel_size, stride=stride, bias=bias) + self.conv = CausalConv3d(channels, out_channels, kernel_size, stride, padding, bias=bias) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - if self.norm is not None: - hidden_states = self.norm(hidden_states.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) - hidden_states = self.conv(hidden_states) - return hidden_states class ResnetBlockCausal3D(nn.Module): - r""" - A Resnet block. - """ - def __init__( self, - *, in_channels: int, out_channels: Optional[int] = None, - conv_shortcut: bool = False, dropout: float = 0.0, - temb_channels: int = 512, groups: int = 32, - groups_out: Optional[int] = None, - pre_norm: bool = True, eps: float = 1e-6, non_linearity: str = "swish", - skip_time_act: bool = False, - time_embedding_norm: str = "default", - kernel: Optional[torch.Tensor] = None, - output_scale_factor: float = 1.0, - use_in_shortcut: Optional[bool] = None, - conv_shortcut_bias: bool = True, - conv_3d_out_channels: Optional[int] = None, - ): + ) -> None: super().__init__() - self.pre_norm = pre_norm - self.pre_norm = True - self.in_channels = in_channels - out_channels = in_channels if out_channels is None else out_channels - self.out_channels = out_channels - self.use_conv_shortcut = conv_shortcut - self.output_scale_factor = output_scale_factor - self.time_embedding_norm = time_embedding_norm - self.skip_time_act = skip_time_act - - linear_cls = nn.Linear - - if groups_out is None: - groups_out = groups - - if self.time_embedding_norm == "ada_group": - self.norm1 = AdaGroupNorm(temb_channels, in_channels, groups, eps=eps) - elif self.time_embedding_norm == "spatial": - self.norm1 = SpatialNorm(in_channels, temb_channels) - else: - self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True) - - self.conv1 = CausalConv3d(in_channels, out_channels, kernel_size=3, stride=1) + out_channels = out_channels or in_channels - if temb_channels is not None: - if self.time_embedding_norm == "default": - self.time_emb_proj = linear_cls(temb_channels, out_channels) - elif self.time_embedding_norm == "scale_shift": - self.time_emb_proj = linear_cls(temb_channels, 2 * out_channels) - elif self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial": - self.time_emb_proj = None - else: - raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ") - else: - self.time_emb_proj = None + self.nonlinearity = get_activation(non_linearity) - if self.time_embedding_norm == "ada_group": - self.norm2 = AdaGroupNorm(temb_channels, out_channels, groups_out, eps=eps) - elif self.time_embedding_norm == "spatial": - self.norm2 = SpatialNorm(out_channels, temb_channels) - else: - self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True) + self.norm1 = nn.GroupNorm(groups, in_channels, eps=eps, affine=True) + self.conv1 = CausalConv3d(in_channels, out_channels, 3, 1, 0) + self.norm2 = nn.GroupNorm(groups, out_channels, eps=eps, affine=True) self.dropout = torch.nn.Dropout(dropout) - conv_3d_out_channels = conv_3d_out_channels or out_channels - self.conv2 = CausalConv3d(out_channels, conv_3d_out_channels, kernel_size=3, stride=1) - - self.nonlinearity = get_activation(non_linearity) - - self.use_in_shortcut = self.in_channels != conv_3d_out_channels if use_in_shortcut is None else use_in_shortcut + self.conv2 = CausalConv3d(out_channels, out_channels, 3, 1, 0) self.conv_shortcut = None - if self.use_in_shortcut: - self.conv_shortcut = CausalConv3d( - in_channels, - conv_3d_out_channels, - kernel_size=1, - stride=1, - bias=conv_shortcut_bias, - ) + if in_channels != out_channels: + self.conv_shortcut = CausalConv3d(in_channels, out_channels, 1, 1, 0) - def forward( - self, - input_tensor: torch.Tensor, - temb: torch.Tensor, - scale: float = 1.0, - ) -> torch.Tensor: - hidden_states = input_tensor - - if self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial": - hidden_states = self.norm1(hidden_states, temb) - else: - hidden_states = self.norm1(hidden_states) + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + residual = hidden_states + hidden_states = self.norm1(hidden_states) hidden_states = self.nonlinearity(hidden_states) hidden_states = self.conv1(hidden_states) - if self.time_emb_proj is not None: - if not self.skip_time_act: - temb = self.nonlinearity(temb) - temb = self.time_emb_proj(temb, scale)[:, :, None, None] - - if temb is not None and self.time_embedding_norm == "default": - hidden_states = hidden_states + temb - - if self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial": - hidden_states = self.norm2(hidden_states, temb) - else: - hidden_states = self.norm2(hidden_states) - - if temb is not None and self.time_embedding_norm == "scale_shift": - scale, shift = torch.chunk(temb, 2, dim=1) - hidden_states = hidden_states * (1 + scale) + shift - + hidden_states = self.norm2(hidden_states) hidden_states = self.nonlinearity(hidden_states) - hidden_states = self.dropout(hidden_states) hidden_states = self.conv2(hidden_states) if self.conv_shortcut is not None: - input_tensor = self.conv_shortcut(input_tensor) - - output_tensor = (input_tensor + hidden_states) / self.output_scale_factor + residual = self.conv_shortcut(residual) - return output_tensor + hidden_states = hidden_states + residual + return hidden_states class UNetMidBlockCausal3D(nn.Module): - """ - A 3D UNet mid-block [`UNetMidBlockCausal3D`] with multiple residual blocks and optional attention blocks. - """ - def __init__( self, in_channels: int, - temb_channels: int, dropout: float = 0.0, num_layers: int = 1, resnet_eps: float = 1e-6, - resnet_time_scale_shift: str = "default", # default, spatial resnet_act_fn: str = "swish", resnet_groups: int = 32, - attn_groups: Optional[int] = None, - resnet_pre_norm: bool = True, add_attention: bool = True, attention_head_dim: int = 1, - output_scale_factor: float = 1.0, - ): + ) -> None: super().__init__() resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) self.add_attention = add_attention - if attn_groups is None: - attn_groups = resnet_groups if resnet_time_scale_shift == "default" else None - - # there is always at least one resnet + # There is always at least one resnet resnets = [ ResnetBlockCausal3D( in_channels=in_channels, out_channels=in_channels, - temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, ) ] attentions = [] for _ in range(num_layers): if self.add_attention: - # assert False, "Not implemented yet" attentions.append( Attention( in_channels, heads=in_channels // attention_head_dim, dim_head=attention_head_dim, - rescale_output_factor=output_scale_factor, eps=resnet_eps, - norm_num_groups=attn_groups, - spatial_norm_dim=temb_channels if resnet_time_scale_shift == "spatial" else None, + norm_num_groups=resnet_groups, residual_connection=True, bias=True, upcast_softmax=True, @@ -342,22 +229,18 @@ def __init__( ResnetBlockCausal3D( in_channels=in_channels, out_channels=in_channels, - temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, ) ) self.attentions = nn.ModuleList(attentions) self.resnets = nn.ModuleList(resnets) - def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor: - hidden_states = self.resnets[0](hidden_states, temb) + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.resnets[0](hidden_states) for attn, resnet in zip(self.attentions, self.resnets[1:]): if attn is not None: B, C, T, H, W = hidden_states.shape @@ -365,9 +248,9 @@ def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = No attention_mask = prepare_causal_attention_mask( T, H * W, hidden_states.dtype, hidden_states.device, batch_size=B ) - hidden_states = attn(hidden_states, temb=temb, attention_mask=attention_mask) + hidden_states = attn(hidden_states, attention_mask=attention_mask) hidden_states = rearrange(hidden_states, "b (f h w) c -> b c f h w", f=T, h=H, w=W) - hidden_states = resnet(hidden_states, temb) + hidden_states = resnet(hidden_states) return hidden_states @@ -380,11 +263,8 @@ def __init__( dropout: float = 0.0, num_layers: int = 1, resnet_eps: float = 1e-6, - resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", resnet_groups: int = 32, - resnet_pre_norm: bool = True, - output_scale_factor: float = 1.0, add_downsample: bool = True, downsample_stride: int = 2, downsample_padding: int = 1, @@ -398,14 +278,10 @@ def __init__( ResnetBlockCausal3D( in_channels=in_channels, out_channels=out_channels, - temb_channels=None, eps=resnet_eps, groups=resnet_groups, dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, ) ) @@ -425,13 +301,13 @@ def __init__( else: self.downsamplers = None - def forward(self, hidden_states: torch.Tensor, scale: float = 1.0) -> torch.Tensor: + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: for resnet in self.resnets: - hidden_states = resnet(hidden_states, temb=None, scale=scale) + hidden_states = resnet(hidden_states) if self.downsamplers is not None: for downsampler in self.downsamplers: - hidden_states = downsampler(hidden_states, scale) + hidden_states = downsampler(hidden_states) return hidden_states @@ -445,14 +321,10 @@ def __init__( dropout: float = 0.0, num_layers: int = 1, resnet_eps: float = 1e-6, - resnet_time_scale_shift: str = "default", # default, spatial resnet_act_fn: str = "swish", resnet_groups: int = 32, - resnet_pre_norm: bool = True, - output_scale_factor: float = 1.0, add_upsample: bool = True, upsample_scale_factor=(2, 2, 2), - temb_channels: Optional[int] = None, ): super().__init__() resnets = [] @@ -464,14 +336,10 @@ def __init__( ResnetBlockCausal3D( in_channels=input_channels, out_channels=out_channels, - temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, ) ) @@ -492,11 +360,9 @@ def __init__( self.resolution_idx = resolution_idx - def forward( - self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None, scale: float = 1.0 - ) -> torch.Tensor: + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: for resnet in self.resnets: - hidden_states = resnet(hidden_states, temb=temb, scale=scale) + hidden_states = resnet(hidden_states) if self.upsamplers is not None: for upsampler in self.upsamplers: @@ -563,10 +429,10 @@ def __init__( in_channels=input_channel, out_channels=output_channel, add_downsample=bool(add_spatial_downsample or add_time_downsample), - downsample_stride=downsample_stride, resnet_eps=1e-6, resnet_act_fn=act_fn, resnet_groups=norm_num_groups, + downsample_stride=downsample_stride, downsample_padding=0, ) @@ -576,11 +442,8 @@ def __init__( in_channels=block_out_channels[-1], resnet_eps=1e-6, resnet_act_fn=act_fn, - output_scale_factor=1, - resnet_time_scale_shift="default", attention_head_dim=block_out_channels[-1], resnet_groups=norm_num_groups, - temb_channels=None, add_attention=mid_block_add_attention, ) @@ -652,18 +515,13 @@ def __init__( self.mid_block = None self.up_blocks = nn.ModuleList([]) - temb_channels = in_channels if norm_type == "spatial" else None - # mid self.mid_block = UNetMidBlockCausal3D( in_channels=block_out_channels[-1], resnet_eps=1e-6, resnet_act_fn=act_fn, - output_scale_factor=1, - resnet_time_scale_shift="default" if norm_type == "group" else norm_type, attention_head_dim=block_out_channels[-1], resnet_groups=norm_num_groups, - temb_channels=temb_channels, add_attention=mid_block_add_attention, ) @@ -698,29 +556,19 @@ def __init__( resnet_eps=1e-6, resnet_act_fn=act_fn, resnet_groups=norm_num_groups, - resnet_time_scale_shift=norm_type, - temb_channels=temb_channels, ) self.up_blocks.append(up_block) prev_output_channel = output_channel # out - if norm_type == "spatial": - self.conv_norm_out = SpatialNorm(block_out_channels[0], temb_channels) - else: - self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6) + self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6) self.conv_act = nn.SiLU() self.conv_out = CausalConv3d(block_out_channels[0], out_channels, kernel_size=3) self.gradient_checkpointing = False - def forward( - self, - sample: torch.Tensor, - latent_embeds: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - r"""The forward method of the `DecoderCausal3D` class.""" + def forward(self, sample: torch.Tensor) -> torch.Tensor: assert len(sample.shape) == 5, "The input tensor should have 5 dimensions" sample = self.conv_in(sample) @@ -739,33 +587,27 @@ def custom_forward(*inputs): sample = torch.utils.checkpoint.checkpoint( create_custom_forward(up_block), sample, - latent_embeds, use_reentrant=False, ) else: # middle - sample = torch.utils.checkpoint.checkpoint( - create_custom_forward(self.mid_block), sample, latent_embeds - ) + sample = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), sample) sample = sample.to(upscale_dtype) # up for up_block in self.up_blocks: - sample = torch.utils.checkpoint.checkpoint(create_custom_forward(up_block), sample, latent_embeds) + sample = torch.utils.checkpoint.checkpoint(create_custom_forward(up_block), sample) else: # middle - sample = self.mid_block(sample, latent_embeds) + sample = self.mid_block(sample) sample = sample.to(upscale_dtype) # up for up_block in self.up_blocks: - sample = up_block(sample, latent_embeds) + sample = up_block(sample) # post-process - if latent_embeds is None: - sample = self.conv_norm_out(sample) - else: - sample = self.conv_norm_out(sample, latent_embeds) + sample = self.conv_norm_out(sample) sample = self.conv_act(sample) sample = self.conv_out(sample) From d0036ff833b5fd5f4bddd6ea6886b5c365460917 Mon Sep 17 00:00:00 2001 From: Aryan Date: Thu, 12 Dec 2024 22:53:48 +0100 Subject: [PATCH 31/58] refactor --- .../autoencoder_kl_hunyuan_video.py | 130 ++++++++---------- 1 file changed, 56 insertions(+), 74 deletions(-) diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py b/src/diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py index db28ff43bade..5866db573708 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py @@ -12,14 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -from dataclasses import dataclass from typing import Optional, Tuple, Union import numpy as np import torch import torch.nn as nn import torch.nn.functional as F -from einops import rearrange from ...configuration_utils import ConfigMixin, register_to_config from ...utils import is_torch_version, logging @@ -28,18 +26,20 @@ from ..attention_processor import Attention from ..modeling_outputs import AutoencoderKLOutput from ..modeling_utils import ModelMixin -from .vae import BaseOutput, DecoderOutput, DiagonalGaussianDistribution +from .vae import DecoderOutput, DiagonalGaussianDistribution logger = logging.get_logger(__name__) # pylint: disable=invalid-name -def prepare_causal_attention_mask(n_frame: int, n_hw: int, dtype, device, batch_size: int = None): - seq_len = n_frame * n_hw +def prepare_causal_attention_mask( + num_frames: int, height_width: int, dtype: torch.dtype, device: torch.device, batch_size: int = None +): + seq_len = num_frames * height_width mask = torch.full((seq_len, seq_len), float("-inf"), dtype=dtype, device=device) for i in range(seq_len): - i_frame = i // n_hw - mask[i, : (i_frame + 1) * n_hw] = 0 + i_frame = i // height_width + mask[i, : (i_frame + 1) * height_width] = 0 if batch_size is not None: mask = mask.unsqueeze(0).expand(batch_size, -1, -1) return mask @@ -178,7 +178,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: return hidden_states -class UNetMidBlockCausal3D(nn.Module): +class HunyuanVideoMidBlock3D(nn.Module): def __init__( self, in_channels: int, @@ -243,19 +243,20 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states = self.resnets[0](hidden_states) for attn, resnet in zip(self.attentions, self.resnets[1:]): if attn is not None: - B, C, T, H, W = hidden_states.shape - hidden_states = rearrange(hidden_states, "b c f h w -> b (f h w) c") + batch_size, num_channels, num_frames, height, width = hidden_states.shape + hidden_states = hidden_states.permute(0, 2, 3, 4, 1).flatten(1, 3) attention_mask = prepare_causal_attention_mask( - T, H * W, hidden_states.dtype, hidden_states.device, batch_size=B + num_frames, height * width, hidden_states.dtype, hidden_states.device, batch_size=batch_size ) hidden_states = attn(hidden_states, attention_mask=attention_mask) - hidden_states = rearrange(hidden_states, "b (f h w) c -> b c f h w", f=T, h=H, w=W) + hidden_states = hidden_states.unflatten(1, (num_frames, height, width)).permute(0, 4, 1, 2, 3) + hidden_states = resnet(hidden_states) return hidden_states -class DownEncoderBlockCausal3D(nn.Module): +class HunyuanVideoDownBlock3D(nn.Module): def __init__( self, in_channels: int, @@ -268,7 +269,7 @@ def __init__( add_downsample: bool = True, downsample_stride: int = 2, downsample_padding: int = 1, - ): + ) -> None: super().__init__() resnets = [] @@ -312,20 +313,19 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: return hidden_states -class UpDecoderBlockCausal3D(nn.Module): +class HunyuanVideoUpBlock3D(nn.Module): def __init__( self, in_channels: int, out_channels: int, - resolution_idx: Optional[int] = None, dropout: float = 0.0, num_layers: int = 1, resnet_eps: float = 1e-6, resnet_act_fn: str = "swish", resnet_groups: int = 32, add_upsample: bool = True, - upsample_scale_factor=(2, 2, 2), - ): + upsample_scale_factor: Tuple[int, int, int] = (2, 2, 2), + ) -> None: super().__init__() resnets = [] @@ -358,8 +358,6 @@ def __init__( else: self.upsamplers = None - self.resolution_idx = resolution_idx - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: for resnet in self.resnets: hidden_states = resnet(hidden_states) @@ -381,10 +379,10 @@ def __init__( in_channels: int = 3, out_channels: int = 3, down_block_types: Tuple[str, ...] = ( - "DownEncoderBlockCausal3D", - "DownEncoderBlockCausal3D", - "DownEncoderBlockCausal3D", - "DownEncoderBlockCausal3D", + "HunyuanVideoDownBlock3D", + "HunyuanVideoDownBlock3D", + "HunyuanVideoDownBlock3D", + "HunyuanVideoDownBlock3D", ), block_out_channels: Tuple[int, ...] = (128, 256, 512, 512), layers_per_block: int = 2, @@ -424,7 +422,7 @@ def __init__( downsample_stride_T = (2,) if add_time_downsample else (1,) downsample_stride = tuple(downsample_stride_T + downsample_stride_HW) - down_block = DownEncoderBlockCausal3D( + down_block = HunyuanVideoDownBlock3D( num_layers=layers_per_block, in_channels=input_channel, out_channels=output_channel, @@ -438,7 +436,7 @@ def __init__( self.down_blocks.append(down_block) - self.mid_block = UNetMidBlockCausal3D( + self.mid_block = HunyuanVideoMidBlock3D( in_channels=block_out_channels[-1], resnet_eps=1e-6, resnet_act_fn=act_fn, @@ -494,10 +492,10 @@ def __init__( in_channels: int = 3, out_channels: int = 3, up_block_types: Tuple[str, ...] = ( - "UpDecoderBlockCausal3D", - "UpDecoderBlockCausal3D", - "UpDecoderBlockCausal3D", - "UpDecoderBlockCausal3D", + "HunyuanVideoUpBlock3D", + "HunyuanVideoUpBlock3D", + "HunyuanVideoUpBlock3D", + "HunyuanVideoUpBlock3D", ), block_out_channels: Tuple[int, ...] = (128, 256, 512, 512), layers_per_block: int = 2, @@ -516,7 +514,7 @@ def __init__( self.up_blocks = nn.ModuleList([]) # mid - self.mid_block = UNetMidBlockCausal3D( + self.mid_block = HunyuanVideoMidBlock3D( in_channels=block_out_channels[-1], resnet_eps=1e-6, resnet_act_fn=act_fn, @@ -547,7 +545,7 @@ def __init__( upsample_scale_factor_T = (2,) if add_time_upsample else (1,) upsample_scale_factor = tuple(upsample_scale_factor_T + upsample_scale_factor_HW) - up_block = UpDecoderBlockCausal3D( + up_block = HunyuanVideoUpBlock3D( num_layers=self.layers_per_block + 1, in_channels=prev_output_channel, out_channels=output_channel, @@ -568,10 +566,8 @@ def __init__( self.gradient_checkpointing = False - def forward(self, sample: torch.Tensor) -> torch.Tensor: - assert len(sample.shape) == 5, "The input tensor should have 5 dimensions" - - sample = self.conv_in(sample) + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.conv_in(hidden_states) upscale_dtype = next(iter(self.up_blocks.parameters())).dtype if self.training and self.gradient_checkpointing: @@ -584,40 +580,34 @@ def custom_forward(*inputs): # up for up_block in self.up_blocks: - sample = torch.utils.checkpoint.checkpoint( + hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(up_block), - sample, + hidden_states, use_reentrant=False, ) else: # middle - sample = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), sample) - sample = sample.to(upscale_dtype) + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), hidden_states) + hidden_states = hidden_states.to(upscale_dtype) # up for up_block in self.up_blocks: - sample = torch.utils.checkpoint.checkpoint(create_custom_forward(up_block), sample) + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(up_block), hidden_states) else: # middle - sample = self.mid_block(sample) - sample = sample.to(upscale_dtype) + hidden_states = self.mid_block(hidden_states) + hidden_states = hidden_states.to(upscale_dtype) # up for up_block in self.up_blocks: - sample = up_block(sample) + hidden_states = up_block(hidden_states) # post-process - sample = self.conv_norm_out(sample) - sample = self.conv_act(sample) - sample = self.conv_out(sample) - - return sample - + hidden_states = self.conv_norm_out(hidden_states) + hidden_states = self.conv_act(hidden_states) + hidden_states = self.conv_out(hidden_states) -@dataclass -class DecoderOutput2(BaseOutput): - sample: torch.Tensor - posterior: Optional[DiagonalGaussianDistribution] = None + return hidden_states class AutoencoderKLHunyuanVideo(ModelMixin, ConfigMixin): @@ -638,16 +628,16 @@ def __init__( out_channels: int = 3, latent_channels: int = 16, down_block_types: Tuple[str, ...] = ( - "DownEncoderBlockCausal3D", - "DownEncoderBlockCausal3D", - "DownEncoderBlockCausal3D", - "DownEncoderBlockCausal3D", + "HunyuanVideoDownBlock3D", + "HunyuanVideoDownBlock3D", + "HunyuanVideoDownBlock3D", + "HunyuanVideoDownBlock3D", ), up_block_types: Tuple[str, ...] = ( - "UpDecoderBlockCausal3D", - "UpDecoderBlockCausal3D", - "UpDecoderBlockCausal3D", - "UpDecoderBlockCausal3D", + "HunyuanVideoUpBlock3D", + "HunyuanVideoUpBlock3D", + "HunyuanVideoUpBlock3D", + "HunyuanVideoUpBlock3D", ), block_out_channels: Tuple[int] = (128, 256, 512, 512), layers_per_block: int = 2, @@ -1050,9 +1040,8 @@ def forward( sample: torch.Tensor, sample_posterior: bool = False, return_dict: bool = True, - return_posterior: bool = False, generator: Optional[torch.Generator] = None, - ) -> Union[DecoderOutput2, torch.Tensor]: + ) -> Union[DecoderOutput, torch.Tensor]: r""" Args: sample (`torch.Tensor`): Input sample. @@ -1067,14 +1056,7 @@ def forward( z = posterior.sample(generator=generator) else: z = posterior.mode() - dec = self.decode(z).sample - + dec = self.decode(z) if not return_dict: - if return_posterior: - return (dec, posterior) - else: - return (dec,) - if return_posterior: - return DecoderOutput2(sample=dec, posterior=posterior) - else: - return DecoderOutput2(sample=dec) + return (dec,) + return dec From da536201672553c69e01dca9398bc1dc1f2ece8c Mon Sep 17 00:00:00 2001 From: Aryan Date: Thu, 12 Dec 2024 23:18:24 +0100 Subject: [PATCH 32/58] refactor --- .../autoencoder_kl_hunyuan_video.py | 276 +++++++++--------- 1 file changed, 146 insertions(+), 130 deletions(-) diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py b/src/diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py index 5866db573708..c52f8edede34 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py @@ -390,7 +390,7 @@ def __init__( act_fn: str = "silu", double_z: bool = True, mid_block_add_attention=True, - time_compression_ratio: int = 4, + temporal_compression_ratio: int = 4, spatial_compression_ratio: int = 8, ) -> None: super().__init__() @@ -405,18 +405,18 @@ def __init__( output_channel = block_out_channels[i] is_final_block = i == len(block_out_channels) - 1 num_spatial_downsample_layers = int(np.log2(spatial_compression_ratio)) - num_time_downsample_layers = int(np.log2(time_compression_ratio)) + num_time_downsample_layers = int(np.log2(temporal_compression_ratio)) - if time_compression_ratio == 4: + if temporal_compression_ratio == 4: add_spatial_downsample = bool(i < num_spatial_downsample_layers) add_time_downsample = bool( i >= (len(block_out_channels) - 1 - num_time_downsample_layers) and not is_final_block ) - elif time_compression_ratio == 8: + elif temporal_compression_ratio == 8: add_spatial_downsample = bool(i < num_spatial_downsample_layers) add_time_downsample = bool(i < num_time_downsample_layers) else: - raise ValueError(f"Unsupported time_compression_ratio: {time_compression_ratio}") + raise ValueError(f"Unsupported time_compression_ratio: {temporal_compression_ratio}") downsample_stride_HW = (2, 2) if add_spatial_downsample else (1, 1) downsample_stride_T = (2,) if add_time_downsample else (1,) @@ -643,16 +643,15 @@ def __init__( layers_per_block: int = 2, act_fn: str = "silu", norm_num_groups: int = 32, - sample_size: int = 256, sample_tsize: int = 64, scaling_factor: float = 0.476986, spatial_compression_ratio: int = 8, - time_compression_ratio: int = 4, + temporal_compression_ratio: int = 4, mid_block_add_attention: bool = True, - ): + ) -> None: super().__init__() - self.time_compression_ratio = time_compression_ratio + self.time_compression_ratio = temporal_compression_ratio self.encoder = EncoderCausal3D( in_channels=in_channels, @@ -664,7 +663,7 @@ def __init__( act_fn=act_fn, double_z=True, mid_block_add_attention=mid_block_add_attention, - time_compression_ratio=time_compression_ratio, + temporal_compression_ratio=temporal_compression_ratio, spatial_compression_ratio=spatial_compression_ratio, ) @@ -676,7 +675,7 @@ def __init__( layers_per_block=layers_per_block, norm_num_groups=norm_num_groups, act_fn=act_fn, - time_compression_ratio=time_compression_ratio, + time_compression_ratio=temporal_compression_ratio, spatial_compression_ratio=spatial_compression_ratio, mid_block_add_attention=mid_block_add_attention, ) @@ -684,120 +683,141 @@ def __init__( self.quant_conv = nn.Conv3d(2 * latent_channels, 2 * latent_channels, kernel_size=1) self.post_quant_conv = nn.Conv3d(latent_channels, latent_channels, kernel_size=1) + self.spatial_compression_ratio = spatial_compression_ratio + self.temporal_compression_ratio = temporal_compression_ratio + + # When decoding a batch of video latents at a time, one can save memory by slicing across the batch dimension + # to perform decoding of a single video latent at a time. self.use_slicing = False - self.use_spatial_tiling = False - self.use_temporal_tiling = False + + # When decoding spatially large video latents, the memory requirement is very high. By breaking the video latent + # frames spatially into smaller tiles and performing multiple forward passes for decoding, and then blending the + # intermediate tiles together, the memory requirement can be lowered. + self.use_tiling = False + + # When decoding temporally long video latents, the memory requirement is very high. By decoding latent frames + # at a fixed frame batch size (based on `self.num_latent_frames_batch_sizes`), the memory requirement can be lowered. + self.use_framewise_encoding = True + self.use_framewise_decoding = True # only relevant if vae tiling is enabled self.tile_sample_min_tsize = sample_tsize - self.tile_latent_min_tsize = sample_tsize // time_compression_ratio + self.tile_latent_min_tsize = sample_tsize // temporal_compression_ratio - self.tile_sample_min_size = self.config.sample_size - sample_size = ( - self.config.sample_size[0] - if isinstance(self.config.sample_size, (list, tuple)) - else self.config.sample_size - ) - self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.block_out_channels) - 1))) - self.tile_overlap_factor = 0.25 + # The minimal tile height and width for spatial tiling to be used + self.tile_sample_min_height = 256 + self.tile_sample_min_width = 256 + + # The minimal distance between two spatial tiles + self.tile_sample_stride_height = 192 + self.tile_sample_stride_width = 192 def _set_gradient_checkpointing(self, module, value=False): if isinstance(module, (EncoderCausal3D, DecoderCausal3D)): module.gradient_checkpointing = value - def enable_temporal_tiling(self, use_tiling: bool = True): - self.use_temporal_tiling = use_tiling - - def disable_temporal_tiling(self): - self.enable_temporal_tiling(False) - - def enable_spatial_tiling(self, use_tiling: bool = True): - self.use_spatial_tiling = use_tiling - - def disable_spatial_tiling(self): - self.enable_spatial_tiling(False) - - def enable_tiling(self, use_tiling: bool = True): + def enable_tiling( + self, + tile_sample_min_height: Optional[int] = None, + tile_sample_min_width: Optional[int] = None, + tile_sample_stride_height: Optional[float] = None, + tile_sample_stride_width: Optional[float] = None, + ) -> None: r""" Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow - processing larger videos. + processing larger images. + + Args: + tile_sample_min_height (`int`, *optional*): + The minimum height required for a sample to be separated into tiles across the height dimension. + tile_sample_min_width (`int`, *optional*): + The minimum width required for a sample to be separated into tiles across the width dimension. + tile_sample_stride_height (`int`, *optional*): + The minimum amount of overlap between two consecutive vertical tiles. This is to ensure that there are + no tiling artifacts produced across the height dimension. + tile_sample_stride_width (`int`, *optional*): + The stride between two consecutive horizontal tiles. This is to ensure that there are no tiling + artifacts produced across the width dimension. """ - self.enable_spatial_tiling(use_tiling) - self.enable_temporal_tiling(use_tiling) + self.use_tiling = True + self.tile_sample_min_height = tile_sample_min_height or self.tile_sample_min_height + self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width + self.tile_sample_stride_height = tile_sample_stride_height or self.tile_sample_stride_height + self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width - def disable_tiling(self): + def disable_tiling(self) -> None: r""" Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing decoding in one step. """ - self.disable_spatial_tiling() - self.disable_temporal_tiling() + self.use_tiling = False - def enable_slicing(self): + def enable_slicing(self) -> None: r""" Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. """ self.use_slicing = True - def disable_slicing(self): + def disable_slicing(self) -> None: r""" Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing decoding in one step. """ self.use_slicing = False + def _encode(self, x: torch.Tensor) -> torch.Tensor: + batch_size, num_channels, num_frames, height, width = x.shape + + if self.use_framewise_decoding and num_frames > self.tile_sample_min_tsize: + return self.temporal_tiled_encode(x) + + if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height): + return self.tiled_encode(x) + + x = self.encoder(x) + enc = self.quant_conv(x) + return enc + @apply_forward_hook def encode( self, x: torch.Tensor, return_dict: bool = True ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]: - """ - Encode a batch of images/videos into latents. + r""" + Encode a batch of images into latents. Args: - x (`torch.Tensor`): Input batch of images/videos. + x (`torch.Tensor`): Input batch of images. return_dict (`bool`, *optional*, defaults to `True`): Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple. Returns: - The latent representations of the encoded images/videos. If `return_dict` is True, a + The latent representations of the encoded videos. If `return_dict` is True, a [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned. """ - assert len(x.shape) == 5, "The input tensor should have 5 dimensions" - - if self.use_temporal_tiling and x.shape[2] > self.tile_sample_min_tsize: - return self.temporal_tiled_encode(x, return_dict=return_dict) - - if self.use_spatial_tiling and ( - x.shape[-1] > self.tile_sample_min_size or x.shape[-2] > self.tile_sample_min_size - ): - return self.spatial_tiled_encode(x, return_dict=return_dict) - if self.use_slicing and x.shape[0] > 1: - encoded_slices = [self.encoder(x_slice) for x_slice in x.split(1)] + encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)] h = torch.cat(encoded_slices) else: - h = self.encoder(x) + h = self._encode(x) - moments = self.quant_conv(h) - posterior = DiagonalGaussianDistribution(moments) + posterior = DiagonalGaussianDistribution(h) if not return_dict: return (posterior,) - return AutoencoderKLOutput(latent_dist=posterior) def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]: - assert len(z.shape) == 5, "The input tensor should have 5 dimensions" + batch_size, num_channels, num_frames, height, width = z.shape + tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio + tile_latent_min_width = self.tile_sample_stride_width // self.spatial_compression_ratio - if self.use_temporal_tiling and z.shape[2] > self.tile_latent_min_tsize: + if self.use_framewise_decoding and num_frames > self.tile_latent_min_tsize: return self.temporal_tiled_decode(z, return_dict=return_dict) - if self.use_spatial_tiling and ( - z.shape[-1] > self.tile_latent_min_size or z.shape[-2] > self.tile_latent_min_size - ): - return self.spatial_tiled_decode(z, return_dict=return_dict) + if self.use_tiling and (width > tile_latent_min_width or height > tile_latent_min_height): + return self.tiled_decode(z, return_dict=return_dict) z = self.post_quant_conv(z) dec = self.decoder(z) @@ -809,8 +829,8 @@ def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOut @apply_forward_hook def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]: - """ - Decode a batch of images/videos. + r""" + Decode a batch of images. Args: z (`torch.Tensor`): Input batch of latent vectors. @@ -821,7 +841,6 @@ def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutp [`~models.vae.DecoderOutput`] or `tuple`: If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is returned. - """ if self.use_slicing and z.shape[0] > 1: decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)] @@ -858,41 +877,40 @@ def blend_t(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch. ) return b - def spatial_tiled_encode( - self, x: torch.Tensor, return_dict: bool = True, return_moments: bool = False - ) -> AutoencoderKLOutput: - r"""Encode a batch of images/videos using a tiled encoder. - - When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several - steps. This is useful to keep memory use constant regardless of image/videos size. The end result of tiled - encoding is different from non-tiled encoding because each tile uses a different encoder. To avoid tiling - artifacts, the tiles overlap and are blended together to form a smooth output. You may still see tile-sized - changes in the output, but they should be much less noticeable. + def tiled_encode(self, x: torch.Tensor) -> AutoencoderKLOutput: + r"""Encode a batch of images using a tiled encoder. Args: - x (`torch.Tensor`): Input batch of images/videos. - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple. + x (`torch.Tensor`): Input batch of videos. Returns: - [`~models.autoencoder_kl.AutoencoderKLOutput`] or `tuple`: - If return_dict is True, a [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain - `tuple` is returned. + `torch.Tensor`: + The latent representation of the encoded videos. """ - overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor)) - blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor) - row_limit = self.tile_latent_min_size - blend_extent + batch_size, num_channels, num_frames, height, width = x.shape + latent_height = height // self.spatial_compression_ratio + latent_width = width // self.spatial_compression_ratio - # Split video into tiles and encode them separately. + tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio + tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio + tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio + tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio + + blend_height = tile_latent_min_height - tile_latent_stride_height + blend_width = tile_latent_min_width - tile_latent_stride_width + + # Split x into overlapping tiles and encode them separately. + # The tiles have an overlap to avoid seams between tiles. rows = [] - for i in range(0, x.shape[-2], overlap_size): + for i in range(0, height, self.tile_sample_stride_height): row = [] - for j in range(0, x.shape[-1], overlap_size): + for j in range(0, width, self.tile_sample_stride_width): tile = x[:, :, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size] tile = self.encoder(tile) tile = self.quant_conv(tile) row.append(tile) rows.append(row) + result_rows = [] for i, row in enumerate(rows): result_row = [] @@ -900,25 +918,18 @@ def spatial_tiled_encode( # blend the above tile and the left tile # to the current tile and add the current tile to the result row if i > 0: - tile = self.blend_v(rows[i - 1][j], tile, blend_extent) + tile = self.blend_v(rows[i - 1][j], tile, blend_height) if j > 0: - tile = self.blend_h(row[j - 1], tile, blend_extent) - result_row.append(tile[:, :, :, :row_limit, :row_limit]) + tile = self.blend_h(row[j - 1], tile, blend_width) + result_row.append(tile[:, :, :, :tile_latent_stride_height, :tile_latent_stride_width]) result_rows.append(torch.cat(result_row, dim=-1)) - moments = torch.cat(result_rows, dim=-2) - if return_moments: - return moments - - posterior = DiagonalGaussianDistribution(moments) - if not return_dict: - return (posterior,) + enc = torch.cat(result_rows, dim=3)[:, :, :, :latent_height, :latent_width] + return enc - return AutoencoderKLOutput(latent_dist=posterior) - - def spatial_tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]: + def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]: r""" - Decode a batch of images/videos using a tiled decoder. + Decode a batch of images using a tiled decoder. Args: z (`torch.Tensor`): Input batch of latent vectors. @@ -930,21 +941,31 @@ def spatial_tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Uni If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is returned. """ - overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor)) - blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor) - row_limit = self.tile_sample_min_size - blend_extent + + batch_size, num_channels, num_frames, height, width = z.shape + sample_height = height * self.spatial_compression_ratio + sample_width = width * self.spatial_compression_ratio + + tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio + tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio + tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio + tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio + + blend_height = self.tile_sample_min_height - self.tile_sample_stride_height + blend_width = self.tile_sample_min_width - self.tile_sample_stride_width # Split z into overlapping tiles and decode them separately. # The tiles have an overlap to avoid seams between tiles. rows = [] - for i in range(0, z.shape[-2], overlap_size): + for i in range(0, height, tile_latent_stride_height): row = [] - for j in range(0, z.shape[-1], overlap_size): - tile = z[:, :, :, i : i + self.tile_latent_min_size, j : j + self.tile_latent_min_size] + for j in range(0, width, tile_latent_stride_width): + tile = z[:, :, :, i : i + tile_latent_min_height, j : j + tile_latent_min_width] tile = self.post_quant_conv(tile) decoded = self.decoder(tile) row.append(decoded) rows.append(row) + result_rows = [] for i, row in enumerate(rows): result_row = [] @@ -952,19 +973,19 @@ def spatial_tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Uni # blend the above tile and the left tile # to the current tile and add the current tile to the result row if i > 0: - tile = self.blend_v(rows[i - 1][j], tile, blend_extent) + tile = self.blend_v(rows[i - 1][j], tile, blend_height) if j > 0: - tile = self.blend_h(row[j - 1], tile, blend_extent) - result_row.append(tile[:, :, :, :row_limit, :row_limit]) + tile = self.blend_h(row[j - 1], tile, blend_width) + result_row.append(tile[:, :, :, : self.tile_sample_stride_height, : self.tile_sample_stride_width]) result_rows.append(torch.cat(result_row, dim=-1)) - dec = torch.cat(result_rows, dim=-2) + dec = torch.cat(result_rows, dim=3)[:, :, :, :sample_height, :sample_width] + if not return_dict: return (dec,) - return DecoderOutput(sample=dec) - def temporal_tiled_encode(self, x: torch.Tensor, return_dict: bool = True) -> AutoencoderKLOutput: + def temporal_tiled_encode(self, x: torch.Tensor) -> AutoencoderKLOutput: B, C, T, H, W = x.shape overlap_size = int(self.tile_sample_min_tsize * (1 - self.tile_overlap_factor)) blend_extent = int(self.tile_latent_min_tsize * self.tile_overlap_factor) @@ -974,10 +995,10 @@ def temporal_tiled_encode(self, x: torch.Tensor, return_dict: bool = True) -> Au row = [] for i in range(0, T, overlap_size): tile = x[:, :, i : i + self.tile_sample_min_tsize + 1, :, :] - if self.use_spatial_tiling and ( + if self.use_tiling and ( tile.shape[-1] > self.tile_sample_min_size or tile.shape[-2] > self.tile_sample_min_size ): - tile = self.spatial_tiled_encode(tile, return_moments=True) + tile = self.tiled_encode(tile, return_moments=True) else: tile = self.encoder(tile) tile = self.quant_conv(tile) @@ -992,13 +1013,8 @@ def temporal_tiled_encode(self, x: torch.Tensor, return_dict: bool = True) -> Au else: result_row.append(tile[:, :, : t_limit + 1, :, :]) - moments = torch.cat(result_row, dim=2) - posterior = DiagonalGaussianDistribution(moments) - - if not return_dict: - return (posterior,) - - return AutoencoderKLOutput(latent_dist=posterior) + enc = torch.cat(result_row, dim=2) + return enc def temporal_tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]: # Split z into overlapping tiles and decode them separately. @@ -1011,10 +1027,10 @@ def temporal_tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Un row = [] for i in range(0, T, overlap_size): tile = z[:, :, i : i + self.tile_latent_min_tsize + 1, :, :] - if self.use_spatial_tiling and ( + if self.use_tiling and ( tile.shape[-1] > self.tile_latent_min_size or tile.shape[-2] > self.tile_latent_min_size ): - decoded = self.spatial_tiled_decode(tile, return_dict=True).sample + decoded = self.tiled_decode(tile, return_dict=True).sample else: tile = self.post_quant_conv(tile) decoded = self.decoder(tile) From f143b023cb51af8aff9379322cdc917894139c8b Mon Sep 17 00:00:00 2001 From: Aryan Date: Thu, 12 Dec 2024 23:22:15 +0100 Subject: [PATCH 33/58] refactor --- .../autoencoder_kl_hunyuan_video.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py b/src/diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py index c52f8edede34..b6a2363a7de4 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py @@ -700,13 +700,13 @@ def __init__( self.use_framewise_encoding = True self.use_framewise_decoding = True - # only relevant if vae tiling is enabled - self.tile_sample_min_tsize = sample_tsize - self.tile_latent_min_tsize = sample_tsize // temporal_compression_ratio # The minimal tile height and width for spatial tiling to be used self.tile_sample_min_height = 256 self.tile_sample_min_width = 256 + + # The minimal tile temporal batch size for temporal tiling to be used + self.tile_sample_min_tsize = 64 # The minimal distance between two spatial tiles self.tile_sample_stride_height = 192 @@ -812,8 +812,9 @@ def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOut batch_size, num_channels, num_frames, height, width = z.shape tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio tile_latent_min_width = self.tile_sample_stride_width // self.spatial_compression_ratio + tile_latent_min_num_frames = self.tile_sample_min_tsize // self.temporal_compression_ratio - if self.use_framewise_decoding and num_frames > self.tile_latent_min_tsize: + if self.use_framewise_decoding and num_frames > tile_latent_min_num_frames: return self.temporal_tiled_decode(z, return_dict=return_dict) if self.use_tiling and (width > tile_latent_min_width or height > tile_latent_min_height): @@ -987,9 +988,10 @@ def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[Decod def temporal_tiled_encode(self, x: torch.Tensor) -> AutoencoderKLOutput: B, C, T, H, W = x.shape + tile_latent_min_tsize = self.tile_sample_min_tsize // self.temporal_compression_ratio overlap_size = int(self.tile_sample_min_tsize * (1 - self.tile_overlap_factor)) - blend_extent = int(self.tile_latent_min_tsize * self.tile_overlap_factor) - t_limit = self.tile_latent_min_tsize - blend_extent + blend_extent = int(tile_latent_min_tsize * self.tile_overlap_factor) + t_limit = tile_latent_min_tsize - blend_extent # Split the video into tiles and encode them separately. row = [] @@ -1020,13 +1022,14 @@ def temporal_tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Un # Split z into overlapping tiles and decode them separately. B, C, T, H, W = z.shape - overlap_size = int(self.tile_latent_min_tsize * (1 - self.tile_overlap_factor)) + tile_latent_min_tsize = self.tile_sample_min_tsize // self.temporal_compression_ratio + overlap_size = int(tile_latent_min_tsize * (1 - self.tile_overlap_factor)) blend_extent = int(self.tile_sample_min_tsize * self.tile_overlap_factor) t_limit = self.tile_sample_min_tsize - blend_extent row = [] for i in range(0, T, overlap_size): - tile = z[:, :, i : i + self.tile_latent_min_tsize + 1, :, :] + tile = z[:, :, i : i + tile_latent_min_tsize + 1, :, :] if self.use_tiling and ( tile.shape[-1] > self.tile_latent_min_size or tile.shape[-2] > self.tile_latent_min_size ): From 2a72d20004420047860aa950e15e7c6b32bc0191 Mon Sep 17 00:00:00 2001 From: Aryan Date: Thu, 12 Dec 2024 23:31:51 +0100 Subject: [PATCH 34/58] grad checkpointing --- .../autoencoder_kl_hunyuan_video.py | 180 +++++++++++++----- 1 file changed, 128 insertions(+), 52 deletions(-) diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py b/src/diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py index b6a2363a7de4..b79b4a66f7ac 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py @@ -12,12 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Tuple, Union +from typing import Any, Dict, Optional, Tuple, Union import numpy as np import torch import torch.nn as nn import torch.nn.functional as F +import torch.utils.checkpoint from ...configuration_utils import ConfigMixin, register_to_config from ...utils import is_torch_version, logging @@ -240,18 +241,51 @@ def __init__( self.resnets = nn.ModuleList(resnets) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - hidden_states = self.resnets[0](hidden_states) - for attn, resnet in zip(self.attentions, self.resnets[1:]): - if attn is not None: - batch_size, num_channels, num_frames, height, width = hidden_states.shape - hidden_states = hidden_states.permute(0, 2, 3, 4, 1).flatten(1, 3) - attention_mask = prepare_causal_attention_mask( - num_frames, height * width, hidden_states.dtype, hidden_states.device, batch_size=batch_size + if torch.is_grad_enabled() and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(self.resnets[0]), hidden_states, **ckpt_kwargs + ) + + for attn, resnet in zip(self.attentions, self.resnets[1:]): + if attn is not None: + batch_size, num_channels, num_frames, height, width = hidden_states.shape + hidden_states = hidden_states.permute(0, 2, 3, 4, 1).flatten(1, 3) + attention_mask = prepare_causal_attention_mask( + num_frames, height * width, hidden_states.dtype, hidden_states.device, batch_size=batch_size + ) + hidden_states = attn(hidden_states, attention_mask=attention_mask) + hidden_states = hidden_states.unflatten(1, (num_frames, height, width)).permute(0, 4, 1, 2, 3) + + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, **ckpt_kwargs ) - hidden_states = attn(hidden_states, attention_mask=attention_mask) - hidden_states = hidden_states.unflatten(1, (num_frames, height, width)).permute(0, 4, 1, 2, 3) - hidden_states = resnet(hidden_states) + else: + hidden_states = self.resnets[0](hidden_states) + + for attn, resnet in zip(self.attentions, self.resnets[1:]): + if attn is not None: + batch_size, num_channels, num_frames, height, width = hidden_states.shape + hidden_states = hidden_states.permute(0, 2, 3, 4, 1).flatten(1, 3) + attention_mask = prepare_causal_attention_mask( + num_frames, height * width, hidden_states.dtype, hidden_states.device, batch_size=batch_size + ) + hidden_states = attn(hidden_states, attention_mask=attention_mask) + hidden_states = hidden_states.unflatten(1, (num_frames, height, width)).permute(0, 4, 1, 2, 3) + + hidden_states = resnet(hidden_states) return hidden_states @@ -303,8 +337,26 @@ def __init__( self.downsamplers = None def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - for resnet in self.resnets: - hidden_states = resnet(hidden_states) + if torch.is_grad_enabled() and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + + for resnet in self.resnets: + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, **ckpt_kwargs + ) + else: + for resnet in self.resnets: + hidden_states = resnet(hidden_states) if self.downsamplers is not None: for downsampler in self.downsamplers: @@ -359,8 +411,27 @@ def __init__( self.upsamplers = None def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - for resnet in self.resnets: - hidden_states = resnet(hidden_states) + if torch.is_grad_enabled() and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + + for resnet in self.resnets: + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, **ckpt_kwargs + ) + + else: + for resnet in self.resnets: + hidden_states = resnet(hidden_states) if self.upsamplers is not None: for upsampler in self.upsamplers: @@ -401,6 +472,9 @@ def __init__( output_channel = block_out_channels[0] for i, down_block_type in enumerate(down_block_types): + if down_block_type != "HunyuanVideoDownBlock3D": + raise ValueError(f"Unsupported down_block_type: {down_block_type}") + input_channel = output_channel output_channel = block_out_channels[i] is_final_block = i == len(block_out_channels) - 1 @@ -454,27 +528,35 @@ def __init__( self.gradient_checkpointing = False def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - # 1. Input layer hidden_states = self.conv_in(hidden_states) - use_reentrant = is_torch_version("<=", "1.11.0") + if torch.is_grad_enabled() and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} - def create_block_forward(block): - if torch.is_grad_enabled() and self.gradient_checkpointing: - return lambda *inputs: torch.utils.checkpoint.checkpoint( - lambda *x: block(*x), *inputs, use_reentrant=use_reentrant + for down_block in self.down_blocks: + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(down_block), hidden_states, **ckpt_kwargs ) - else: - return block - # 2. Down blocks - for down_block in self.down_blocks: - hidden_states = create_block_forward(down_block)(hidden_states) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(self.mid_block), hidden_states, **ckpt_kwargs + ) + else: + for down_block in self.down_blocks: + hidden_states = down_block(hidden_states) - # 3. Mid block - hidden_states = self.mid_block(hidden_states) + hidden_states = self.mid_block(hidden_states) - # 4. Output layers hidden_states = self.conv_norm_out(hidden_states) hidden_states = self.conv_act(hidden_states) hidden_states = self.conv_out(hidden_states) @@ -501,7 +583,6 @@ def __init__( layers_per_block: int = 2, norm_num_groups: int = 32, act_fn: str = "silu", - norm_type: str = "group", mid_block_add_attention=True, time_compression_ratio: int = 4, spatial_compression_ratio: int = 8, @@ -527,6 +608,9 @@ def __init__( reversed_block_out_channels = list(reversed(block_out_channels)) output_channel = reversed_block_out_channels[0] for i, up_block_type in enumerate(up_block_types): + if up_block_type != "HunyuanVideoUpBlock3D": + raise ValueError(f"Unsupported up_block_type: {up_block_type}") + prev_output_channel = output_channel output_channel = reversed_block_out_channels[i] is_final_block = i == len(block_out_channels) - 1 @@ -569,36 +653,30 @@ def __init__( def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states = self.conv_in(hidden_states) - upscale_dtype = next(iter(self.up_blocks.parameters())).dtype - if self.training and self.gradient_checkpointing: + if torch.is_grad_enabled() and self.gradient_checkpointing: - def create_custom_forward(module): + def create_custom_forward(module, return_dict=None): def custom_forward(*inputs): - return module(*inputs) + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) return custom_forward - # up + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(self.mid_block), hidden_states, **ckpt_kwargs + ) + for up_block in self.up_blocks: hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(up_block), - hidden_states, - use_reentrant=False, + create_custom_forward(up_block), hidden_states, **ckpt_kwargs ) - else: - # middle - hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), hidden_states) - hidden_states = hidden_states.to(upscale_dtype) - - # up - for up_block in self.up_blocks: - hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(up_block), hidden_states) else: - # middle hidden_states = self.mid_block(hidden_states) - hidden_states = hidden_states.to(upscale_dtype) - # up for up_block in self.up_blocks: hidden_states = up_block(hidden_states) @@ -643,7 +721,6 @@ def __init__( layers_per_block: int = 2, act_fn: str = "silu", norm_num_groups: int = 32, - sample_tsize: int = 64, scaling_factor: float = 0.476986, spatial_compression_ratio: int = 8, temporal_compression_ratio: int = 4, @@ -700,11 +777,10 @@ def __init__( self.use_framewise_encoding = True self.use_framewise_decoding = True - # The minimal tile height and width for spatial tiling to be used self.tile_sample_min_height = 256 self.tile_sample_min_width = 256 - + # The minimal tile temporal batch size for temporal tiling to be used self.tile_sample_min_tsize = 64 From d0c61e074d119b71925dd3e54808cda4671bfa21 Mon Sep 17 00:00:00 2001 From: Aryan Date: Thu, 12 Dec 2024 23:49:23 +0100 Subject: [PATCH 35/58] autoencoder test --- .../autoencoder_kl_hunyuan_video.py | 20 ++- .../test_models_autoencoder_hunyuan_video.py | 159 ++++++++++++++++++ 2 files changed, 171 insertions(+), 8 deletions(-) create mode 100644 tests/models/autoencoders/test_models_autoencoder_hunyuan_video.py diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py b/src/diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py index b79b4a66f7ac..3914d9740790 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py @@ -240,6 +240,8 @@ def __init__( self.attentions = nn.ModuleList(attentions) self.resnets = nn.ModuleList(resnets) + self.gradient_checkpointing = False + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: if torch.is_grad_enabled() and self.gradient_checkpointing: @@ -336,6 +338,8 @@ def __init__( else: self.downsamplers = None + self.gradient_checkpointing = False + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: if torch.is_grad_enabled() and self.gradient_checkpointing: @@ -410,6 +414,8 @@ def __init__( else: self.upsamplers = None + self.gradient_checkpointing = False + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: if torch.is_grad_enabled() and self.gradient_checkpointing: @@ -440,7 +446,7 @@ def custom_forward(*inputs): return hidden_states -class EncoderCausal3D(nn.Module): +class HunyuanVideoEncoder3D(nn.Module): r""" Causal encoder for 3D video-like data introduced in [Hunyuan Video](https://huggingface.co/papers/2412.03603). """ @@ -564,7 +570,7 @@ def custom_forward(*inputs): return hidden_states -class DecoderCausal3D(nn.Module): +class HunyuanVideoDecoder3D(nn.Module): r""" Causal decoder for 3D video-like data introduced in [Hunyuan Video](https://huggingface.co/papers/2412.03603). """ @@ -730,7 +736,7 @@ def __init__( self.time_compression_ratio = temporal_compression_ratio - self.encoder = EncoderCausal3D( + self.encoder = HunyuanVideoEncoder3D( in_channels=in_channels, out_channels=latent_channels, down_block_types=down_block_types, @@ -744,7 +750,7 @@ def __init__( spatial_compression_ratio=spatial_compression_ratio, ) - self.decoder = DecoderCausal3D( + self.decoder = HunyuanVideoDecoder3D( in_channels=latent_channels, out_channels=out_channels, up_block_types=up_block_types, @@ -789,7 +795,7 @@ def __init__( self.tile_sample_stride_width = 192 def _set_gradient_checkpointing(self, module, value=False): - if isinstance(module, (EncoderCausal3D, DecoderCausal3D)): + if isinstance(module, (HunyuanVideoEncoder3D, HunyuanVideoDecoder3D)): module.gradient_checkpointing = value def enable_tiling( @@ -1151,7 +1157,5 @@ def forward( z = posterior.sample(generator=generator) else: z = posterior.mode() - dec = self.decode(z) - if not return_dict: - return (dec,) + dec = self.decode(z, return_dict=return_dict) return dec diff --git a/tests/models/autoencoders/test_models_autoencoder_hunyuan_video.py b/tests/models/autoencoders/test_models_autoencoder_hunyuan_video.py new file mode 100644 index 000000000000..826ac30d5f2f --- /dev/null +++ b/tests/models/autoencoders/test_models_autoencoder_hunyuan_video.py @@ -0,0 +1,159 @@ +# coding=utf-8 +# Copyright 2024 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch + +from diffusers import AutoencoderKLHunyuanVideo +from diffusers.utils.testing_utils import ( + enable_full_determinism, + floats_tensor, + torch_device, +) + +from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin + + +enable_full_determinism() + + +class AutoencoderKLHunyuanVideoTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase): + model_class = AutoencoderKLHunyuanVideo + main_input_name = "sample" + base_precision = 1e-2 + + def get_autoencoder_kl_hunyuan_video_config(self): + return { + "in_channels": 3, + "out_channels": 3, + "latent_channels": 4, + "down_block_types": ( + "HunyuanVideoDownBlock3D", + "HunyuanVideoDownBlock3D", + ), + "up_block_types": ( + "HunyuanVideoUpBlock3D", + "HunyuanVideoUpBlock3D", + ), + "block_out_channels": (8, 8, 8, 8), + "layers_per_block": 1, + "act_fn": "silu", + "norm_num_groups": 4, + "scaling_factor": 0.476986, + "spatial_compression_ratio": 8, + "temporal_compression_ratio": 4, + "mid_block_add_attention": True, + } + + @property + def dummy_input(self): + batch_size = 2 + num_frames = 9 + num_channels = 3 + sizes = (16, 16) + + image = floats_tensor((batch_size, num_channels, num_frames) + sizes).to(torch_device) + + return {"sample": image} + + @property + def input_shape(self): + return (3, 9, 16, 16) + + @property + def output_shape(self): + return (3, 9, 16, 16) + + def prepare_init_args_and_inputs_for_common(self): + init_dict = self.get_autoencoder_kl_hunyuan_video_config() + inputs_dict = self.dummy_input + return init_dict, inputs_dict + + def test_enable_disable_tiling(self): + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + + torch.manual_seed(0) + model = self.model_class(**init_dict).to(torch_device) + + inputs_dict.update({"return_dict": False}) + + torch.manual_seed(0) + output_without_tiling = model(**inputs_dict, generator=torch.manual_seed(0))[0] + + torch.manual_seed(0) + model.enable_tiling() + output_with_tiling = model(**inputs_dict, generator=torch.manual_seed(0))[0] + + self.assertLess( + (output_without_tiling.detach().cpu().numpy() - output_with_tiling.detach().cpu().numpy()).max(), + 0.5, + "VAE tiling should not affect the inference results", + ) + + torch.manual_seed(0) + model.disable_tiling() + output_without_tiling_2 = model(**inputs_dict, generator=torch.manual_seed(0))[0] + + self.assertEqual( + output_without_tiling.detach().cpu().numpy().all(), + output_without_tiling_2.detach().cpu().numpy().all(), + "Without tiling outputs should match with the outputs when tiling is manually disabled.", + ) + + def test_enable_disable_slicing(self): + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + + torch.manual_seed(0) + model = self.model_class(**init_dict).to(torch_device) + + inputs_dict.update({"return_dict": False}) + + torch.manual_seed(0) + output_without_slicing = model(**inputs_dict, generator=torch.manual_seed(0))[0] + + torch.manual_seed(0) + model.enable_slicing() + output_with_slicing = model(**inputs_dict, generator=torch.manual_seed(0))[0] + + self.assertLess( + (output_without_slicing.detach().cpu().numpy() - output_with_slicing.detach().cpu().numpy()).max(), + 0.5, + "VAE slicing should not affect the inference results", + ) + + torch.manual_seed(0) + model.disable_slicing() + output_without_slicing_2 = model(**inputs_dict, generator=torch.manual_seed(0))[0] + + self.assertEqual( + output_without_slicing.detach().cpu().numpy().all(), + output_without_slicing_2.detach().cpu().numpy().all(), + "Without slicing outputs should match with the outputs when slicing is manually disabled.", + ) + + def test_gradient_checkpointing_is_applied(self): + expected_set = { + "HunyuanVideoDecoder3D", + "HunyuanVideoDownBlock3D", + "HunyuanVideoEncoder3D", + "HunyuanVideoMidBlock3D", + "HunyuanVideoUpBlock3D", + } + super().test_gradient_checkpointing_is_applied(expected_set=expected_set) + + @unittest.skip("Unsupported test.") + def test_outputs_equivalence(self): + pass From 845f30389a37a0d4a1b4cbbcb79a4cec0ece3070 Mon Sep 17 00:00:00 2001 From: Aryan Date: Fri, 13 Dec 2024 10:32:41 +0100 Subject: [PATCH 36/58] fix scaling factor --- src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py b/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py index 4dbe14afb2dd..61bee59513db 100644 --- a/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +++ b/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py @@ -803,6 +803,7 @@ def __call__( latents = latents.to(vae_dtype) if not output_type == "latent": + latents = latents / self.vae.config.scaling_factor image = self.vae.decode(latents, return_dict=False)[0] torch.save(image, "diffusers_latents_decoded.pt") From 556c6e95cd3682caf7618432e803138166cbffaf Mon Sep 17 00:00:00 2001 From: Aryan Date: Fri, 13 Dec 2024 12:31:20 +0100 Subject: [PATCH 37/58] refactor clip --- .../hunyuan_video/pipeline_hunyuan_video.py | 80 +++++++++++++++---- 1 file changed, 63 insertions(+), 17 deletions(-) diff --git a/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py b/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py index 61bee59513db..836f0aed5510 100644 --- a/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +++ b/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py @@ -18,6 +18,7 @@ import numpy as np import torch +from transformers import CLIPTextModel, CLIPTokenizer from ...callbacks import MultiPipelineCallbacks, PipelineCallback from ...image_processor import VaeImageProcessor @@ -140,7 +141,8 @@ def __init__( text_encoder: TextEncoder, transformer: HunyuanVideoTransformer3DModel, scheduler: KarrasDiffusionSchedulers, - text_encoder_2: Optional[TextEncoder] = None, + text_encoder_2: Optional[CLIPTextModel] = None, + tokenizer_2: Optional[CLIPTokenizer] = None, ): super().__init__() @@ -150,6 +152,7 @@ def __init__( transformer=transformer, scheduler=scheduler, text_encoder_2=text_encoder_2, + tokenizer_2=tokenizer_2, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) @@ -246,6 +249,45 @@ def encode_prompt( attention_mask, ) + def _get_clip_prompt_embeds( + self, + prompt: Union[str, List[str]], + num_videos_per_prompt: int = 1, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + max_sequence_length: int = 77, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder_2.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + text_inputs = self.tokenizer_2( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {max_sequence_length} tokens: {removed_text}" + ) + + prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False).pooler_output + + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, -1) + + return prompt_embeds + def check_inputs( self, prompt, @@ -577,7 +619,6 @@ def __call__( prompt_attention_mask: Optional[torch.Tensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, - clip_skip: Optional[int] = None, callback_on_step_end: Optional[ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] ] = None, @@ -674,7 +715,6 @@ def __call__( ) self._guidance_scale = guidance_scale - self._clip_skip = clip_skip self._interrupt = False # 2. Define call parameters @@ -698,27 +738,33 @@ def __call__( num_videos_per_prompt, prompt_embeds=prompt_embeds, attention_mask=prompt_attention_mask, - clip_skip=self.clip_skip, data_type=data_type, ) + # if self.text_encoder_2 is not None: + # ( + # prompt_embeds_2, + # prompt_mask_2, + # ) = self.encode_prompt( + # prompt, + # device, + # num_videos_per_prompt, + # prompt_embeds=prompt_embeds_2, + # attention_mask=None, + # clip_skip=self.clip_skip, + # text_encoder=self.text_encoder_2, + # data_type=data_type, + # ) + # else: + # prompt_embeds_2 = None + # prompt_mask_2 = None + if self.text_encoder_2 is not None: - ( - prompt_embeds_2, - prompt_mask_2, - ) = self.encode_prompt( + prompt_embeds_2 = self._get_clip_prompt_embeds( prompt, - device, num_videos_per_prompt, - prompt_embeds=prompt_embeds_2, - attention_mask=None, - clip_skip=self.clip_skip, - text_encoder=self.text_encoder_2, - data_type=data_type, + device=device, ) - else: - prompt_embeds_2 = None - prompt_mask_2 = None # 4. Prepare timesteps timesteps, num_inference_steps = retrieve_timesteps( From 4c6cf2d89f9c49b0cd23f39309e8b8c3d573eac7 Mon Sep 17 00:00:00 2001 From: Aryan Date: Fri, 13 Dec 2024 14:53:41 +0100 Subject: [PATCH 38/58] refactor llama text encoding --- .../hunyuan_video/pipeline_hunyuan_video.py | 285 ++++++++-------- .../pipelines/hunyuan_video/text_encoder.py | 322 ------------------ 2 files changed, 147 insertions(+), 460 deletions(-) delete mode 100644 src/diffusers/pipelines/hunyuan_video/text_encoder.py diff --git a/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py b/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py index 836f0aed5510..47309d06c5a0 100644 --- a/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +++ b/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py @@ -14,11 +14,11 @@ import inspect from dataclasses import dataclass -from typing import Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import numpy as np import torch -from transformers import CLIPTextModel, CLIPTokenizer +from transformers import CLIPTextModel, CLIPTokenizer, LlamaForCausalLM, LlamaTokenizerFast from ...callbacks import MultiPipelineCallbacks, PipelineCallback from ...image_processor import VaeImageProcessor @@ -31,7 +31,6 @@ ) from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline -from .text_encoder import TextEncoder logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -39,6 +38,20 @@ EXAMPLE_DOC_STRING = """""" +DEFAULT_PROMPT_TEMPLATE = { + "template": ( + "<|start_header_id|>system<|end_header_id|>\n\nDescribe the video by detailing the following aspects: " + "1. The main content and theme of the video." + "2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects." + "3. Actions, events, behaviors temporal relationships, physical movement changes of the objects." + "4. background environment, light, style and atmosphere." + "5. camera angles, movements, and transitions used in the video:<|eot_id|>" + "<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>" + ), + "crop_start": 95, +} + + def retrieve_timesteps( scheduler, num_inference_steps: Optional[int] = None, @@ -138,7 +151,8 @@ class HunyuanVideoPipeline(DiffusionPipeline): def __init__( self, vae: AutoencoderKLHunyuanVideo, - text_encoder: TextEncoder, + text_encoder: LlamaForCausalLM, + tokenizer: LlamaTokenizerFast, transformer: HunyuanVideoTransformer3DModel, scheduler: KarrasDiffusionSchedulers, text_encoder_2: Optional[CLIPTextModel] = None, @@ -149,105 +163,86 @@ def __init__( self.register_modules( vae=vae, text_encoder=text_encoder, + tokenizer=tokenizer, transformer=transformer, scheduler=scheduler, text_encoder_2=text_encoder_2, tokenizer_2=tokenizer_2, ) - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) - self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) - def encode_prompt( + self.vae_scale_factor_temporal = ( + self.vae.temporal_compression_ratio if hasattr(self, "vae") and self.vae is not None else 4 + ) + self.vae_scale_factor_spatial = ( + self.vae.spatial_compression_ratio if hasattr(self, "vae") and self.vae is not None else 8 + ) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor_spatial) + + def _get_llama_prompt_embeds( self, - prompt, - device, - num_videos_per_prompt, - prompt_embeds: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - clip_skip: Optional[int] = None, - text_encoder: Optional[TextEncoder] = None, - data_type: Optional[str] = "image", - ): - r""" - Encodes the prompt into text encoder hidden states. + prompt: Union[str, List[str]], + prompt_template: Dict[str, Any], + num_videos_per_prompt: int = 1, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + max_sequence_length: int = 256, + num_hidden_layers_to_skip: int = 2, + ) -> Tuple[torch.Tensor, torch.Tensor]: + device = device or self._execution_device + dtype = dtype or self.text_encoder_2.dtype - Args: - prompt (`str` or `List[str]`, *optional*): - prompt to be encoded - device: (`torch.device`): - torch device - num_videos_per_prompt (`int`): - number of videos that should be generated per prompt - prompt_embeds (`torch.Tensor`, *optional*): - Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not - provided, text embeddings will be generated from `prompt` input argument. - attention_mask (`torch.Tensor`, *optional*): - clip_skip (`int`, *optional*): - Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that - the output of the pre-final layer will be used for computing the prompt embeddings. - text_encoder (TextEncoder, *optional*): - data_type (`str`, *optional*): - """ - if text_encoder is None: - text_encoder = self.text_encoder + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) - if prompt_embeds is None: - text_inputs = text_encoder.text2tokens(prompt, data_type=data_type) + prompt = [prompt_template["template"].format(p) for p in prompt] + + crop_start = prompt_template.get("crop_start", None) + if crop_start is None: + prompt_template_input = self.tokenizer( + prompt_template["template"], + padding="max_length", + return_tensors="pt", + return_length=False, + return_overflowing_tokens=False, + return_attention_mask=False, + ) + crop_start = prompt_template_input["input_ids"].shape[-1] + # Remove <|eot_id|> token and placeholder {} + crop_start -= 2 - if clip_skip is None: - prompt_outputs = text_encoder.encode(text_inputs, data_type=data_type, device=device) - # TODO(aryan): Don't know why it doesn't work without this - torch.cuda.synchronize() + max_sequence_length += crop_start + text_inputs = self.tokenizer( + prompt, + max_length=max_sequence_length, + padding="max_length", + truncation=True, + return_tensors="pt", + return_length=False, + return_overflowing_tokens=False, + return_attention_mask=True, + ) + text_input_ids = text_inputs.input_ids.to(device) + prompt_attention_mask = text_inputs.attention_mask.to(device) - prompt_embeds = prompt_outputs.hidden_state - else: - prompt_outputs = text_encoder.encode( - text_inputs, - output_hidden_states=True, - data_type=data_type, - device=device, - ) - # Access the `hidden_states` first, that contains a tuple of - # all the hidden states from the encoder layers. Then index into - # the tuple to access the hidden states from the desired layer. - prompt_embeds = prompt_outputs.hidden_states_list[-(clip_skip + 1)] - # We also need to apply the final LayerNorm here to not mess with the - # representations. The `last_hidden_states` that we typically use for - # obtaining the final prompt representations passes through the LayerNorm - # layer. - prompt_embeds = text_encoder.model.text_model.final_layer_norm(prompt_embeds) - - attention_mask = prompt_outputs.attention_mask - if attention_mask is not None: - attention_mask = attention_mask.to(device) - bs_embed, seq_len = attention_mask.shape - attention_mask = attention_mask.repeat(1, num_videos_per_prompt) - attention_mask = attention_mask.view(bs_embed * num_videos_per_prompt, seq_len) - - if text_encoder is not None: - prompt_embeds_dtype = text_encoder.dtype - elif self.transformer is not None: - prompt_embeds_dtype = self.transformer.dtype - else: - prompt_embeds_dtype = prompt_embeds.dtype + prompt_embeds = self.text_encoder( + input_ids=text_input_ids, + attention_mask=prompt_attention_mask, + output_hidden_states=True, + ).hidden_states[-(num_hidden_layers_to_skip + 1)] + prompt_embeds = prompt_embeds.to(dtype) - prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + if crop_start is not None and crop_start > 0: + prompt_embeds = prompt_embeds[:, crop_start:] + prompt_attention_mask = prompt_attention_mask[:, crop_start:] - if prompt_embeds.ndim == 2: - bs_embed, _ = prompt_embeds.shape - # duplicate text embeddings for each generation per prompt, using mps friendly method - prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt) - prompt_embeds = prompt_embeds.view(bs_embed * num_videos_per_prompt, -1) - else: - bs_embed, seq_len, _ = prompt_embeds.shape - # duplicate text embeddings for each generation per prompt, using mps friendly method - prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) - prompt_embeds = prompt_embeds.view(bs_embed * num_videos_per_prompt, seq_len, -1) + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + prompt_attention_mask = prompt_attention_mask.repeat(1, num_videos_per_prompt) + prompt_attention_mask = prompt_attention_mask.view(batch_size * num_videos_per_prompt, seq_len) - return ( - prompt_embeds, - attention_mask, - ) + return prompt_embeds, prompt_attention_mask def _get_clip_prompt_embeds( self, @@ -256,7 +251,7 @@ def _get_clip_prompt_embeds( device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, max_sequence_length: int = 77, - ): + ) -> torch.Tensor: device = device or self._execution_device dtype = dtype or self.text_encoder_2.dtype @@ -288,6 +283,35 @@ def _get_clip_prompt_embeds( return prompt_embeds + def encode_prompt( + self, + prompt: Union[str, List[str]], + prompt_2: Union[str, List[str]] = None, + prompt_template: Dict[str, Any] = DEFAULT_PROMPT_TEMPLATE, + num_videos_per_prompt: int = 1, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + max_sequence_length: int = 256, + ): + prompt_embeds, prompt_attention_mask = self._get_llama_prompt_embeds( + prompt, + prompt_template, + num_videos_per_prompt, + device=device, + max_sequence_length=max_sequence_length, + ) + + if self.text_encoder_2 is not None: + prompt_2 = prompt_2 or prompt + prompt_embeds_2 = self._get_clip_prompt_embeds( + prompt, + num_videos_per_prompt, + device=device, + max_sequence_length=77, + ) + + return prompt_embeds, prompt_embeds_2, prompt_attention_mask + def check_inputs( self, prompt, @@ -297,6 +321,7 @@ def check_inputs( video_length, prompt_embeds=None, callback_on_step_end_tensor_inputs=None, + prompt_template=None, ): if height % 16 != 0 or width % 16 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") @@ -327,6 +352,14 @@ def check_inputs( elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") + if prompt_template is not None: + if not isinstance(prompt_template, dict): + raise ValueError(f"`prompt_template` has to be of type `dict` but is {type(prompt_template)}") + if "template" not in prompt_template: + raise ValueError( + f"`prompt_template` has to contain a key `template` but only found {prompt_template.keys()}" + ) + def prepare_latents( self, batch_size, @@ -343,8 +376,8 @@ def prepare_latents( batch_size, num_channels_latents, video_length, - int(height) // self.vae_scale_factor, - int(width) // self.vae_scale_factor, + int(height) // self.vae_scale_factor_spatial, + int(width) // self.vae_scale_factor_spatial, ) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( @@ -605,7 +638,7 @@ def __call__( prompt_2: Union[str, List[str]] = None, height: int = 720, width: int = 1280, - video_length: int = 129, + num_frames: int = 129, data_type: str = "video", num_inference_steps: int = 50, sigmas: List[float] = None, @@ -623,6 +656,8 @@ def __call__( Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] ] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], + prompt_template: Dict[str, Any] = DEFAULT_PROMPT_TEMPLATE, + max_sequence_length: int = 256, ): r""" The call function to the pipeline for generation. @@ -709,14 +744,17 @@ def __call__( prompt_2, height, width, - video_length, + num_frames, prompt_embeds, callback_on_step_end_tensor_inputs, + prompt_template, ) self._guidance_scale = guidance_scale self._interrupt = False + device = self._execution_device + # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 @@ -725,46 +763,23 @@ def __call__( else: batch_size = prompt_embeds.shape[0] - # TODO(aryan): No idea why it won't run without this - device = torch.device(self._execution_device) - # 3. Encode input prompt - ( - prompt_embeds, - prompt_mask, - ) = self.encode_prompt( + prompt_embeds, prompt_embeds_2, prompt_attention_mask = self.encode_prompt( prompt, - device, + prompt_2, + prompt_template, num_videos_per_prompt, - prompt_embeds=prompt_embeds, - attention_mask=prompt_attention_mask, - data_type=data_type, + device, + max_sequence_length, ) - # if self.text_encoder_2 is not None: - # ( - # prompt_embeds_2, - # prompt_mask_2, - # ) = self.encode_prompt( - # prompt, - # device, - # num_videos_per_prompt, - # prompt_embeds=prompt_embeds_2, - # attention_mask=None, - # clip_skip=self.clip_skip, - # text_encoder=self.text_encoder_2, - # data_type=data_type, - # ) - # else: - # prompt_embeds_2 = None - # prompt_mask_2 = None + target_dtype = torch.bfloat16 # Note(aryan): This has been hardcoded for now from the original repo + vae_dtype = torch.float16 # Note(aryan): This has been hardcoded for now from the original repo - if self.text_encoder_2 is not None: - prompt_embeds_2 = self._get_clip_prompt_embeds( - prompt, - num_videos_per_prompt, - device=device, - ) + prompt_embeds = prompt_embeds.to(target_dtype) + prompt_attention_mask = prompt_attention_mask.to(target_dtype) + if prompt_embeds_2 is not None: + prompt_embeds_2 = prompt_embeds_2.to(target_dtype) # 4. Prepare timesteps timesteps, num_inference_steps = retrieve_timesteps( @@ -775,31 +790,25 @@ def __call__( ) # 5. Prepare latent variables - target_dtype = torch.bfloat16 # Note(aryan): This has been hardcoded for now from the original repo num_channels_latents = self.transformer.config.in_channels + num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 latents = self.prepare_latents( batch_size * num_videos_per_prompt, num_channels_latents, height, width, - (video_length - 1) // 4 + 1, + num_latent_frames, target_dtype, device, generator, latents, ) - prompt_embeds = prompt_embeds.to(target_dtype) - prompt_mask = prompt_mask.to(target_dtype) - if prompt_embeds_2 is not None: - prompt_embeds_2 = prompt_embeds_2.to(target_dtype) - vae_dtype = torch.float16 # Note(aryan): This has been hardcoded for now from the original repo - # 7. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order self._num_timesteps = len(timesteps) - image_rotary_emb = self.get_rotary_pos_embed(video_length, height, width) + image_rotary_emb = self.get_rotary_pos_embed(num_frames, height, width) # if is_progress_bar: with self.progress_bar(total=num_inference_steps) as progress_bar: @@ -821,7 +830,7 @@ def __call__( hidden_states=latents, timestep=timestep, encoder_hidden_states=prompt_embeds, - encoder_attention_mask=prompt_mask, + encoder_attention_mask=prompt_attention_mask, encoder_hidden_states_2=prompt_embeds_2, freqs_cos=image_rotary_emb[0], freqs_sin=image_rotary_emb[1], diff --git a/src/diffusers/pipelines/hunyuan_video/text_encoder.py b/src/diffusers/pipelines/hunyuan_video/text_encoder.py deleted file mode 100644 index 9acfcab85767..000000000000 --- a/src/diffusers/pipelines/hunyuan_video/text_encoder.py +++ /dev/null @@ -1,322 +0,0 @@ -from dataclasses import dataclass -from typing import Optional, Tuple - -import torch -import torch.nn as nn -from transformers import AutoModel, AutoTokenizer, CLIPTextModel, CLIPTokenizer -from transformers.utils import ModelOutput - - -PRECISION_TO_TYPE = { - "fp32": torch.float32, - "fp16": torch.float16, - "bf16": torch.bfloat16, -} - - -def use_default(value, default): - return value if value is not None else default - - -def load_text_encoder( - text_encoder_type, - text_encoder_precision=None, - text_encoder_path=None, - device=None, -): - if text_encoder_path is None: - raise ValueError("text_encoder_path must be provided.") - - if text_encoder_type == "clipL": - text_encoder = CLIPTextModel.from_pretrained(text_encoder_path) - text_encoder.final_layer_norm = text_encoder.text_model.final_layer_norm - elif text_encoder_type == "llm": - text_encoder = AutoModel.from_pretrained(text_encoder_path, low_cpu_mem_usage=True) - text_encoder.final_layer_norm = text_encoder.norm - else: - raise ValueError(f"Unsupported text encoder type: {text_encoder_type}") - # from_pretrained will ensure that the model is in eval mode. - - if text_encoder_precision is not None: - text_encoder = text_encoder.to(dtype=PRECISION_TO_TYPE[text_encoder_precision]) - - text_encoder.requires_grad_(False) - - if device is not None: - text_encoder = text_encoder.to(device) - - return text_encoder, text_encoder_path - - -def load_tokenizer(tokenizer_type, tokenizer_path=None, padding_side="right"): - if tokenizer_path is None: - raise ValueError("tokenizer_path must be provided.") - - if tokenizer_type == "clipL": - tokenizer = CLIPTokenizer.from_pretrained(tokenizer_path, max_length=77) - elif tokenizer_type == "llm": - tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, padding_side=padding_side) - else: - raise ValueError(f"Unsupported tokenizer type: {tokenizer_type}") - - return tokenizer, tokenizer_path - - -@dataclass -class TextEncoderModelOutput(ModelOutput): - """ - Base class for model's outputs that also contains a pooling of the last hidden states. - - Args: - hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): - Sequence of hidden-states at the output of the last layer of the model. - attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``: - hidden_states_list (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed): - Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + - one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of - the model at the output of each layer plus the optional initial embedding outputs. - text_outputs (`list`, *optional*, returned when `return_texts=True` is passed): - List of decoded texts. - """ - - hidden_state: torch.FloatTensor = None - attention_mask: Optional[torch.LongTensor] = None - hidden_states_list: Optional[Tuple[torch.FloatTensor, ...]] = None - text_outputs: Optional[list] = None - - -class TextEncoder(nn.Module): - def __init__( - self, - text_encoder_type: str, - max_length: int, - text_encoder_precision: Optional[str] = None, - text_encoder_path: Optional[str] = None, - tokenizer_type: Optional[str] = None, - tokenizer_path: Optional[str] = None, - output_key: Optional[str] = None, - use_attention_mask: bool = True, - input_max_length: Optional[int] = None, - prompt_template: Optional[dict] = None, - prompt_template_video: Optional[dict] = None, - hidden_state_skip_layer: Optional[int] = None, - apply_final_norm: bool = False, - reproduce: bool = False, - ): - super().__init__() - self.text_encoder_type = text_encoder_type - self.max_length = max_length - self.precision = text_encoder_precision - self.model_path = text_encoder_path - self.tokenizer_type = tokenizer_type if tokenizer_type is not None else text_encoder_type - self.tokenizer_path = tokenizer_path if tokenizer_path is not None else text_encoder_path - self.use_attention_mask = use_attention_mask - if prompt_template_video is not None: - assert use_attention_mask is True, "Attention mask is True required when training videos." - self.input_max_length = input_max_length if input_max_length is not None else max_length - self.prompt_template = prompt_template - self.prompt_template_video = prompt_template_video - self.hidden_state_skip_layer = hidden_state_skip_layer - self.apply_final_norm = apply_final_norm - self.reproduce = reproduce - - self.use_template = self.prompt_template is not None - if self.use_template: - assert ( - isinstance(self.prompt_template, dict) and "template" in self.prompt_template - ), f"`prompt_template` must be a dictionary with a key 'template', got {self.prompt_template}" - assert "{}" in str(self.prompt_template["template"]), ( - "`prompt_template['template']` must contain a placeholder `{}` for the input text, " - f"got {self.prompt_template['template']}" - ) - - self.use_video_template = self.prompt_template_video is not None - if self.use_video_template: - if self.prompt_template_video is not None: - assert ( - isinstance(self.prompt_template_video, dict) and "template" in self.prompt_template_video - ), f"`prompt_template_video` must be a dictionary with a key 'template', got {self.prompt_template_video}" - assert "{}" in str(self.prompt_template_video["template"]), ( - "`prompt_template_video['template']` must contain a placeholder `{}` for the input text, " - f"got {self.prompt_template_video['template']}" - ) - - if "t5" in text_encoder_type: - self.output_key = output_key or "last_hidden_state" - elif "clip" in text_encoder_type: - self.output_key = output_key or "pooler_output" - elif "llm" in text_encoder_type or "glm" in text_encoder_type: - self.output_key = output_key or "last_hidden_state" - else: - raise ValueError(f"Unsupported text encoder type: {text_encoder_type}") - - self.model, self.model_path = load_text_encoder( - text_encoder_type=self.text_encoder_type, - text_encoder_precision=self.precision, - text_encoder_path=self.model_path, - device="cuda", - ) - self.dtype = self.model.dtype - self.device = "cuda" - - self.tokenizer, self.tokenizer_path = load_tokenizer( - tokenizer_type=self.tokenizer_type, - tokenizer_path=self.tokenizer_path, - padding_side="right", - ) - - def __repr__(self): - return f"{self.text_encoder_type} ({self.precision} - {self.model_path})" - - @staticmethod - def apply_text_to_template(text, template, prevent_empty_text=True): - """ - Apply text to template. - - Args: - text (str): Input text. - template (str or list): Template string or list of chat conversation. - prevent_empty_text (bool): If Ture, we will prevent the user text from being empty - by adding a space. Defaults to True. - """ - if isinstance(template, str): - # Will send string to tokenizer. Used for llm - return template.format(text) - else: - raise TypeError(f"Unsupported template type: {type(template)}") - - def text2tokens(self, text, data_type="image"): - """ - Tokenize the input text. - - Args: - text (str or list): Input text. - """ - tokenize_input_type = "str" - if self.use_template: - if data_type == "image": - prompt_template = self.prompt_template["template"] - elif data_type == "video": - prompt_template = self.prompt_template_video["template"] - else: - raise ValueError(f"Unsupported data type: {data_type}") - if isinstance(text, (list, tuple)): - text = [self.apply_text_to_template(one_text, prompt_template) for one_text in text] - if isinstance(text[0], list): - tokenize_input_type = "list" - elif isinstance(text, str): - text = self.apply_text_to_template(text, prompt_template) - if isinstance(text, list): - tokenize_input_type = "list" - else: - raise TypeError(f"Unsupported text type: {type(text)}") - - kwargs = { - "truncation": True, - "max_length": self.max_length, - "padding": "max_length", - "return_tensors": "pt", - } - if tokenize_input_type == "str": - return self.tokenizer( - text, - return_length=False, - return_overflowing_tokens=False, - return_attention_mask=True, - **kwargs, - ) - elif tokenize_input_type == "list": - return self.tokenizer.apply_chat_template( - text, - add_generation_prompt=True, - tokenize=True, - return_dict=True, - **kwargs, - ) - else: - raise ValueError(f"Unsupported tokenize_input_type: {tokenize_input_type}") - - def encode( - self, - batch_encoding, - use_attention_mask=None, - output_hidden_states=False, - do_sample=None, - hidden_state_skip_layer=None, - return_texts=False, - data_type="image", - device=None, - ): - """ - Args: - batch_encoding (dict): Batch encoding from tokenizer. - use_attention_mask (bool): Whether to use attention mask. If None, use self.use_attention_mask. - Defaults to None. - output_hidden_states (bool): Whether to output hidden states. If False, return the value of - self.output_key. If True, return the entire output. If set self.hidden_state_skip_layer, - output_hidden_states will be set True. Defaults to False. - do_sample (bool): Whether to sample from the model. Used for Decoder-Only LLMs. Defaults to None. - When self.produce is False, do_sample is set to True by default. - hidden_state_skip_layer (int): Number of hidden states to hidden_state_skip_layer. 0 means the last layer. - If None, self.output_key will be used. Defaults to None. - return_texts (bool): Whether to return the decoded texts. Defaults to False. - """ - device = self.model.device if device is None else device - use_attention_mask = use_default(use_attention_mask, self.use_attention_mask) - hidden_state_skip_layer = use_default(hidden_state_skip_layer, self.hidden_state_skip_layer) - do_sample = use_default(do_sample, not self.reproduce) - attention_mask = batch_encoding["attention_mask"].to(device) if use_attention_mask else None - input_ids = batch_encoding["input_ids"].to(device) - - # No idea why it doesn't work without this - torch.cuda.synchronize() - - outputs = self.model( - input_ids=input_ids, - attention_mask=attention_mask, - output_hidden_states=output_hidden_states or hidden_state_skip_layer is not None, - ) - if hidden_state_skip_layer is not None: - last_hidden_state = outputs.hidden_states[-(hidden_state_skip_layer + 1)] - # Real last hidden state already has layer norm applied. So here we only apply it - # for intermediate layers. - if hidden_state_skip_layer > 0 and self.apply_final_norm: - last_hidden_state = self.model.final_layer_norm(last_hidden_state) - else: - last_hidden_state = outputs[self.output_key] - - # Remove hidden states of instruction tokens, only keep prompt tokens. - if self.use_template: - if data_type == "image": - crop_start = self.prompt_template.get("crop_start", -1) - elif data_type == "video": - crop_start = self.prompt_template_video.get("crop_start", -1) - else: - raise ValueError(f"Unsupported data type: {data_type}") - if crop_start > 0: - last_hidden_state = last_hidden_state[:, crop_start:] - attention_mask = attention_mask[:, crop_start:] if use_attention_mask else None - - if output_hidden_states: - return TextEncoderModelOutput(last_hidden_state, attention_mask, outputs.hidden_states) - return TextEncoderModelOutput(last_hidden_state, attention_mask) - - def forward( - self, - text, - use_attention_mask=None, - output_hidden_states=False, - do_sample=False, - hidden_state_skip_layer=None, - return_texts=False, - ): - batch_encoding = self.text2tokens(text) - return self.encode( - batch_encoding, - use_attention_mask=use_attention_mask, - output_hidden_states=output_hidden_states, - do_sample=do_sample, - hidden_state_skip_layer=hidden_state_skip_layer, - return_texts=return_texts, - ) From d9ae8defddc6af1a38c7c66419e8345b3e66264c Mon Sep 17 00:00:00 2001 From: Aryan Date: Fri, 13 Dec 2024 14:55:25 +0100 Subject: [PATCH 39/58] add coauthor Co-Authored-By: "Gregory D. Hunkins" From e71366084ba1947f931a255df282b30cafab2a66 Mon Sep 17 00:00:00 2001 From: Aryan Date: Fri, 13 Dec 2024 21:25:22 +0100 Subject: [PATCH 40/58] refactor rope; diff: 0.14990234375; reason and fix: create rope grid on cpu and move to device Note: The following line diverges from original behaviour. We create the grid on the device, whereas original implementation creates it on CPU and then moves it to device. This results in numerical differences in layerwise debugging outputs, but visually it is the same. --- .../transformers/transformer_hunyuan_video.py | 75 +++-- .../hunyuan_video/pipeline_hunyuan_video.py | 305 +++--------------- 2 files changed, 90 insertions(+), 290 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_hunyuan_video.py b/src/diffusers/models/transformers/transformer_hunyuan_video.py index 50f8254648d1..840d78fcada8 100644 --- a/src/diffusers/models/transformers/transformer_hunyuan_video.py +++ b/src/diffusers/models/transformers/transformer_hunyuan_video.py @@ -24,6 +24,7 @@ from ...utils import is_torch_version from ..attention import FeedForward from ..attention_processor import Attention, AttentionProcessor +from ..embeddings import get_1d_rotary_pos_embed from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle @@ -138,26 +139,19 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class PatchEmbed(nn.Module): def __init__( self, - patch_size=16, - in_chans=3, - embed_dim=768, - norm_layer=None, - flatten=True, - bias=True, - ): + patch_size: Union[int, Tuple[int, int, int]] = 16, + in_chans: int = 3, + embed_dim: int = 768, + ) -> None: super().__init__() - patch_size = tuple(patch_size) - self.flatten = flatten - self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias) - self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + patch_size = (patch_size, patch_size, patch_size) if isinstance(patch_size, int) else patch_size + self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) - def forward(self, x): - x = self.proj(x) - if self.flatten: - x = x.flatten(2).transpose(1, 2) # BCHW -> BNC - x = self.norm(x) - return x + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.proj(hidden_states) + hidden_states = hidden_states.flatten(2).transpose(1, 2) # BCFHW -> BNC + return hidden_states class TextProjection(nn.Module): @@ -384,6 +378,39 @@ def forward( return hidden_states +class HunyuanVideoRotaryPosEmbed(nn.Module): + def __init__(self, patch_size: int, patch_size_t: int, rope_dim: List[int], theta: float = 256.0) -> None: + super().__init__() + + self.patch_size = patch_size + self.patch_size_t = patch_size_t + self.rope_dim = rope_dim + self.theta = theta + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + batch_size, num_channels, num_frames, height, width = hidden_states.shape + rope_sizes = [num_frames // self.patch_size_t, height // self.patch_size, width // self.patch_size] + + axes_grids = [] + for i in range(3): + # Note: The following line diverges from original behaviour. We create the grid on the device, whereas + # original implementation creates it on CPU and then moves it to device. This results in numerical + # differences in layerwise debugging outputs, but visually it is the same. + grid = torch.arange(0, rope_sizes[i], device=hidden_states.device, dtype=torch.float32) + axes_grids.append(grid) + grid = torch.meshgrid(*axes_grids, indexing="ij") # [W, H, T] + grid = torch.stack(grid, dim=0) # [3, W, H, T] + + freqs = [] + for i in range(3): + freq = get_1d_rotary_pos_embed(self.rope_dim[i], grid[i].reshape(-1), self.theta, use_real=True) + freqs.append(freq) + + freqs_cos = torch.cat([f[0] for f in freqs], dim=1) # (W * H * T, D / 2) + freqs_sin = torch.cat([f[1] for f in freqs], dim=1) # (W * H * T, D / 2) + return freqs_cos, freqs_sin + + class HunyuanVideoSingleTransformerBlock(nn.Module): def __init__( self, @@ -546,12 +573,12 @@ def __init__( guidance_embeds: bool = True, text_embed_dim: int = 4096, text_embed_dim_2: int = 768, + rope_theta: float = 256.0, ) -> None: super().__init__() inner_dim = num_attention_heads * attention_head_dim out_channels = out_channels or in_channels - self.rope_dim_list = rope_dim_list # image projection self.img_in = PatchEmbed((patch_size_t, patch_size, patch_size), in_channels, inner_dim) @@ -570,6 +597,9 @@ def __init__( # guidance modulation self.guidance_in = TimestepEmbedder(inner_dim, nn.SiLU) + # 3. RoPE + self.rope = HunyuanVideoRotaryPosEmbed(patch_size, patch_size_t, rope_dim_list, rope_theta) + self.transformer_blocks = nn.ModuleList( [ HunyuanVideoTransformerBlock( @@ -664,8 +694,6 @@ def forward( encoder_hidden_states: torch.Tensor, encoder_attention_mask: torch.Tensor, encoder_hidden_states_2: torch.Tensor, - freqs_cos: Optional[torch.Tensor] = None, - freqs_sin: Optional[torch.Tensor] = None, guidance: torch.Tensor = None, return_dict: bool = True, ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: @@ -676,6 +704,8 @@ def forward( post_patch_height = height // p post_patch_width = width // p + image_rotary_emb = self.rope(hidden_states) + temb = self.time_in(timestep) temb = temb + self.vector_in(encoder_hidden_states_2) temb = temb + self.guidance_in(guidance) @@ -691,15 +721,14 @@ def forward( else lambda x: x ) - freqs_cis = (freqs_cos, freqs_sin) if freqs_cos is not None else None for _, block in enumerate(self.transformer_blocks): hidden_states, encoder_hidden_states = block_forward(block)( - hidden_states, encoder_hidden_states, temb, freqs_cis + hidden_states, encoder_hidden_states, temb, image_rotary_emb ) for block in self.single_transformer_blocks: hidden_states, encoder_hidden_states = block_forward(block)( - hidden_states, encoder_hidden_states, temb, freqs_cis + hidden_states, encoder_hidden_states, temb, image_rotary_emb ) hidden_states = self.norm_out(hidden_states, temb) diff --git a/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py b/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py index 47309d06c5a0..1a8ee40770a1 100644 --- a/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +++ b/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py @@ -289,24 +289,31 @@ def encode_prompt( prompt_2: Union[str, List[str]] = None, prompt_template: Dict[str, Any] = DEFAULT_PROMPT_TEMPLATE, num_videos_per_prompt: int = 1, + prompt_embeds: Optional[torch.Tensor] = None, + prompt_embeds_2: Optional[torch.Tensor] = None, + prompt_attention_mask: Optional[torch.Tensor] = None, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, max_sequence_length: int = 256, ): - prompt_embeds, prompt_attention_mask = self._get_llama_prompt_embeds( - prompt, - prompt_template, - num_videos_per_prompt, - device=device, - max_sequence_length=max_sequence_length, - ) + if prompt_embeds is None: + prompt_embeds, prompt_attention_mask = self._get_llama_prompt_embeds( + prompt, + prompt_template, + num_videos_per_prompt, + device=device, + dtype=dtype, + max_sequence_length=max_sequence_length, + ) - if self.text_encoder_2 is not None: - prompt_2 = prompt_2 or prompt + if prompt_embeds_2 is None and self.text_encoder_2 is not None: + if prompt_2 is None and prompt_embeds_2 is None: + prompt_2 = prompt prompt_embeds_2 = self._get_clip_prompt_embeds( prompt, num_videos_per_prompt, device=device, + dtype=dtype, max_sequence_length=77, ) @@ -318,7 +325,6 @@ def check_inputs( prompt_2, height, width, - video_length, prompt_embeds=None, callback_on_step_end_tensor_inputs=None, prompt_template=None, @@ -362,20 +368,23 @@ def check_inputs( def prepare_latents( self, - batch_size, - num_channels_latents, - height, - width, - video_length, - dtype, - device, - generator, - latents=None, - ): + batch_size: int, + num_channels_latents: 32, + height: int = 720, + width: int = 1280, + num_frames: int = 129, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if latents is not None: + return latents.to(device=device, dtype=dtype) + shape = ( batch_size, num_channels_latents, - video_length, + num_frames, int(height) // self.vae_scale_factor_spatial, int(width) // self.vae_scale_factor_spatial, ) @@ -385,235 +394,9 @@ def prepare_latents( f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) - if latents is None: - latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) - else: - latents = latents.to(device) - - # Check existence to make it compatible with FlowMatchEulerDiscreteScheduler - if hasattr(self.scheduler, "init_noise_sigma"): - # scale the initial noise by the standard deviation required by the scheduler - latents = latents * self.scheduler.init_noise_sigma + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) return latents - # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding - def get_guidance_scale_embedding( - self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32 - ) -> torch.Tensor: - """ - See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 - - Args: - w (`torch.Tensor`): - Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings. - embedding_dim (`int`, *optional*, defaults to 512): - Dimension of the embeddings to generate. - dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): - Data type of the generated embeddings. - - Returns: - `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`. - """ - assert len(w.shape) == 1 - w = w * 1000.0 - - half_dim = embedding_dim // 2 - emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) - emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) - emb = w.to(dtype)[:, None] * emb[None, :] - emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) - if embedding_dim % 2 == 1: # zero pad - emb = torch.nn.functional.pad(emb, (0, 1)) - assert emb.shape == (w.shape[0], embedding_dim) - return emb - - def get_rotary_pos_embed(self, video_length, height, width): - def _to_tuple(x, dim=2): - if isinstance(x, int): - return (x,) * dim - elif len(x) == dim: - return x - else: - raise ValueError(f"Expected length {dim} or int, but got {x}") - - def get_meshgrid_nd(start, *args, dim=2): - """ - Get n-D meshgrid with start, stop and num. - - Args: - start (int or tuple): - If len(args) == 0, start is num; If len(args) == 1, start is start, args[0] is stop, step is 1; If - len(args) == 2, start is start, args[0] is stop, args[1] is num. For n-dim, start/stop/num should - be int or n-tuple. If n-tuple is provided, the meshgrid will be stacked following the dim order in - n-tuples. - *args: See above. - dim (int): Dimension of the meshgrid. Defaults to 2. - - Returns: - grid (np.ndarray): [dim, ...] - """ - if len(args) == 0: - # start is grid_size - num = _to_tuple(start, dim=dim) - start = (0,) * dim - stop = num - elif len(args) == 1: - # start is start, args[0] is stop, step is 1 - start = _to_tuple(start, dim=dim) - stop = _to_tuple(args[0], dim=dim) - num = [stop[i] - start[i] for i in range(dim)] - elif len(args) == 2: - # start is start, args[0] is stop, args[1] is num - start = _to_tuple(start, dim=dim) # Left-Top eg: 12,0 - stop = _to_tuple(args[0], dim=dim) # Right-Bottom eg: 20,32 - num = _to_tuple(args[1], dim=dim) # Target Size eg: 32,124 - else: - raise ValueError(f"len(args) should be 0, 1 or 2, but got {len(args)}") - - # PyTorch implement of np.linspace(start[i], stop[i], num[i], endpoint=False) - axis_grid = [] - for i in range(dim): - a, b, n = start[i], stop[i], num[i] - g = torch.linspace(a, b, n + 1, dtype=torch.float32)[:n] - axis_grid.append(g) - grid = torch.meshgrid(*axis_grid, indexing="ij") # dim x [W, H, D] - grid = torch.stack(grid, dim=0) # [dim, W, H, D] - - return grid - - def get_1d_rotary_pos_embed( - dim: int, - pos: Union[torch.FloatTensor, int], - theta: float = 10000.0, - use_real: bool = False, - theta_rescale_factor: float = 1.0, - interpolation_factor: float = 1.0, - ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: - """ - Precompute the frequency tensor for complex exponential (cis) with given dimensions. (Note: `cis` means - `cos + i * sin`, where i is the imaginary unit.) - - This function calculates a frequency tensor with complex exponential using the given dimension 'dim' and - the end index 'end'. The 'theta' parameter scales the frequencies. The returned tensor contains complex - values in complex64 data type. - - Args: - dim (int): Dimension of the frequency tensor. - pos (int or torch.FloatTensor): Position indices for the frequency tensor. [S] or scalar - theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0. - use_real (bool, optional): If True, return real part and imaginary part separately. - Otherwise, return complex numbers. - theta_rescale_factor (float, optional): Rescale factor for theta. Defaults to 1.0. - - Returns: - freqs_cis: Precomputed frequency tensor with complex exponential. [S, D/2] freqs_cos, freqs_sin: - Precomputed frequency tensor with real and imaginary parts separately. [S, D] - """ - if isinstance(pos, int): - pos = torch.arange(pos).float() - - # proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning - # has some connection to NTK literature - if theta_rescale_factor != 1.0: - theta *= theta_rescale_factor ** (dim / (dim - 2)) - - freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) # [D/2] - # assert interpolation_factor == 1.0, f"interpolation_factor: {interpolation_factor}" - freqs = torch.outer(pos * interpolation_factor, freqs) # [S, D/2] - if use_real: - freqs_cos = freqs.cos().repeat_interleave(2, dim=1) # [S, D] - freqs_sin = freqs.sin().repeat_interleave(2, dim=1) # [S, D] - return freqs_cos, freqs_sin - else: - freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2] - return freqs_cis - - def get_nd_rotary_pos_embed( - rope_dim_list, - start, - *args, - theta=10000.0, - use_real=False, - theta_rescale_factor: Union[float, List[float]] = 1.0, - interpolation_factor: Union[float, List[float]] = 1.0, - ): - """ - This is a n-d version of precompute_freqs_cis, which is a RoPE for tokens with n-d structure. - - Args: - rope_dim_list (list of int): Dimension of each rope. len(rope_dim_list) should equal to n. - sum(rope_dim_list) should equal to head_dim of attention layer. - start (int | tuple of int | list of int): - If len(args) == 0, start is num; If len(args) == 1, start is start, args[0] is stop, step is 1; If - len(args) == 2, start is start, args[0] is stop, args[1] is num. - *args: See above. - theta (float): Scaling factor for frequency computation. Defaults to 10000.0. - use_real (bool): - If True, return real part and imaginary part separately. Otherwise, return complex numbers. Some - libraries such as TensorRT does not support complex64 data type. So it is useful to provide a real - part and an imaginary part separately. - theta_rescale_factor (float): Rescale factor for theta. Defaults to 1.0. - - Returns: - pos_embed (torch.Tensor): [HW, D/2] - """ - - grid = get_meshgrid_nd(start, *args, dim=len(rope_dim_list)) # [3, W, H, D] / [2, W, H] - - if isinstance(theta_rescale_factor, int) or isinstance(theta_rescale_factor, float): - theta_rescale_factor = [theta_rescale_factor] * len(rope_dim_list) - elif isinstance(theta_rescale_factor, list) and len(theta_rescale_factor) == 1: - theta_rescale_factor = [theta_rescale_factor[0]] * len(rope_dim_list) - assert len(theta_rescale_factor) == len( - rope_dim_list - ), "len(theta_rescale_factor) should equal to len(rope_dim_list)" - - if isinstance(interpolation_factor, int) or isinstance(interpolation_factor, float): - interpolation_factor = [interpolation_factor] * len(rope_dim_list) - elif isinstance(interpolation_factor, list) and len(interpolation_factor) == 1: - interpolation_factor = [interpolation_factor[0]] * len(rope_dim_list) - assert len(interpolation_factor) == len( - rope_dim_list - ), "len(interpolation_factor) should equal to len(rope_dim_list)" - - # use 1/ndim of dimensions to encode grid_axis - embs = [] - for i in range(len(rope_dim_list)): - emb = get_1d_rotary_pos_embed( - rope_dim_list[i], - grid[i].reshape(-1), - theta, - use_real=use_real, - theta_rescale_factor=theta_rescale_factor[i], - interpolation_factor=interpolation_factor[i], - ) # 2 x [WHD, rope_dim_list[i]] - embs.append(emb) - - if use_real: - cos = torch.cat([emb[0] for emb in embs], dim=1) # (WHD, D/2) - sin = torch.cat([emb[1] for emb in embs], dim=1) # (WHD, D/2) - return cos, sin - else: - emb = torch.cat(embs, dim=1) # (WHD, D/2) - return emb - - latents_size = [(video_length - 1) // 4 + 1, height // 8, width // 8] - rope_sizes = [ - latents_size[0] // self.transformer.config.patch_size_t, - latents_size[1] // self.transformer.config.patch_size, - latents_size[2] // self.transformer.config.patch_size, - ] - - freqs_cos, freqs_sin = get_nd_rotary_pos_embed( - self.transformer.config.rope_dim_list, - rope_sizes, - theta=256, - use_real=True, - theta_rescale_factor=1, - ) - - return freqs_cos, freqs_sin - @property def guidance_scale(self): return self._guidance_scale @@ -639,12 +422,10 @@ def __call__( height: int = 720, width: int = 1280, num_frames: int = 129, - data_type: str = "video", num_inference_steps: int = 50, sigmas: List[float] = None, guidance_scale: float = 6.0, num_videos_per_prompt: Optional[int] = 1, - eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, prompt_embeds: Optional[torch.Tensor] = None, @@ -669,19 +450,15 @@ def __call__( prompt_2 (`str` or `List[str]`, *optional*): The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is will be used instead. - height (`int`): + height (`int`, defaults to `720`): The height in pixels of the generated image. - width (`int`): + width (`int`, defaults to `1280`): The width in pixels of the generated image. - video_length (`int`): + num_frames (`int`, defaults to `129`): The number of frames in the generated video. - num_inference_steps (`int`, *optional*, defaults to 50): + num_inference_steps (`int`, defaults to `50`): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. - timesteps (`List[int]`, *optional*): - Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument - in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is - passed will be used. Must be in descending order. sigmas (`List[float]`, *optional*): Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed @@ -696,9 +473,6 @@ def __call__( not applied. num_videos_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. - eta (`float`, *optional*, defaults to 0.0): - Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies - to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. @@ -744,7 +518,6 @@ def __call__( prompt_2, height, width, - num_frames, prompt_embeds, callback_on_step_end_tensor_inputs, prompt_template, @@ -769,6 +542,9 @@ def __call__( prompt_2, prompt_template, num_videos_per_prompt, + prompt_embeds, + prompt_embeds_2, + prompt_attention_mask, device, max_sequence_length, ) @@ -808,9 +584,6 @@ def __call__( num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order self._num_timesteps = len(timesteps) - image_rotary_emb = self.get_rotary_pos_embed(num_frames, height, width) - - # if is_progress_bar: with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): if self.interrupt: @@ -832,8 +605,6 @@ def __call__( encoder_hidden_states=prompt_embeds, encoder_attention_mask=prompt_attention_mask, encoder_hidden_states_2=prompt_embeds_2, - freqs_cos=image_rotary_emb[0], - freqs_sin=image_rotary_emb[1], guidance=guidance_expand, return_dict=False, )[0] From 9039db4de4606ce4279ae063b5918ec3f7a6eefb Mon Sep 17 00:00:00 2001 From: Aryan Date: Fri, 13 Dec 2024 21:48:45 +0100 Subject: [PATCH 41/58] use diffusers timesteps embedding; diff: 0.10205078125 --- .../transformers/transformer_hunyuan_video.py | 29 ++++++++++--------- 1 file changed, 16 insertions(+), 13 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_hunyuan_video.py b/src/diffusers/models/transformers/transformer_hunyuan_video.py index 840d78fcada8..86e232774c56 100644 --- a/src/diffusers/models/transformers/transformer_hunyuan_video.py +++ b/src/diffusers/models/transformers/transformer_hunyuan_video.py @@ -24,7 +24,7 @@ from ...utils import is_torch_version from ..attention import FeedForward from ..attention_processor import Attention, AttentionProcessor -from ..embeddings import get_1d_rotary_pos_embed +from ..embeddings import get_1d_rotary_pos_embed, get_timestep_embedding from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle @@ -219,7 +219,8 @@ def __init__( ) def forward(self, t): - t_freq = timestep_embedding(t, self.frequency_embedding_size, self.max_period).type(self.mlp[0].weight.dtype) + # t_freq = timestep_embedding(t, self.frequency_embedding_size, self.max_period).type(self.mlp[0].weight.dtype) + t_freq = get_timestep_embedding(t, self.frequency_embedding_size, flip_sin_to_cos=True, max_period=self.max_period, downscale_freq_shift=0).type(self.mlp[0].weight.dtype) t_emb = self.mlp(t_freq) return t_emb @@ -231,24 +232,22 @@ def __init__( attention_head_dim: int, mlp_width_ratio: str = 4.0, mlp_drop_rate: float = 0.0, - qkv_bias: bool = True, + attention_bias: bool = True, ) -> None: super().__init__() hidden_size = num_attention_heads * attention_head_dim self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6) - self.attn = Attention( query_dim=hidden_size, cross_attention_dim=None, heads=num_attention_heads, dim_head=attention_head_dim, - bias=True, + bias=attention_bias, ) self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6) - self.mlp = FeedForward(hidden_size, mult=mlp_width_ratio, activation_fn="silu", dropout=mlp_drop_rate) self.adaLN_modulation = nn.Sequential( @@ -286,8 +285,8 @@ def __init__( num_layers: int, mlp_width_ratio: float = 4.0, mlp_drop_rate: float = 0.0, - qkv_bias: bool = True, - ): + attention_bias: bool = True, + ) -> None: super().__init__() self.refiner_blocks = nn.ModuleList( @@ -297,7 +296,7 @@ def __init__( attention_head_dim=attention_head_dim, mlp_width_ratio=mlp_width_ratio, mlp_drop_rate=mlp_drop_rate, - qkv_bias=qkv_bias, + attention_bias=attention_bias, ) for _ in range(num_layers) ] @@ -308,7 +307,7 @@ def forward( hidden_states: torch.Tensor, temb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, - ): + ) -> None: self_attn_mask = None if attention_mask is not None: batch_size = attention_mask.shape[0] @@ -334,13 +333,15 @@ def __init__( num_layers: int, mlp_ratio: float = 4.0, mlp_drop_rate: float = 0.0, - qkv_bias: bool = True, - ): + attention_bias: bool = True, + ) -> None: super().__init__() hidden_size = num_attention_heads * attention_head_dim self.input_embedder = nn.Linear(in_channels, hidden_size, bias=True) + # self.time_embed = TimestepEmbedder(hidden_size, nn.SiLU) + # self.context_embed = TextProjection(in_channels, hidden_size, nn.SiLU) self.t_embedder = TimestepEmbedder(hidden_size, nn.SiLU) self.c_embedder = TextProjection(in_channels, hidden_size, nn.SiLU) @@ -350,7 +351,7 @@ def __init__( num_layers=num_layers, mlp_width_ratio=mlp_ratio, mlp_drop_rate=mlp_drop_rate, - qkv_bias=qkv_bias, + attention_bias=attention_bias, ) def forward( @@ -360,6 +361,7 @@ def forward( attention_mask: Optional[torch.LongTensor] = None, ) -> torch.Tensor: original_dtype = hidden_states.dtype + # temb = self.time_embed(timestep) temb = self.t_embedder(timestep) if attention_mask is None: @@ -369,6 +371,7 @@ def forward( pooled_projections = (hidden_states * mask_float).sum(dim=1) / mask_float.sum(dim=1) pooled_projections = pooled_projections.to(original_dtype) + # pooled_projections = self.context_embed(pooled_projections) pooled_projections = self.c_embedder(pooled_projections) emb = temb + pooled_projections From 16778b15bb95d1be99d308b769cc89f43757ada0 Mon Sep 17 00:00:00 2001 From: Aryan Date: Fri, 13 Dec 2024 21:57:38 +0100 Subject: [PATCH 42/58] rename --- scripts/convert_hunyuan_video_to_diffusers.py | 25 +++++++++++- .../transformers/transformer_hunyuan_video.py | 38 ++----------------- 2 files changed, 28 insertions(+), 35 deletions(-) diff --git a/scripts/convert_hunyuan_video_to_diffusers.py b/scripts/convert_hunyuan_video_to_diffusers.py index 61ffe0474087..9ad2465fda80 100644 --- a/scripts/convert_hunyuan_video_to_diffusers.py +++ b/scripts/convert_hunyuan_video_to_diffusers.py @@ -3,8 +3,9 @@ import torch from accelerate import init_empty_weights +from transformers import AutoModel, AutoTokenizer, CLIPTextModel, CLIPTokenizer -from diffusers import AutoencoderKLHunyuanVideo, HunyuanVideoTransformer3DModel +from diffusers import AutoencoderKLHunyuanVideo, HunyuanVideoTransformer3DModel, HunyuanVideoPipeline def remap_norm_scale_shift_(key, state_dict): @@ -76,6 +77,8 @@ def remap_single_transformer_blocks_(key, state_dict): # "guidance_in.mlp.2": "time_text_embed.guidance_embedder.linear_2", # "vector_in.in_layer": "time_text_embed.text_embedder.linear_1", # "vector_in.out_layer": "time_text_embed.text_embedder.linear_2", + "txt_in.t_embedder": "txt_in.time_embed", + "txt_in.c_embedder": "txt_in.context_embed", "double_blocks": "transformer_blocks", "individual_token_refiner.blocks": "token_refiner.refiner_blocks", "img_attn_q_norm": "attn.norm_q", @@ -179,6 +182,8 @@ def get_args(): "--transformer_ckpt_path", type=str, default=None, help="Path to original transformer checkpoint" ) parser.add_argument("--vae_ckpt_path", type=str, default=None, help="Path to original VAE checkpoint") + parser.add_argument("--text_encoder_path", type=str, default=None, help="Path to original llama checkpoint") + parser.add_argument("--text_encoder_2_path", type=str, default=None, help="Path to original clip checkpoint") parser.add_argument("--save_pipeline", action="store_true") parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved") parser.add_argument("--dtype", default="bf16", help="Torch dtype to save the transformer in.") @@ -200,6 +205,8 @@ def get_args(): if args.save_pipeline: assert args.transformer_ckpt_path is not None and args.vae_ckpt_path is not None + assert args.text_encoder_path is not None + assert args.text_encoder_2_path is not None if args.transformer_ckpt_path is not None: transformer = convert_transformer(args.transformer_ckpt_path) @@ -211,3 +218,19 @@ def get_args(): vae = convert_vae(args.vae_ckpt_path) if not args.save_pipeline: vae.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB") + + if args.save_pipeline: + text_encoder = AutoModel.from_pretrained(args.text_encoder_path, torch_dtype=torch.float16) + tokenizer = AutoTokenizer.from_pretrained(args.text_encoder_path, padding_side="right") + text_encoder_2 = CLIPTextModel.from_pretrained(args.text_encoder_2_path, torch_dtype=torch.float16) + tokenizer_2 = CLIPTokenizer.from_pretrained(args.text_encoder_2_path) + + pipe = HunyuanVideoPipeline( + transformer=transformer, + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + text_encoder_2=text_encoder_2, + tokenizer_2=tokenizer_2, + ) + pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB") diff --git a/src/diffusers/models/transformers/transformer_hunyuan_video.py b/src/diffusers/models/transformers/transformer_hunyuan_video.py index 86e232774c56..ec64559333f9 100644 --- a/src/diffusers/models/transformers/transformer_hunyuan_video.py +++ b/src/diffusers/models/transformers/transformer_hunyuan_video.py @@ -168,31 +168,6 @@ def forward(self, caption): return hidden_states -def timestep_embedding(t, dim, max_period=10000): - """ - Create sinusoidal timestep embeddings. - - Args: - t (torch.Tensor): a 1-D Tensor of N indices, one per batch element. These may be fractional. - dim (int): the dimension of the output. - max_period (int): controls the minimum frequency of the embeddings. - - Returns: - embedding (torch.Tensor): An (N, D) Tensor of positional embeddings. - - .. ref_link: https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py - """ - half = dim // 2 - freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to( - device=t.device - ) - args = t[:, None].float() * freqs[None] - embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) - if dim % 2: - embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) - return embedding - - class TimestepEmbedder(nn.Module): """ Embeds scalar timesteps into vector representations. @@ -219,7 +194,6 @@ def __init__( ) def forward(self, t): - # t_freq = timestep_embedding(t, self.frequency_embedding_size, self.max_period).type(self.mlp[0].weight.dtype) t_freq = get_timestep_embedding(t, self.frequency_embedding_size, flip_sin_to_cos=True, max_period=self.max_period, downscale_freq_shift=0).type(self.mlp[0].weight.dtype) t_emb = self.mlp(t_freq) return t_emb @@ -340,10 +314,8 @@ def __init__( hidden_size = num_attention_heads * attention_head_dim self.input_embedder = nn.Linear(in_channels, hidden_size, bias=True) - # self.time_embed = TimestepEmbedder(hidden_size, nn.SiLU) - # self.context_embed = TextProjection(in_channels, hidden_size, nn.SiLU) - self.t_embedder = TimestepEmbedder(hidden_size, nn.SiLU) - self.c_embedder = TextProjection(in_channels, hidden_size, nn.SiLU) + self.time_embed = TimestepEmbedder(hidden_size, nn.SiLU) + self.context_embed = TextProjection(in_channels, hidden_size, nn.SiLU) self.token_refiner = IndividualTokenRefiner( num_attention_heads=num_attention_heads, @@ -361,8 +333,7 @@ def forward( attention_mask: Optional[torch.LongTensor] = None, ) -> torch.Tensor: original_dtype = hidden_states.dtype - # temb = self.time_embed(timestep) - temb = self.t_embedder(timestep) + temb = self.time_embed(timestep) if attention_mask is None: pooled_projections = hidden_states.mean(dim=1) @@ -371,8 +342,7 @@ def forward( pooled_projections = (hidden_states * mask_float).sum(dim=1) / mask_float.sum(dim=1) pooled_projections = pooled_projections.to(original_dtype) - # pooled_projections = self.context_embed(pooled_projections) - pooled_projections = self.c_embedder(pooled_projections) + pooled_projections = self.context_embed(pooled_projections) emb = temb + pooled_projections hidden_states = self.input_embedder(hidden_states) From b6c7ae027792dc8aa03e1c0c7db8560cfbdf79cb Mon Sep 17 00:00:00 2001 From: Aryan Date: Fri, 13 Dec 2024 23:20:10 +0100 Subject: [PATCH 43/58] convert --- scripts/convert_hunyuan_video_to_diffusers.py | 50 +++-- .../transformers/transformer_hunyuan_video.py | 177 ++++++++---------- 2 files changed, 105 insertions(+), 122 deletions(-) diff --git a/scripts/convert_hunyuan_video_to_diffusers.py b/scripts/convert_hunyuan_video_to_diffusers.py index 9ad2465fda80..03e302a75574 100644 --- a/scripts/convert_hunyuan_video_to_diffusers.py +++ b/scripts/convert_hunyuan_video_to_diffusers.py @@ -5,7 +5,7 @@ from accelerate import init_empty_weights from transformers import AutoModel, AutoTokenizer, CLIPTextModel, CLIPTokenizer -from diffusers import AutoencoderKLHunyuanVideo, HunyuanVideoTransformer3DModel, HunyuanVideoPipeline +from diffusers import AutoencoderKLHunyuanVideo, HunyuanVideoPipeline, HunyuanVideoTransformer3DModel def remap_norm_scale_shift_(key, state_dict): @@ -15,6 +15,23 @@ def remap_norm_scale_shift_(key, state_dict): state_dict[key.replace("final_layer.adaLN_modulation.1", "norm_out.linear")] = new_weight +def remap_token_refiner_blocks_(key, state_dict): + def rename_key(key): + new_key = key.replace("individual_token_refiner.blocks", "token_refiner.refiner_blocks") + new_key = new_key.replace("adaLN_modulation.1", "norm_out.linear") + return new_key + + if "self_attn_qkv" in key: + weight = state_dict.pop(key) + to_q, to_k, to_v = weight.chunk(3, dim=0) + state_dict[rename_key(key.replace("self_attn_qkv", "attn.to_q"))] = to_q + state_dict[rename_key(key.replace("self_attn_qkv", "attn.to_k"))] = to_k + state_dict[rename_key(key.replace("self_attn_qkv", "attn.to_v"))] = to_v + + else: + state_dict[rename_key(key)] = state_dict.pop(key) + + def remap_img_attn_qkv_(key, state_dict): weight = state_dict.pop(key) to_q, to_k, to_v = weight.chunk(3, dim=0) @@ -31,14 +48,6 @@ def remap_txt_attn_qkv_(key, state_dict): state_dict[key.replace("txt_attn_qkv", "attn.add_v_proj")] = to_v -def remap_self_attn_qkv_(key, state_dict): - weight = state_dict.pop(key) - to_q, to_k, to_v = weight.chunk(3, dim=0) - state_dict[key.replace("self_attn_qkv", "attn.to_q")] = to_q - state_dict[key.replace("self_attn_qkv", "attn.to_k")] = to_k - state_dict[key.replace("self_attn_qkv", "attn.to_v")] = to_v - - def remap_single_transformer_blocks_(key, state_dict): hidden_size = 3072 @@ -71,16 +80,16 @@ def remap_single_transformer_blocks_(key, state_dict): TRANSFORMER_KEYS_RENAME_DICT = { - # "time_in.mlp.0": "time_text_embed.timestep_embedder.linear_1", - # "time_in.mlp.2": "time_text_embed.timestep_embedder.linear_2", - # "guidance_in.mlp.0": "time_text_embed.guidance_embedder.linear_1", - # "guidance_in.mlp.2": "time_text_embed.guidance_embedder.linear_2", - # "vector_in.in_layer": "time_text_embed.text_embedder.linear_1", - # "vector_in.out_layer": "time_text_embed.text_embedder.linear_2", - "txt_in.t_embedder": "txt_in.time_embed", - "txt_in.c_embedder": "txt_in.context_embed", + "time_in.mlp.0": "time_text_embed.timestep_embedder.linear_1", + "time_in.mlp.2": "time_text_embed.timestep_embedder.linear_2", + "guidance_in.mlp.0": "time_text_embed.guidance_embedder.linear_1", + "guidance_in.mlp.2": "time_text_embed.guidance_embedder.linear_2", + "vector_in.in_layer": "time_text_embed.text_embedder.linear_1", + "vector_in.out_layer": "time_text_embed.text_embedder.linear_2", + "txt_in.t_embedder.mlp.0": "txt_in.time_text_embed.timestep_embedder.linear_1", + "txt_in.t_embedder.mlp.2": "txt_in.time_text_embed.timestep_embedder.linear_2", + "txt_in.c_embedder": "txt_in.time_text_embed.text_embedder", "double_blocks": "transformer_blocks", - "individual_token_refiner.blocks": "token_refiner.refiner_blocks", "img_attn_q_norm": "attn.norm_q", "img_attn_k_norm": "attn.norm_k", "img_attn_proj": "attn.to_out.0", @@ -102,14 +111,15 @@ def remap_single_transformer_blocks_(key, state_dict): "final_layer.linear": "proj_out", "fc1": "net.0.proj", "fc2": "net.2", + "input_embedder": "proj_in", } TRANSFORMER_SPECIAL_KEYS_REMAP = { - "final_layer.adaLN_modulation.1": remap_norm_scale_shift_, "img_attn_qkv": remap_img_attn_qkv_, "txt_attn_qkv": remap_txt_attn_qkv_, - "self_attn_qkv": remap_self_attn_qkv_, "single_blocks": remap_single_transformer_blocks_, + "final_layer.adaLN_modulation.1": remap_norm_scale_shift_, + "individual_token_refiner.blocks": remap_token_refiner_blocks_, } VAE_KEYS_RENAME_DICT = {} diff --git a/src/diffusers/models/transformers/transformer_hunyuan_video.py b/src/diffusers/models/transformers/transformer_hunyuan_video.py index ec64559333f9..22939e205fb0 100644 --- a/src/diffusers/models/transformers/transformer_hunyuan_video.py +++ b/src/diffusers/models/transformers/transformer_hunyuan_video.py @@ -12,9 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -import math -from functools import partial -from typing import Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union import torch import torch.nn as nn @@ -24,7 +22,11 @@ from ...utils import is_torch_version from ..attention import FeedForward from ..attention_processor import Attention, AttentionProcessor -from ..embeddings import get_1d_rotary_pos_embed, get_timestep_embedding +from ..embeddings import ( + CombinedTimestepGuidanceTextProjEmbeddings, + CombinedTimestepTextProjEmbeddings, + get_1d_rotary_pos_embed, +) from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle @@ -123,19 +125,6 @@ def __call__( return hidden_states, encoder_hidden_states -class MLPEmbedder(nn.Module): - """copied from https://github.com/black-forest-labs/flux/blob/main/src/flux/modules/layers.py""" - - def __init__(self, in_dim: int, hidden_dim: int): - super().__init__() - self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True) - self.silu = nn.SiLU() - self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - return self.out_layer(self.silu(self.in_layer(x))) - - class PatchEmbed(nn.Module): def __init__( self, @@ -154,49 +143,21 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: return hidden_states -class TextProjection(nn.Module): - def __init__(self, in_channels, hidden_size, act_layer): +class HunyuanVideoAdaNorm(nn.Module): + def __init__(self, in_features: int, out_features: Optional[int] = None) -> None: super().__init__() - self.linear_1 = nn.Linear(in_features=in_channels, out_features=hidden_size, bias=True) - self.act_1 = act_layer() - self.linear_2 = nn.Linear(in_features=hidden_size, out_features=hidden_size, bias=True) - - def forward(self, caption): - hidden_states = self.linear_1(caption) - hidden_states = self.act_1(hidden_states) - hidden_states = self.linear_2(hidden_states) - return hidden_states - -class TimestepEmbedder(nn.Module): - """ - Embeds scalar timesteps into vector representations. - """ - - def __init__( - self, - hidden_size, - act_layer, - frequency_embedding_size=256, - max_period=10000, - out_size=None, - ): - super().__init__() - self.frequency_embedding_size = frequency_embedding_size - self.max_period = max_period - if out_size is None: - out_size = hidden_size - - self.mlp = nn.Sequential( - nn.Linear(frequency_embedding_size, hidden_size, bias=True), - act_layer(), - nn.Linear(hidden_size, out_size, bias=True), - ) + out_features = out_features or 2 * in_features + self.linear = nn.Linear(in_features, out_features) + self.nonlinearity = nn.SiLU() - def forward(self, t): - t_freq = get_timestep_embedding(t, self.frequency_embedding_size, flip_sin_to_cos=True, max_period=self.max_period, downscale_freq_shift=0).type(self.mlp[0].weight.dtype) - t_emb = self.mlp(t_freq) - return t_emb + def forward( + self, temb: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + temb = self.linear(self.nonlinearity(temb)) + gate_msa, gate_mlp = temb.chunk(2, dim=1) + gate_msa, gate_mlp = gate_msa.unsqueeze(1), gate_mlp.unsqueeze(1) + return gate_msa, gate_mlp class IndividualTokenRefinerBlock(nn.Module): @@ -224,10 +185,7 @@ def __init__( self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6) self.mlp = FeedForward(hidden_size, mult=mlp_width_ratio, activation_fn="silu", dropout=mlp_drop_rate) - self.adaLN_modulation = nn.Sequential( - nn.SiLU(), - nn.Linear(hidden_size, 2 * hidden_size, bias=True), - ) + self.norm_out = HunyuanVideoAdaNorm(hidden_size, 2 * hidden_size) def forward( self, @@ -235,8 +193,6 @@ def forward( temb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: - gate_msa, gate_mlp = self.adaLN_modulation(temb).chunk(2, dim=1) - norm_hidden_states = self.norm1(hidden_states) attn_output = self.attn( @@ -244,9 +200,12 @@ def forward( encoder_hidden_states=None, attention_mask=attention_mask, ) - hidden_states = hidden_states + attn_output * gate_msa.unsqueeze(1) - hidden_states = hidden_states + self.mlp(self.norm2(hidden_states)) * gate_mlp.unsqueeze(1) + gate_msa, gate_mlp = self.norm_out(temb) + hidden_states = hidden_states + attn_output * gate_msa + + ff_output = self.mlp(self.norm2(hidden_states)) + hidden_states = hidden_states + ff_output * gate_mlp return hidden_states @@ -313,10 +272,10 @@ def __init__( hidden_size = num_attention_heads * attention_head_dim - self.input_embedder = nn.Linear(in_channels, hidden_size, bias=True) - self.time_embed = TimestepEmbedder(hidden_size, nn.SiLU) - self.context_embed = TextProjection(in_channels, hidden_size, nn.SiLU) - + self.time_text_embed = CombinedTimestepTextProjEmbeddings( + embedding_dim=hidden_size, pooled_projection_dim=in_channels + ) + self.proj_in = nn.Linear(in_channels, hidden_size, bias=True) self.token_refiner = IndividualTokenRefiner( num_attention_heads=num_attention_heads, attention_head_dim=attention_head_dim, @@ -332,21 +291,17 @@ def forward( timestep: torch.LongTensor, attention_mask: Optional[torch.LongTensor] = None, ) -> torch.Tensor: - original_dtype = hidden_states.dtype - temb = self.time_embed(timestep) - if attention_mask is None: pooled_projections = hidden_states.mean(dim=1) else: + original_dtype = hidden_states.dtype mask_float = attention_mask.float().unsqueeze(-1) pooled_projections = (hidden_states * mask_float).sum(dim=1) / mask_float.sum(dim=1) pooled_projections = pooled_projections.to(original_dtype) - pooled_projections = self.context_embed(pooled_projections) - emb = temb + pooled_projections - - hidden_states = self.input_embedder(hidden_states) - hidden_states = self.token_refiner(hidden_states, emb, attention_mask) + temb = self.time_text_embed(timestep, pooled_projections) + hidden_states = self.proj_in(hidden_states) + hidden_states = self.token_refiner(hidden_states, temb, attention_mask) return hidden_states @@ -561,14 +516,7 @@ def __init__( text_embed_dim, num_attention_heads, attention_head_dim, num_layers=num_refiner_layers ) - # time modulation - self.time_in = TimestepEmbedder(inner_dim, nn.SiLU) - - # text modulation - self.vector_in = MLPEmbedder(text_embed_dim_2, inner_dim) - - # guidance modulation - self.guidance_in = TimestepEmbedder(inner_dim, nn.SiLU) + self.time_text_embed = CombinedTimestepGuidanceTextProjEmbeddings(inner_dim, text_embed_dim_2) # 3. RoPE self.rope = HunyuanVideoRotaryPosEmbed(patch_size, patch_size_t, rope_dim_list, rope_theta) @@ -679,30 +627,55 @@ def forward( image_rotary_emb = self.rope(hidden_states) - temb = self.time_in(timestep) - temb = temb + self.vector_in(encoder_hidden_states_2) - temb = temb + self.guidance_in(guidance) + temb = self.time_text_embed(timestep, guidance, encoder_hidden_states_2) # Embed image and text. hidden_states = self.img_in(hidden_states) encoder_hidden_states = self.txt_in(encoder_hidden_states, timestep, encoder_attention_mask) - use_reentrant = is_torch_version(">=", "1.11.0") - block_forward = ( - partial(torch.utils.checkpoint.checkpoint, use_reentrant=use_reentrant) - if torch.is_grad_enabled() and self.gradient_checkpointing - else lambda x: x - ) + if torch.is_grad_enabled() and self.gradient_checkpointing: - for _, block in enumerate(self.transformer_blocks): - hidden_states, encoder_hidden_states = block_forward(block)( - hidden_states, encoder_hidden_states, temb, image_rotary_emb - ) + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) - for block in self.single_transformer_blocks: - hidden_states, encoder_hidden_states = block_forward(block)( - hidden_states, encoder_hidden_states, temb, image_rotary_emb - ) + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + + for block in self.transformer_blocks: + hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + encoder_hidden_states, + temb, + image_rotary_emb, + **ckpt_kwargs, + ) + + for block in self.single_transformer_blocks: + hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + encoder_hidden_states, + temb, + image_rotary_emb, + **ckpt_kwargs, + ) + + else: + for block in self.transformer_blocks: + hidden_states, encoder_hidden_states = block( + hidden_states, encoder_hidden_states, temb, image_rotary_emb + ) + + for block in self.single_transformer_blocks: + hidden_states, encoder_hidden_states = block( + hidden_states, encoder_hidden_states, temb, image_rotary_emb + ) hidden_states = self.norm_out(hidden_states, temb) hidden_states = self.proj_out(hidden_states) From e7c382e7b40eac4233da5052c7425008db4964ea Mon Sep 17 00:00:00 2001 From: Aryan Date: Sat, 14 Dec 2024 00:25:37 +0100 Subject: [PATCH 44/58] update --- scripts/convert_hunyuan_video_to_diffusers.py | 14 +- .../hunyuan_video/pipeline_hunyuan_video.py | 140 ++++++++++-------- .../hunyuan_video/pipeline_output.py | 20 +++ 3 files changed, 107 insertions(+), 67 deletions(-) create mode 100644 src/diffusers/pipelines/hunyuan_video/pipeline_output.py diff --git a/scripts/convert_hunyuan_video_to_diffusers.py b/scripts/convert_hunyuan_video_to_diffusers.py index 03e302a75574..8f960b90183d 100644 --- a/scripts/convert_hunyuan_video_to_diffusers.py +++ b/scripts/convert_hunyuan_video_to_diffusers.py @@ -5,7 +5,12 @@ from accelerate import init_empty_weights from transformers import AutoModel, AutoTokenizer, CLIPTextModel, CLIPTokenizer -from diffusers import AutoencoderKLHunyuanVideo, HunyuanVideoPipeline, HunyuanVideoTransformer3DModel +from diffusers import ( + AutoencoderKLHunyuanVideo, + FlowMatchEulerDiscreteScheduler, + HunyuanVideoPipeline, + HunyuanVideoTransformer3DModel, +) def remap_norm_scale_shift_(key, state_dict): @@ -193,6 +198,7 @@ def get_args(): ) parser.add_argument("--vae_ckpt_path", type=str, default=None, help="Path to original VAE checkpoint") parser.add_argument("--text_encoder_path", type=str, default=None, help="Path to original llama checkpoint") + parser.add_argument("--tokenizer_path", type=str, default=None, help="Path to original llama tokenizer") parser.add_argument("--text_encoder_2_path", type=str, default=None, help="Path to original clip checkpoint") parser.add_argument("--save_pipeline", action="store_true") parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved") @@ -216,6 +222,7 @@ def get_args(): if args.save_pipeline: assert args.transformer_ckpt_path is not None and args.vae_ckpt_path is not None assert args.text_encoder_path is not None + assert args.tokenizer_path is not None assert args.text_encoder_2_path is not None if args.transformer_ckpt_path is not None: @@ -231,9 +238,11 @@ def get_args(): if args.save_pipeline: text_encoder = AutoModel.from_pretrained(args.text_encoder_path, torch_dtype=torch.float16) - tokenizer = AutoTokenizer.from_pretrained(args.text_encoder_path, padding_side="right") + tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path, padding_side="right") text_encoder_2 = CLIPTextModel.from_pretrained(args.text_encoder_2_path, torch_dtype=torch.float16) tokenizer_2 = CLIPTokenizer.from_pretrained(args.text_encoder_2_path) + scheduler = FlowMatchEulerDiscreteScheduler(shift=7.0) + print(tokenizer_2) pipe = HunyuanVideoPipeline( transformer=transformer, @@ -242,5 +251,6 @@ def get_args(): tokenizer=tokenizer, text_encoder_2=text_encoder_2, tokenizer_2=tokenizer_2, + scheduler=scheduler, ) pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB") diff --git a/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py b/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py index 1a8ee40770a1..337caee2fb40 100644 --- a/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +++ b/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py @@ -13,24 +13,19 @@ # limitations under the License. import inspect -from dataclasses import dataclass from typing import Any, Callable, Dict, List, Optional, Tuple, Union -import numpy as np import torch -from transformers import CLIPTextModel, CLIPTokenizer, LlamaForCausalLM, LlamaTokenizerFast +from transformers import CLIPTextModel, CLIPTokenizer, LlamaModel, LlamaTokenizerFast from ...callbacks import MultiPipelineCallbacks, PipelineCallback -from ...image_processor import VaeImageProcessor from ...models import AutoencoderKLHunyuanVideo, HunyuanVideoTransformer3DModel from ...schedulers import KarrasDiffusionSchedulers -from ...utils import ( - BaseOutput, - logging, - replace_example_docstring, -) +from ...utils import logging, replace_example_docstring from ...utils.torch_utils import randn_tensor +from ...video_processor import VideoProcessor from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import HunyuanVideoPipelineOutput logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -111,18 +106,6 @@ def retrieve_timesteps( return timesteps, num_inference_steps -PRECISION_TO_TYPE = { - "fp32": torch.float32, - "fp16": torch.float16, - "bf16": torch.bfloat16, -} - - -@dataclass -class HunyuanVideoPipelineOutput(BaseOutput): - videos: Union[torch.Tensor, np.ndarray] - - class HunyuanVideoPipeline(DiffusionPipeline): r""" Pipeline for text-to-video generation using HunyuanVideo. @@ -131,29 +114,36 @@ class HunyuanVideoPipeline(DiffusionPipeline): implemented for all pipelines (downloading, saving, running on a particular device, etc.). Args: - vae ([`AutoencoderKL`]): - Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. - text_encoder ([`TextEncoder`]): - Frozen text-encoder. - text_encoder_2 ([`TextEncoder`]): - Frozen text-encoder_2. - transformer ([`HYVideoDiffusionTransformer`]): - A `HYVideoDiffusionTransformer` to denoise the encoded video latents. - scheduler ([`SchedulerMixin`]): - A scheduler to be used in combination with `unet` to denoise the encoded image latents. + text_encoder ([`LlamaModel`]): + [Llava Llama3-8B](https://huggingface.co/xtuner/llava-llama-3-8b-v1_1-transformers). + tokenizer_2 (`LlamaTokenizer`): + Tokenizer from [Llava Llama3-8B](https://huggingface.co/xtuner/llava-llama-3-8b-v1_1-transformers). + transformer ([`HunyuanVideoTransformer3DModel`]): + Conditional Transformer to denoise the encoded image latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKLHunyuanVideo`]): + Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. + text_encoder_2 ([`T5EncoderModel`]): + [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically + the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant. + text_encoder_2 ([`CLIPTextModel`]): + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + tokenizer_2 (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer). """ model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae" - _optional_components = ["text_encoder_2"] - _exclude_from_cpu_offload = ["transformer"] - _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + _callback_tensor_inputs = ["latents", "prompt_embeds"] def __init__( self, - vae: AutoencoderKLHunyuanVideo, - text_encoder: LlamaForCausalLM, + text_encoder: LlamaModel, tokenizer: LlamaTokenizerFast, transformer: HunyuanVideoTransformer3DModel, + vae: AutoencoderKLHunyuanVideo, scheduler: KarrasDiffusionSchedulers, text_encoder_2: Optional[CLIPTextModel] = None, tokenizer_2: Optional[CLIPTokenizer] = None, @@ -176,7 +166,7 @@ def __init__( self.vae_scale_factor_spatial = ( self.vae.spatial_compression_ratio if hasattr(self, "vae") and self.vae is not None else 8 ) - self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor_spatial) + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) def _get_llama_prompt_embeds( self, @@ -397,14 +387,39 @@ def prepare_latents( latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) return latents + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + self.vae.enable_slicing() + + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_slicing() + + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + """ + self.vae.enable_tiling() + + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_tiling() + @property def guidance_scale(self): return self._guidance_scale - @property - def clip_skip(self): - return self._clip_skip - @property def num_timesteps(self): return self._num_timesteps @@ -549,13 +564,11 @@ def __call__( max_sequence_length, ) - target_dtype = torch.bfloat16 # Note(aryan): This has been hardcoded for now from the original repo - vae_dtype = torch.float16 # Note(aryan): This has been hardcoded for now from the original repo - - prompt_embeds = prompt_embeds.to(target_dtype) - prompt_attention_mask = prompt_attention_mask.to(target_dtype) + transformer_dtype = self.transformer.dtype + prompt_embeds = prompt_embeds.to(transformer_dtype) + prompt_attention_mask = prompt_attention_mask.to(transformer_dtype) if prompt_embeds_2 is not None: - prompt_embeds_2 = prompt_embeds_2.to(target_dtype) + prompt_embeds_2 = prompt_embeds_2.to(transformer_dtype) # 4. Prepare timesteps timesteps, num_inference_steps = retrieve_timesteps( @@ -574,12 +587,14 @@ def __call__( height, width, num_latent_frames, - target_dtype, + torch.float32, device, generator, latents, ) + guidance = torch.tensor([guidance_scale] * latents.shape[0], dtype=transformer_dtype, device=device) * 1000.0 + # 7. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order self._num_timesteps = len(timesteps) @@ -589,23 +604,17 @@ def __call__( if self.interrupt: continue + latent_model_input = latents.to(transformer_dtype) # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = t.expand(latents.shape[0]).to(latents.dtype) - guidance_expand = ( - torch.tensor([guidance_scale] * latents.shape[0], dtype=torch.float32, device=device).to( - target_dtype - ) - * 1000.0 - ) - noise_pred = self.transformer( - hidden_states=latents, + hidden_states=latent_model_input, timestep=timestep, encoder_hidden_states=prompt_embeds, encoder_attention_mask=prompt_attention_mask, encoder_hidden_states_2=prompt_embeds_2, - guidance=guidance_expand, + guidance=guidance, return_dict=False, )[0] @@ -627,24 +636,25 @@ def __call__( if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() - latents = latents.to(vae_dtype) + latents = latents.to(self.vae.dtype) if not output_type == "latent": latents = latents / self.vae.config.scaling_factor - image = self.vae.decode(latents, return_dict=False)[0] + video = self.vae.decode(latents, return_dict=False)[0] - torch.save(image, "diffusers_latents_decoded.pt") + torch.save(video, "diffusers_latents_decoded.pt") + video = self.video_processor.postprocess_video(video, output_type=output_type) else: - image = latents + video = latents - image = (image / 2 + 0.5).clamp(0, 1) + video = (video / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 - image = image.cpu().float() + video = video.cpu().float() # Offload all models self.maybe_free_model_hooks() if not return_dict: - return image + return video - return HunyuanVideoPipelineOutput(videos=image) + return HunyuanVideoPipelineOutput(videos=video) diff --git a/src/diffusers/pipelines/hunyuan_video/pipeline_output.py b/src/diffusers/pipelines/hunyuan_video/pipeline_output.py new file mode 100644 index 000000000000..c5cb853e3932 --- /dev/null +++ b/src/diffusers/pipelines/hunyuan_video/pipeline_output.py @@ -0,0 +1,20 @@ +from dataclasses import dataclass + +import torch + +from diffusers.utils import BaseOutput + + +@dataclass +class HunyuanVideoPipelineOutput(BaseOutput): + r""" + Output class for HunyuanVideo pipelines. + + Args: + frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]): + List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing + denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape + `(batch_size, num_frames, channels, height, width)`. + """ + + frames: torch.Tensor From dbba9c7485093362dc254755b26e13113a5f7ed7 Mon Sep 17 00:00:00 2001 From: Aryan Date: Sat, 14 Dec 2024 01:22:54 +0100 Subject: [PATCH 45/58] add tests for transformer --- .../models/hunyuan_video_transformer_3d.md | 2 +- src/diffusers/models/attention_processor.py | 6 +- .../transformers/transformer_hunyuan_video.py | 12 +-- .../hunyuan_video/pipeline_hunyuan_video.py | 30 +++---- .../test_models_transformer_hunyuan_video.py | 89 +++++++++++++++++++ 5 files changed, 114 insertions(+), 25 deletions(-) create mode 100644 tests/models/transformers/test_models_transformer_hunyuan_video.py diff --git a/docs/source/en/api/models/hunyuan_video_transformer_3d.md b/docs/source/en/api/models/hunyuan_video_transformer_3d.md index 18fcb80e55ca..0c1540d8cbe5 100644 --- a/docs/source/en/api/models/hunyuan_video_transformer_3d.md +++ b/docs/source/en/api/models/hunyuan_video_transformer_3d.md @@ -21,7 +21,7 @@ TODO ## HunyuanVideoTransformer3DModel -[[autodoc]] MochiTransformer3DModel +[[autodoc]] HunyuanVideoTransformer3DModel ## Transformer2DModelOutput diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 8d06498c1b59..d205bef384c6 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -790,7 +790,11 @@ def fuse_projections(self, fuse=True): self.to_kv.bias.copy_(concatenated_bias) # handle added projections for SD3 and others. - if hasattr(self, "add_q_proj") and hasattr(self, "add_k_proj") and hasattr(self, "add_v_proj"): + if ( + getattr(self, "add_q_proj", None) is not None + and getattr(self, "add_k_proj", None) is not None + and getattr(self, "add_v_proj", None) is not None + ): concatenated_weights = torch.cat( [self.add_q_proj.weight.data, self.add_k_proj.weight.data, self.add_v_proj.weight.data] ) diff --git a/src/diffusers/models/transformers/transformer_hunyuan_video.py b/src/diffusers/models/transformers/transformer_hunyuan_video.py index 22939e205fb0..2f5e6781717f 100644 --- a/src/diffusers/models/transformers/transformer_hunyuan_video.py +++ b/src/diffusers/models/transformers/transformer_hunyuan_video.py @@ -496,12 +496,12 @@ def __init__( mlp_ratio: float = 4.0, patch_size: int = 2, patch_size_t: int = 1, - rope_dim_list: List[int] = [16, 56, 56], qk_norm: str = "rms_norm", guidance_embeds: bool = True, text_embed_dim: int = 4096, - text_embed_dim_2: int = 768, + pooled_projection_dim: int = 768, rope_theta: float = 256.0, + rope_axes_dim: Tuple[int] = (16, 56, 56), ) -> None: super().__init__() @@ -516,10 +516,10 @@ def __init__( text_embed_dim, num_attention_heads, attention_head_dim, num_layers=num_refiner_layers ) - self.time_text_embed = CombinedTimestepGuidanceTextProjEmbeddings(inner_dim, text_embed_dim_2) + self.time_text_embed = CombinedTimestepGuidanceTextProjEmbeddings(inner_dim, pooled_projection_dim) # 3. RoPE - self.rope = HunyuanVideoRotaryPosEmbed(patch_size, patch_size_t, rope_dim_list, rope_theta) + self.rope = HunyuanVideoRotaryPosEmbed(patch_size, patch_size_t, rope_axes_dim, rope_theta) self.transformer_blocks = nn.ModuleList( [ @@ -614,7 +614,7 @@ def forward( timestep: torch.LongTensor, encoder_hidden_states: torch.Tensor, encoder_attention_mask: torch.Tensor, - encoder_hidden_states_2: torch.Tensor, + pooled_projections: torch.Tensor, guidance: torch.Tensor = None, return_dict: bool = True, ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: @@ -627,7 +627,7 @@ def forward( image_rotary_emb = self.rope(hidden_states) - temb = self.time_text_embed(timestep, guidance, encoder_hidden_states_2) + temb = self.time_text_embed(timestep, guidance, pooled_projections) # Embed image and text. hidden_states = self.img_in(hidden_states) diff --git a/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py b/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py index 337caee2fb40..4d22a5eb70c5 100644 --- a/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +++ b/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py @@ -280,7 +280,7 @@ def encode_prompt( prompt_template: Dict[str, Any] = DEFAULT_PROMPT_TEMPLATE, num_videos_per_prompt: int = 1, prompt_embeds: Optional[torch.Tensor] = None, - prompt_embeds_2: Optional[torch.Tensor] = None, + pooled_prompt_embeds: Optional[torch.Tensor] = None, prompt_attention_mask: Optional[torch.Tensor] = None, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, @@ -296,10 +296,10 @@ def encode_prompt( max_sequence_length=max_sequence_length, ) - if prompt_embeds_2 is None and self.text_encoder_2 is not None: - if prompt_2 is None and prompt_embeds_2 is None: + if pooled_prompt_embeds is None and self.text_encoder_2 is not None: + if prompt_2 is None and pooled_prompt_embeds is None: prompt_2 = prompt - prompt_embeds_2 = self._get_clip_prompt_embeds( + pooled_prompt_embeds = self._get_clip_prompt_embeds( prompt, num_videos_per_prompt, device=device, @@ -307,7 +307,7 @@ def encode_prompt( max_sequence_length=77, ) - return prompt_embeds, prompt_embeds_2, prompt_attention_mask + return prompt_embeds, pooled_prompt_embeds, prompt_attention_mask def check_inputs( self, @@ -444,7 +444,7 @@ def __call__( generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, prompt_embeds: Optional[torch.Tensor] = None, - prompt_embeds_2: Optional[torch.Tensor] = None, + pooled_prompt_embeds: Optional[torch.Tensor] = None, prompt_attention_mask: Optional[torch.Tensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, @@ -552,13 +552,13 @@ def __call__( batch_size = prompt_embeds.shape[0] # 3. Encode input prompt - prompt_embeds, prompt_embeds_2, prompt_attention_mask = self.encode_prompt( + prompt_embeds, pooled_prompt_embeds, prompt_attention_mask = self.encode_prompt( prompt, prompt_2, prompt_template, num_videos_per_prompt, prompt_embeds, - prompt_embeds_2, + pooled_prompt_embeds, prompt_attention_mask, device, max_sequence_length, @@ -567,8 +567,8 @@ def __call__( transformer_dtype = self.transformer.dtype prompt_embeds = prompt_embeds.to(transformer_dtype) prompt_attention_mask = prompt_attention_mask.to(transformer_dtype) - if prompt_embeds_2 is not None: - prompt_embeds_2 = prompt_embeds_2.to(transformer_dtype) + if pooled_prompt_embeds is not None: + pooled_prompt_embeds = pooled_prompt_embeds.to(transformer_dtype) # 4. Prepare timesteps timesteps, num_inference_steps = retrieve_timesteps( @@ -613,7 +613,7 @@ def __call__( timestep=timestep, encoder_hidden_states=prompt_embeds, encoder_attention_mask=prompt_attention_mask, - encoder_hidden_states_2=prompt_embeds_2, + pooled_projections=pooled_prompt_embeds, guidance=guidance, return_dict=False, )[0] @@ -647,14 +647,10 @@ def __call__( else: video = latents - video = (video / 2 + 0.5).clamp(0, 1) - # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 - video = video.cpu().float() - # Offload all models self.maybe_free_model_hooks() if not return_dict: - return video + return (video,) - return HunyuanVideoPipelineOutput(videos=video) + return HunyuanVideoPipelineOutput(frames=video) diff --git a/tests/models/transformers/test_models_transformer_hunyuan_video.py b/tests/models/transformers/test_models_transformer_hunyuan_video.py new file mode 100644 index 000000000000..e8ea8cecbb9e --- /dev/null +++ b/tests/models/transformers/test_models_transformer_hunyuan_video.py @@ -0,0 +1,89 @@ +# Copyright 2024 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch + +from diffusers import HunyuanVideoTransformer3DModel +from diffusers.utils.testing_utils import enable_full_determinism, torch_device + +from ..test_modeling_common import ModelTesterMixin + + +enable_full_determinism() + + +class HunyuanVideoTransformer3DTests(ModelTesterMixin, unittest.TestCase): + model_class = HunyuanVideoTransformer3DModel + main_input_name = "hidden_states" + uses_custom_attn_processor = True + + @property + def dummy_input(self): + batch_size = 1 + num_channels = 4 + num_frames = 1 + height = 16 + width = 16 + text_encoder_embedding_dim = 16 + pooled_projection_dim = 8 + sequence_length = 12 + + hidden_states = torch.randn((batch_size, num_channels, num_frames, height, width)).to(torch_device) + timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device) + encoder_hidden_states = torch.randn((batch_size, sequence_length, text_encoder_embedding_dim)).to(torch_device) + pooled_projections = torch.randn((batch_size, pooled_projection_dim)).to(torch_device) + encoder_attention_mask = torch.ones((batch_size, sequence_length)).to(torch_device) + guidance = torch.randint(0, 1000, size=(batch_size,)).to(torch_device, dtype=torch.float32) + + return { + "hidden_states": hidden_states, + "timestep": timestep, + "encoder_hidden_states": encoder_hidden_states, + "pooled_projections": pooled_projections, + "encoder_attention_mask": encoder_attention_mask, + "guidance": guidance, + } + + @property + def input_shape(self): + return (4, 1, 16, 16) + + @property + def output_shape(self): + return (4, 1, 16, 16) + + def prepare_init_args_and_inputs_for_common(self): + init_dict = { + "in_channels": 4, + "out_channels": 4, + "num_attention_heads": 2, + "attention_head_dim": 10, + "num_layers": 1, + "num_single_layers": 1, + "num_refiner_layers": 1, + "patch_size": 1, + "patch_size_t": 1, + "guidance_embeds": True, + "text_embed_dim": 16, + "pooled_projection_dim": 8, + "rope_axes_dim": (2, 4, 4), + } + inputs_dict = self.dummy_input + return init_dict, inputs_dict + + def test_gradient_checkpointing_is_applied(self): + expected_set = {"HunyuanVideoTransformer3DModel"} + super().test_gradient_checkpointing_is_applied(expected_set=expected_set) From 36dea1099bbd67f98ee611ac7f862a671e88257d Mon Sep 17 00:00:00 2001 From: Aryan Date: Sat, 14 Dec 2024 01:50:26 +0100 Subject: [PATCH 46/58] add pipeline tests; text encoder 2 is not optional --- .../hunyuan_video/pipeline_hunyuan_video.py | 32 +- tests/pipelines/hunyuan_video/__init__.py | 0 .../hunyuan_video/test_hunyuan_video.py | 331 ++++++++++++++++++ 3 files changed, 347 insertions(+), 16 deletions(-) create mode 100644 tests/pipelines/hunyuan_video/__init__.py create mode 100644 tests/pipelines/hunyuan_video/test_hunyuan_video.py diff --git a/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py b/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py index 4d22a5eb70c5..a9eab0ddbb8a 100644 --- a/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +++ b/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py @@ -145,8 +145,8 @@ def __init__( transformer: HunyuanVideoTransformer3DModel, vae: AutoencoderKLHunyuanVideo, scheduler: KarrasDiffusionSchedulers, - text_encoder_2: Optional[CLIPTextModel] = None, - tokenizer_2: Optional[CLIPTokenizer] = None, + text_encoder_2: CLIPTextModel, + tokenizer_2: CLIPTokenizer, ): super().__init__() @@ -179,7 +179,7 @@ def _get_llama_prompt_embeds( num_hidden_layers_to_skip: int = 2, ) -> Tuple[torch.Tensor, torch.Tensor]: device = device or self._execution_device - dtype = dtype or self.text_encoder_2.dtype + dtype = dtype or self.text_encoder.dtype prompt = [prompt] if isinstance(prompt, str) else prompt batch_size = len(prompt) @@ -211,15 +211,15 @@ def _get_llama_prompt_embeds( return_overflowing_tokens=False, return_attention_mask=True, ) - text_input_ids = text_inputs.input_ids.to(device) - prompt_attention_mask = text_inputs.attention_mask.to(device) + text_input_ids = text_inputs.input_ids.to(device=device) + prompt_attention_mask = text_inputs.attention_mask.to(device=device) prompt_embeds = self.text_encoder( input_ids=text_input_ids, attention_mask=prompt_attention_mask, output_hidden_states=True, ).hidden_states[-(num_hidden_layers_to_skip + 1)] - prompt_embeds = prompt_embeds.to(dtype) + prompt_embeds = prompt_embeds.to(dtype=dtype) if crop_start is not None and crop_start > 0: prompt_embeds = prompt_embeds[:, crop_start:] @@ -296,7 +296,7 @@ def encode_prompt( max_sequence_length=max_sequence_length, ) - if pooled_prompt_embeds is None and self.text_encoder_2 is not None: + if pooled_prompt_embeds is None: if prompt_2 is None and pooled_prompt_embeds is None: prompt_2 = prompt pooled_prompt_embeds = self._get_clip_prompt_embeds( @@ -553,15 +553,15 @@ def __call__( # 3. Encode input prompt prompt_embeds, pooled_prompt_embeds, prompt_attention_mask = self.encode_prompt( - prompt, - prompt_2, - prompt_template, - num_videos_per_prompt, - prompt_embeds, - pooled_prompt_embeds, - prompt_attention_mask, - device, - max_sequence_length, + prompt=prompt, + prompt_2=prompt_2, + prompt_template=prompt_template, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + device=device, + max_sequence_length=max_sequence_length, ) transformer_dtype = self.transformer.dtype diff --git a/tests/pipelines/hunyuan_video/__init__.py b/tests/pipelines/hunyuan_video/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/pipelines/hunyuan_video/test_hunyuan_video.py b/tests/pipelines/hunyuan_video/test_hunyuan_video.py new file mode 100644 index 000000000000..567002268106 --- /dev/null +++ b/tests/pipelines/hunyuan_video/test_hunyuan_video.py @@ -0,0 +1,331 @@ +# Copyright 2024 The HuggingFace Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import unittest + +import numpy as np +import torch +from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer, LlamaConfig, LlamaModel, LlamaTokenizer + +from diffusers import ( + AutoencoderKLHunyuanVideo, + FlowMatchEulerDiscreteScheduler, + HunyuanVideoPipeline, + HunyuanVideoTransformer3DModel, +) +from diffusers.utils.testing_utils import ( + enable_full_determinism, + torch_device, +) + +from ..test_pipelines_common import PipelineTesterMixin, to_np + + +enable_full_determinism() + + +class HunyuanVideoPipelineFastTests(PipelineTesterMixin, unittest.TestCase): + pipeline_class = HunyuanVideoPipeline + params = frozenset(["prompt", "height", "width", "guidance_scale", "prompt_embeds", "pooled_prompt_embeds"]) + batch_params = frozenset(["prompt"]) + required_optional_params = frozenset( + [ + "num_inference_steps", + "generator", + "latents", + "return_dict", + "callback_on_step_end", + "callback_on_step_end_tensor_inputs", + ] + ) + + # there is no xformers processor for Flux + test_xformers_attention = False + + def get_dummy_components(self): + torch.manual_seed(0) + transformer = HunyuanVideoTransformer3DModel( + in_channels=4, + out_channels=4, + num_attention_heads=2, + attention_head_dim=10, + num_layers=1, + num_single_layers=1, + num_refiner_layers=1, + patch_size=1, + patch_size_t=1, + guidance_embeds=True, + text_embed_dim=16, + pooled_projection_dim=8, + rope_axes_dim=(2, 4, 4), + ) + + torch.manual_seed(0) + vae = AutoencoderKLHunyuanVideo( + in_channels=3, + out_channels=3, + latent_channels=4, + down_block_types=( + "HunyuanVideoDownBlock3D", + "HunyuanVideoDownBlock3D", + "HunyuanVideoDownBlock3D", + "HunyuanVideoDownBlock3D", + ), + up_block_types=( + "HunyuanVideoUpBlock3D", + "HunyuanVideoUpBlock3D", + "HunyuanVideoUpBlock3D", + "HunyuanVideoUpBlock3D", + ), + block_out_channels=(8, 8, 8, 8), + layers_per_block=1, + act_fn="silu", + norm_num_groups=4, + scaling_factor=0.476986, + spatial_compression_ratio=8, + temporal_compression_ratio=4, + mid_block_add_attention=True, + ) + + torch.manual_seed(0) + scheduler = FlowMatchEulerDiscreteScheduler(shift=7.0) + + llama_text_encoder_config = LlamaConfig( + bos_token_id=0, + eos_token_id=2, + hidden_size=16, + intermediate_size=37, + layer_norm_eps=1e-05, + num_attention_heads=4, + num_hidden_layers=2, + pad_token_id=1, + vocab_size=1000, + hidden_act="gelu", + projection_dim=32, + ) + clip_text_encoder_config = CLIPTextConfig( + bos_token_id=0, + eos_token_id=2, + hidden_size=8, + intermediate_size=37, + layer_norm_eps=1e-05, + num_attention_heads=4, + num_hidden_layers=2, + pad_token_id=1, + vocab_size=1000, + hidden_act="gelu", + projection_dim=32, + ) + + torch.manual_seed(0) + text_encoder = LlamaModel(llama_text_encoder_config) + tokenizer = LlamaTokenizer.from_pretrained("hf-internal-testing/tiny-random-LlamaForCausalLM") + + torch.manual_seed(0) + text_encoder_2 = CLIPTextModel(clip_text_encoder_config) + tokenizer_2 = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + + components = { + "transformer": transformer, + "vae": vae, + "scheduler": scheduler, + "text_encoder": text_encoder, + "text_encoder_2": text_encoder_2, + "tokenizer": tokenizer, + "tokenizer_2": tokenizer_2, + } + return components + + def get_dummy_inputs(self, device, seed=0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + + # Cannot test with dummy prompt because tokenizers are not configured correctly. + # TODO(aryan): create dummy tokenizers and using from hub + inputs = { + "prompt": "", + "prompt_template": { + "template": "{}", + "crop_start": 0, + }, + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 4.5, + "height": 16, + "width": 16, + # 4 * k + 1 is the recommendation + "num_frames": 9, + "max_sequence_length": 16, + "output_type": "pt", + } + return inputs + + def test_inference(self): + device = "cpu" + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + video = pipe(**inputs).frames + generated_video = video[0] + + self.assertEqual(generated_video.shape, (9, 3, 16, 16)) + expected_video = torch.randn(9, 3, 16, 16) + max_diff = np.abs(generated_video - expected_video).max() + self.assertLessEqual(max_diff, 1e10) + + def test_callback_inputs(self): + sig = inspect.signature(self.pipeline_class.__call__) + has_callback_tensor_inputs = "callback_on_step_end_tensor_inputs" in sig.parameters + has_callback_step_end = "callback_on_step_end" in sig.parameters + + if not (has_callback_tensor_inputs and has_callback_step_end): + return + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + self.assertTrue( + hasattr(pipe, "_callback_tensor_inputs"), + f" {self.pipeline_class} should have `_callback_tensor_inputs` that defines a list of tensor variables its callback function can use as inputs", + ) + + def callback_inputs_subset(pipe, i, t, callback_kwargs): + # iterate over callback args + for tensor_name, tensor_value in callback_kwargs.items(): + # check that we're only passing in allowed tensor inputs + assert tensor_name in pipe._callback_tensor_inputs + + return callback_kwargs + + def callback_inputs_all(pipe, i, t, callback_kwargs): + for tensor_name in pipe._callback_tensor_inputs: + assert tensor_name in callback_kwargs + + # iterate over callback args + for tensor_name, tensor_value in callback_kwargs.items(): + # check that we're only passing in allowed tensor inputs + assert tensor_name in pipe._callback_tensor_inputs + + return callback_kwargs + + inputs = self.get_dummy_inputs(torch_device) + + # Test passing in a subset + inputs["callback_on_step_end"] = callback_inputs_subset + inputs["callback_on_step_end_tensor_inputs"] = ["latents"] + output = pipe(**inputs)[0] + + # Test passing in a everything + inputs["callback_on_step_end"] = callback_inputs_all + inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs + output = pipe(**inputs)[0] + + def callback_inputs_change_tensor(pipe, i, t, callback_kwargs): + is_last = i == (pipe.num_timesteps - 1) + if is_last: + callback_kwargs["latents"] = torch.zeros_like(callback_kwargs["latents"]) + return callback_kwargs + + inputs["callback_on_step_end"] = callback_inputs_change_tensor + inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs + output = pipe(**inputs)[0] + assert output.abs().sum() < 1e10 + + def test_attention_slicing_forward_pass( + self, test_max_difference=True, test_mean_pixel_difference=True, expected_max_diff=1e-3 + ): + if not self.test_attention_slicing: + return + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + for component in pipe.components.values(): + if hasattr(component, "set_default_attn_processor"): + component.set_default_attn_processor() + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + generator_device = "cpu" + inputs = self.get_dummy_inputs(generator_device) + output_without_slicing = pipe(**inputs)[0] + + pipe.enable_attention_slicing(slice_size=1) + inputs = self.get_dummy_inputs(generator_device) + output_with_slicing1 = pipe(**inputs)[0] + + pipe.enable_attention_slicing(slice_size=2) + inputs = self.get_dummy_inputs(generator_device) + output_with_slicing2 = pipe(**inputs)[0] + + if test_max_difference: + max_diff1 = np.abs(to_np(output_with_slicing1) - to_np(output_without_slicing)).max() + max_diff2 = np.abs(to_np(output_with_slicing2) - to_np(output_without_slicing)).max() + self.assertLess( + max(max_diff1, max_diff2), + expected_max_diff, + "Attention slicing should not affect the inference results", + ) + + def test_vae_tiling(self, expected_diff_max: float = 0.2): + # Seems to require higher tolerance than the other tests + expected_diff_max = 0.6 + generator_device = "cpu" + components = self.get_dummy_components() + + pipe = self.pipeline_class(**components) + pipe.to("cpu") + pipe.set_progress_bar_config(disable=None) + + # Without tiling + inputs = self.get_dummy_inputs(generator_device) + inputs["height"] = inputs["width"] = 128 + output_without_tiling = pipe(**inputs)[0] + + # With tiling + pipe.vae.enable_tiling( + tile_sample_min_height=96, + tile_sample_min_width=96, + tile_sample_stride_height=64, + tile_sample_stride_width=64, + ) + inputs = self.get_dummy_inputs(generator_device) + inputs["height"] = inputs["width"] = 128 + output_with_tiling = pipe(**inputs)[0] + + self.assertLess( + (to_np(output_without_tiling) - to_np(output_with_tiling)).max(), + expected_diff_max, + "VAE tiling should not affect the inference results", + ) + + # TODO(aryan): Create a dummy gemma model with smol vocab size + @unittest.skip( + "A very small vocab size is used for fast tests. So, any kind of prompt other than the empty default used in other tests will lead to a embedding lookup error. This test uses a long prompt that causes the error." + ) + def test_inference_batch_consistent(self): + pass + + @unittest.skip( + "A very small vocab size is used for fast tests. So, any kind of prompt other than the empty default used in other tests will lead to a embedding lookup error. This test uses a long prompt that causes the error." + ) + def test_inference_batch_single_identical(self): + pass From 1c7b31751f9b7dd06eb66856d4ca035886b1a9d8 Mon Sep 17 00:00:00 2001 From: Aryan Date: Sat, 14 Dec 2024 16:27:43 +0100 Subject: [PATCH 47/58] fix attention implementation for torch --- .../autoencoder_kl_hunyuan_video.py | 6 ++--- .../transformers/transformer_hunyuan_video.py | 26 ++++++++++++++++--- 2 files changed, 25 insertions(+), 7 deletions(-) diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py b/src/diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py index 3914d9740790..496677671c5d 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py @@ -789,6 +789,7 @@ def __init__( # The minimal tile temporal batch size for temporal tiling to be used self.tile_sample_min_tsize = 64 + self.tile_overlap_factor = 0.25 # The minimal distance between two spatial tiles self.tile_sample_stride_height = 192 @@ -1108,13 +1109,12 @@ def temporal_tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Un overlap_size = int(tile_latent_min_tsize * (1 - self.tile_overlap_factor)) blend_extent = int(self.tile_sample_min_tsize * self.tile_overlap_factor) t_limit = self.tile_sample_min_tsize - blend_extent + tile_latent_min_size = self.tile_sample_min_size // self.spatial_compression_ratio row = [] for i in range(0, T, overlap_size): tile = z[:, :, i : i + tile_latent_min_tsize + 1, :, :] - if self.use_tiling and ( - tile.shape[-1] > self.tile_latent_min_size or tile.shape[-2] > self.tile_latent_min_size - ): + if self.use_tiling and (tile.shape[-1] > tile_latent_min_size or tile.shape[-2] > tile_latent_min_size): decoded = self.tiled_decode(tile, return_dict=True).sample else: tile = self.post_quant_conv(tile) diff --git a/src/diffusers/models/transformers/transformer_hunyuan_video.py b/src/diffusers/models/transformers/transformer_hunyuan_video.py index 2f5e6781717f..7036ebdf5212 100644 --- a/src/diffusers/models/transformers/transformer_hunyuan_video.py +++ b/src/diffusers/models/transformers/transformer_hunyuan_video.py @@ -379,6 +379,7 @@ def forward( hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, temb: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, ) -> torch.Tensor: text_seq_length = encoder_hidden_states.shape[1] @@ -397,6 +398,7 @@ def forward( attn_output, context_attn_output = self.attn( hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states, + attention_mask=attention_mask, image_rotary_emb=image_rotary_emb, ) attn_output = torch.cat([attn_output, context_attn_output], dim=1) @@ -452,6 +454,7 @@ def forward( hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, temb: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, freqs_cis: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb) @@ -462,6 +465,7 @@ def forward( img_attn, txt_attn = self.attn( hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states, + attention_mask=attention_mask, image_rotary_emb=freqs_cis, ) @@ -619,8 +623,7 @@ def forward( return_dict: bool = True, ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: batch_size, num_channels, num_frames, height, width = hidden_states.shape - p = self.config.patch_size - p_t = self.config.patch_size_t + p, p_t = self.config.patch_size, self.config.patch_size_t post_patch_num_frames = num_frames // p_t post_patch_height = height // p post_patch_width = width // p @@ -633,6 +636,19 @@ def forward( hidden_states = self.img_in(hidden_states) encoder_hidden_states = self.txt_in(encoder_hidden_states, timestep, encoder_attention_mask) + latent_sequence_length = hidden_states.shape[1] + condition_sequence_length = encoder_hidden_states.shape[1] + sequence_length = latent_sequence_length + condition_sequence_length + attention_mask = torch.zeros( + batch_size, sequence_length, sequence_length, device=hidden_states.device, dtype=torch.bool + ) # [B, N, N] + + effective_condition_sequence_length = encoder_attention_mask.sum(dim=1, dtype=torch.int) # [B,] + effective_sequence_length = latent_sequence_length + effective_condition_sequence_length + + for i in range(batch_size): + attention_mask[i, : effective_sequence_length[i], : effective_sequence_length[i]] = True + if torch.is_grad_enabled() and self.gradient_checkpointing: def create_custom_forward(module, return_dict=None): @@ -652,6 +668,7 @@ def custom_forward(*inputs): hidden_states, encoder_hidden_states, temb, + attention_mask, image_rotary_emb, **ckpt_kwargs, ) @@ -662,6 +679,7 @@ def custom_forward(*inputs): hidden_states, encoder_hidden_states, temb, + attention_mask, image_rotary_emb, **ckpt_kwargs, ) @@ -669,12 +687,12 @@ def custom_forward(*inputs): else: for block in self.transformer_blocks: hidden_states, encoder_hidden_states = block( - hidden_states, encoder_hidden_states, temb, image_rotary_emb + hidden_states, encoder_hidden_states, temb, attention_mask, image_rotary_emb ) for block in self.single_transformer_blocks: hidden_states, encoder_hidden_states = block( - hidden_states, encoder_hidden_states, temb, image_rotary_emb + hidden_states, encoder_hidden_states, temb, attention_mask, image_rotary_emb ) hidden_states = self.norm_out(hidden_states, temb) From ca982278b04ca66a4751af42d19cf97b7fa28110 Mon Sep 17 00:00:00 2001 From: Aryan Date: Sat, 14 Dec 2024 21:37:32 +0100 Subject: [PATCH 48/58] add example --- scripts/convert_hunyuan_video_to_diffusers.py | 14 ++-- .../transformers/transformer_hunyuan_video.py | 77 +++++++++++-------- .../hunyuan_video/pipeline_hunyuan_video.py | 34 ++++++-- 3 files changed, 78 insertions(+), 47 deletions(-) diff --git a/scripts/convert_hunyuan_video_to_diffusers.py b/scripts/convert_hunyuan_video_to_diffusers.py index 8f960b90183d..13ced51240d3 100644 --- a/scripts/convert_hunyuan_video_to_diffusers.py +++ b/scripts/convert_hunyuan_video_to_diffusers.py @@ -20,10 +20,14 @@ def remap_norm_scale_shift_(key, state_dict): state_dict[key.replace("final_layer.adaLN_modulation.1", "norm_out.linear")] = new_weight -def remap_token_refiner_blocks_(key, state_dict): +def remap_txt_in_(key, state_dict): def rename_key(key): new_key = key.replace("individual_token_refiner.blocks", "token_refiner.refiner_blocks") new_key = new_key.replace("adaLN_modulation.1", "norm_out.linear") + new_key = new_key.replace("txt_in", "context_embedder") + new_key = new_key.replace("t_embedder.mlp.0", "time_text_embed.timestep_embedder.linear_1") + new_key = new_key.replace("t_embedder.mlp.2", "time_text_embed.timestep_embedder.linear_2") + new_key = new_key.replace("c_embedder", "time_text_embed.text_embedder") return new_key if "self_attn_qkv" in key: @@ -32,7 +36,6 @@ def rename_key(key): state_dict[rename_key(key.replace("self_attn_qkv", "attn.to_q"))] = to_q state_dict[rename_key(key.replace("self_attn_qkv", "attn.to_k"))] = to_k state_dict[rename_key(key.replace("self_attn_qkv", "attn.to_v"))] = to_v - else: state_dict[rename_key(key)] = state_dict.pop(key) @@ -85,15 +88,13 @@ def remap_single_transformer_blocks_(key, state_dict): TRANSFORMER_KEYS_RENAME_DICT = { + "img_in": "x_embedder", "time_in.mlp.0": "time_text_embed.timestep_embedder.linear_1", "time_in.mlp.2": "time_text_embed.timestep_embedder.linear_2", "guidance_in.mlp.0": "time_text_embed.guidance_embedder.linear_1", "guidance_in.mlp.2": "time_text_embed.guidance_embedder.linear_2", "vector_in.in_layer": "time_text_embed.text_embedder.linear_1", "vector_in.out_layer": "time_text_embed.text_embedder.linear_2", - "txt_in.t_embedder.mlp.0": "txt_in.time_text_embed.timestep_embedder.linear_1", - "txt_in.t_embedder.mlp.2": "txt_in.time_text_embed.timestep_embedder.linear_2", - "txt_in.c_embedder": "txt_in.time_text_embed.text_embedder", "double_blocks": "transformer_blocks", "img_attn_q_norm": "attn.norm_q", "img_attn_k_norm": "attn.norm_k", @@ -120,11 +121,11 @@ def remap_single_transformer_blocks_(key, state_dict): } TRANSFORMER_SPECIAL_KEYS_REMAP = { + "txt_in": remap_txt_in_, "img_attn_qkv": remap_img_attn_qkv_, "txt_attn_qkv": remap_txt_attn_qkv_, "single_blocks": remap_single_transformer_blocks_, "final_layer.adaLN_modulation.1": remap_norm_scale_shift_, - "individual_token_refiner.blocks": remap_token_refiner_blocks_, } VAE_KEYS_RENAME_DICT = {} @@ -242,7 +243,6 @@ def get_args(): text_encoder_2 = CLIPTextModel.from_pretrained(args.text_encoder_2_path, torch_dtype=torch.float16) tokenizer_2 = CLIPTokenizer.from_pretrained(args.text_encoder_2_path) scheduler = FlowMatchEulerDiscreteScheduler(shift=7.0) - print(tokenizer_2) pipe = HunyuanVideoPipeline( transformer=transformer, diff --git a/src/diffusers/models/transformers/transformer_hunyuan_video.py b/src/diffusers/models/transformers/transformer_hunyuan_video.py index 7036ebdf5212..51b51f739075 100644 --- a/src/diffusers/models/transformers/transformer_hunyuan_video.py +++ b/src/diffusers/models/transformers/transformer_hunyuan_video.py @@ -50,6 +50,7 @@ def __call__( if attn.add_q_proj is None and encoder_hidden_states is not None: hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1) + # 1. QKV projections query = attn.to_q(hidden_states) key = attn.to_k(hidden_states) value = attn.to_v(hidden_states) @@ -58,11 +59,13 @@ def __call__( key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2) value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2) + # 2. QK normalization if attn.norm_q is not None: query = attn.norm_q(query) if attn.norm_k is not None: key = attn.norm_k(key) + # 3. Rotational positional embeddings applied to latent stream if image_rotary_emb is not None: from ..embeddings import apply_rotary_emb @@ -85,6 +88,7 @@ def __call__( query = apply_rotary_emb(query, image_rotary_emb) key = apply_rotary_emb(key, image_rotary_emb) + # 4. Encoder condition QKV projection and normalization if attn.add_q_proj is not None and encoder_hidden_states is not None: encoder_query = attn.add_q_proj(encoder_hidden_states) encoder_key = attn.add_k_proj(encoder_hidden_states) @@ -103,12 +107,14 @@ def __call__( key = torch.cat([key, encoder_key], dim=2) value = torch.cat([value, encoder_value], dim=2) + # 5. Attention hidden_states = F.scaled_dot_product_attention( query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False ) hidden_states = hidden_states.transpose(1, 2).flatten(2, 3) hidden_states = hidden_states.to(query.dtype) + # 6. Output projection if encoder_hidden_states is not None: hidden_states, encoder_hidden_states = ( hidden_states[:, : -encoder_hidden_states.shape[1]], @@ -160,7 +166,7 @@ def forward( return gate_msa, gate_mlp -class IndividualTokenRefinerBlock(nn.Module): +class HunyuanVideoIndividualTokenRefinerBlock(nn.Module): def __init__( self, num_attention_heads: int, @@ -210,7 +216,7 @@ def forward( return hidden_states -class IndividualTokenRefiner(nn.Module): +class HunyuanVideoIndividualTokenRefiner(nn.Module): def __init__( self, num_attention_heads: int, @@ -224,7 +230,7 @@ def __init__( self.refiner_blocks = nn.ModuleList( [ - IndividualTokenRefinerBlock( + HunyuanVideoIndividualTokenRefinerBlock( num_attention_heads=num_attention_heads, attention_head_dim=attention_head_dim, mlp_width_ratio=mlp_width_ratio, @@ -257,7 +263,7 @@ def forward( return hidden_states -class SingleTokenRefiner(nn.Module): +class HunyuanVideoTokenRefiner(nn.Module): def __init__( self, in_channels: int, @@ -276,7 +282,7 @@ def __init__( embedding_dim=hidden_size, pooled_projection_dim=in_channels ) self.proj_in = nn.Linear(in_channels, hidden_size, bias=True) - self.token_refiner = IndividualTokenRefiner( + self.token_refiner = HunyuanVideoIndividualTokenRefiner( num_attention_heads=num_attention_heads, attention_head_dim=attention_head_dim, num_layers=num_layers, @@ -346,15 +352,11 @@ def __init__( attention_head_dim: int, mlp_ratio: float = 4.0, qk_norm: str = "rms_norm", - ): + ) -> None: super().__init__() hidden_size = num_attention_heads * attention_head_dim - mlp_hidden_dim = int(hidden_size * mlp_ratio) - - self.hidden_size = hidden_size - self.heads_num = num_attention_heads - self.mlp_hidden_dim = mlp_hidden_dim + mlp_dim = int(hidden_size * mlp_ratio) self.attn = Attention( query_dim=hidden_size, @@ -370,9 +372,9 @@ def __init__( ) self.norm = AdaLayerNormZeroSingle(hidden_size, norm_type="layer_norm") - self.proj_mlp = nn.Linear(hidden_size, self.mlp_hidden_dim) + self.proj_mlp = nn.Linear(hidden_size, mlp_dim) self.act_mlp = nn.GELU(approximate="tanh") - self.proj_out = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size) + self.proj_out = nn.Linear(hidden_size + mlp_dim, hidden_size) def forward( self, @@ -387,6 +389,7 @@ def forward( residual = hidden_states + # 1. Input normalization norm_hidden_states, gate = self.norm(hidden_states, emb=temb) mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states)) @@ -395,6 +398,7 @@ def forward( norm_hidden_states[:, -text_seq_length:, :], ) + # 2. Attention attn_output, context_attn_output = self.attn( hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states, @@ -403,6 +407,7 @@ def forward( ) attn_output = torch.cat([attn_output, context_attn_output], dim=1) + # 3. Modulation and residual connection hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2) hidden_states = gate.unsqueeze(1) * self.proj_out(hidden_states) hidden_states = hidden_states + residual @@ -457,30 +462,35 @@ def forward( attention_mask: Optional[torch.Tensor] = None, freqs_cis: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: + # 1. Input normalization norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb) norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context( encoder_hidden_states, emb=temb ) - img_attn, txt_attn = self.attn( + # 2. Joint attention + attn_output, context_attn_output = self.attn( hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states, attention_mask=attention_mask, image_rotary_emb=freqs_cis, ) - hidden_states = hidden_states + img_attn * gate_msa.unsqueeze(1) - encoder_hidden_states = encoder_hidden_states + txt_attn * c_gate_msa.unsqueeze(1) + # 3. Modulation and residual connection + hidden_states = hidden_states + attn_output * gate_msa.unsqueeze(1) + encoder_hidden_states = encoder_hidden_states + context_attn_output * c_gate_msa.unsqueeze(1) norm_hidden_states = self.norm2(hidden_states) - norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] - ff_output = self.ff(norm_hidden_states) - ff_output = gate_mlp.unsqueeze(1) * ff_output - hidden_states = hidden_states + ff_output - norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states) + + norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None] + + # 4. Feed-forward + ff_output = self.ff(norm_hidden_states) context_ff_output = self.ff_context(norm_encoder_hidden_states) + + hidden_states = hidden_states + gate_mlp.unsqueeze(1) * ff_output encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output return hidden_states, encoder_hidden_states @@ -512,19 +522,17 @@ def __init__( inner_dim = num_attention_heads * attention_head_dim out_channels = out_channels or in_channels - # image projection - self.img_in = PatchEmbed((patch_size_t, patch_size, patch_size), in_channels, inner_dim) - - # text projection - self.txt_in = SingleTokenRefiner( + # 1. Latent and condition embedders + self.x_embedder = PatchEmbed((patch_size_t, patch_size, patch_size), in_channels, inner_dim) + self.context_embedder = HunyuanVideoTokenRefiner( text_embed_dim, num_attention_heads, attention_head_dim, num_layers=num_refiner_layers ) - self.time_text_embed = CombinedTimestepGuidanceTextProjEmbeddings(inner_dim, pooled_projection_dim) - # 3. RoPE + # 2. RoPE self.rope = HunyuanVideoRotaryPosEmbed(patch_size, patch_size_t, rope_axes_dim, rope_theta) + # 3. Dual stream transformer blocks self.transformer_blocks = nn.ModuleList( [ HunyuanVideoTransformerBlock( @@ -534,6 +542,7 @@ def __init__( ] ) + # 4. Single stream transformer blocks self.single_transformer_blocks = nn.ModuleList( [ HunyuanVideoSingleTransformerBlock( @@ -543,6 +552,7 @@ def __init__( ] ) + # 5. Output projection self.norm_out = AdaLayerNormContinuous(inner_dim, inner_dim, elementwise_affine=False, eps=1e-6) self.proj_out = nn.Linear(inner_dim, patch_size_t * patch_size * patch_size * out_channels) @@ -628,14 +638,15 @@ def forward( post_patch_height = height // p post_patch_width = width // p + # 1. RoPE image_rotary_emb = self.rope(hidden_states) + # 2. Conditional embeddings temb = self.time_text_embed(timestep, guidance, pooled_projections) + hidden_states = self.x_embedder(hidden_states) + encoder_hidden_states = self.context_embedder(encoder_hidden_states, timestep, encoder_attention_mask) - # Embed image and text. - hidden_states = self.img_in(hidden_states) - encoder_hidden_states = self.txt_in(encoder_hidden_states, timestep, encoder_attention_mask) - + # 3. Attention mask preparation latent_sequence_length = hidden_states.shape[1] condition_sequence_length = encoder_hidden_states.shape[1] sequence_length = latent_sequence_length + condition_sequence_length @@ -649,6 +660,7 @@ def forward( for i in range(batch_size): attention_mask[i, : effective_sequence_length[i], : effective_sequence_length[i]] = True + # 4. Transformer blocks if torch.is_grad_enabled() and self.gradient_checkpointing: def create_custom_forward(module, return_dict=None): @@ -695,6 +707,7 @@ def custom_forward(*inputs): hidden_states, encoder_hidden_states, temb, attention_mask, image_rotary_emb ) + # 5. Output projection hidden_states = self.norm_out(hidden_states, temb) hidden_states = self.proj_out(hidden_states) diff --git a/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py b/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py index a9eab0ddbb8a..a6611d32e643 100644 --- a/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +++ b/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py @@ -30,7 +30,31 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name -EXAMPLE_DOC_STRING = """""" +EXAMPLE_DOC_STRING = """ + Examples: + ```python + >>> import torch + >>> from diffusers import HunyuanVideoPipeline, HunyuanVideoTransformer3DModel + >>> from diffusers.utils import export_to_video + + >>> model_id = "tencent/HunyuanVideo" + >>> transformer = HunyuanVideoTransformer3DModel.from_pretrained( + ... model_id, subfolder="transformer", torch_dtype=torch.bfloat16 + ... ) + >>> pipe = HunyuanVideoPipeline.from_pretrained(model_id, transformer=transformer, torch_dtype=torch.float16) + >>> pipe.vae.enable_tiling() + >>> pipe.to("cuda") + + >>> output = pipe( + ... prompt="A cat walks on the grass, realistic", + ... height=320, + ... width=512, + ... num_frames=61, + ... num_inference_steps=30, + ... ).frames[0] + >>> export_to_video(output, "output.mp4", fps=15) + ``` +""" DEFAULT_PROMPT_TEMPLATE = { @@ -621,8 +645,6 @@ def __call__( # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] - torch.save(latents, f"diffusers_refactor_latents_{i}.pt") - if callback_on_step_end is not None: callback_kwargs = {} for k in callback_on_step_end_tensor_inputs: @@ -636,14 +658,10 @@ def __call__( if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() - latents = latents.to(self.vae.dtype) if not output_type == "latent": - latents = latents / self.vae.config.scaling_factor + latents = latents.to(self.vae.dtype) / self.vae.config.scaling_factor video = self.vae.decode(latents, return_dict=False)[0] - - torch.save(video, "diffusers_latents_decoded.pt") video = self.video_processor.postprocess_video(video, output_type=output_type) - else: video = latents From 154b31c984248a2ab776c0fa27d8cf3c7398122a Mon Sep 17 00:00:00 2001 From: Aryan Date: Sat, 14 Dec 2024 21:47:37 +0100 Subject: [PATCH 49/58] update docs --- docs/source/en/_toctree.yml | 2 + .../models/autoencoder_kl_hunyuan_video.md | 6 ++- .../models/hunyuan_video_transformer_3d.md | 4 +- docs/source/en/api/pipelines/hunyuan_video.md | 42 +++++++++++++++++++ 4 files changed, 51 insertions(+), 3 deletions(-) create mode 100644 docs/source/en/api/pipelines/hunyuan_video.md diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index d630bf891a25..074e213a79fa 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -396,6 +396,8 @@ title: Flux - local: api/pipelines/hunyuandit title: Hunyuan-DiT + - local: api/pipelines/hunyuan_video + title: HunyuanVideo - local: api/pipelines/i2vgenxl title: I2VGen-XL - local: api/pipelines/pix2pix diff --git a/docs/source/en/api/models/autoencoder_kl_hunyuan_video.md b/docs/source/en/api/models/autoencoder_kl_hunyuan_video.md index 89679422d664..f69c14814d3d 100644 --- a/docs/source/en/api/models/autoencoder_kl_hunyuan_video.md +++ b/docs/source/en/api/models/autoencoder_kl_hunyuan_video.md @@ -16,10 +16,12 @@ The 3D variational autoencoder (VAE) model with KL loss used in [HunyuanVideo](h The model can be loaded with the following code snippet. ```python -TODO +from diffusers import AutoencoderKLHunyuanVideo + +vae = AutoencoderKLHunyuanVideo.from_pretrained("tencent/HunyuanVideo", torch_dtype=torch.float16) ``` -## AutoencoderKLMochi +## AutoencoderKLHunyuanVideo [[autodoc]] AutoencoderKLHunyuanVideo - decode diff --git a/docs/source/en/api/models/hunyuan_video_transformer_3d.md b/docs/source/en/api/models/hunyuan_video_transformer_3d.md index 0c1540d8cbe5..73aea9832fc0 100644 --- a/docs/source/en/api/models/hunyuan_video_transformer_3d.md +++ b/docs/source/en/api/models/hunyuan_video_transformer_3d.md @@ -16,7 +16,9 @@ A Diffusion Transformer model for 3D video-like data was introduced in [HunyuanV The model can be loaded with the following code snippet. ```python -TODO +from diffusers import HunyuanVideoTransformer3DModel + +transformer = HunyuanVideoTransformer3DModel.from_pretrained("tencent/HunyuanVideo", torch_dtype=torch.bfloat16) ``` ## HunyuanVideoTransformer3DModel diff --git a/docs/source/en/api/pipelines/hunyuan_video.md b/docs/source/en/api/pipelines/hunyuan_video.md new file mode 100644 index 000000000000..8cb877e07a25 --- /dev/null +++ b/docs/source/en/api/pipelines/hunyuan_video.md @@ -0,0 +1,42 @@ + + +# HunyuanVideo + +[HunyuanVideo](https://www.arxiv.org/abs/2412.03603) by Tencent. + +*Recent advancements in video generation have significantly impacted daily life for both individuals and industries. However, the leading video generation models remain closed-source, resulting in a notable performance gap between industry capabilities and those available to the public. In this report, we introduce HunyuanVideo, an innovative open-source video foundation model that demonstrates performance in video generation comparable to, or even surpassing, that of leading closed-source models. HunyuanVideo encompasses a comprehensive framework that integrates several key elements, including data curation, advanced architectural design, progressive model scaling and training, and an efficient infrastructure tailored for large-scale model training and inference. As a result, we successfully trained a video generative model with over 13 billion parameters, making it the largest among all open-source models. We conducted extensive experiments and implemented a series of targeted designs to ensure high visual quality, motion dynamics, text-video alignment, and advanced filming techniques. According to evaluations by professionals, HunyuanVideo outperforms previous state-of-the-art models, including Runway Gen-3, Luma 1.6, and three top-performing Chinese video generative models. By releasing the code for the foundation model and its applications, we aim to bridge the gap between closed-source and open-source communities. This initiative will empower individuals within the community to experiment with their ideas, fostering a more dynamic and vibrant video generation ecosystem. The code is publicly available at [this https URL](https://github.com/Tencent/HunyuanVideo).* + + + +Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers.md) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading.md#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines. + + + +Recommendations for inference: +- Both text encoders should be in `torch.float16`. +- Transformer should be in `torch.bfloat16`. +- VAE should be in `torch.float16`. +- `num_frames` should be of the form `4 * k + 1`, for example `49` or `129`. +- For smaller resolution images, try lower values of `shift` (between `2.0` to `5.0`) in the [Scheduler](https://huggingface.co/docs/diffusers/main/en/api/schedulers/flow_match_euler_discrete#diffusers.FlowMatchEulerDiscreteScheduler.shift). For larger resolution images, try higher values (between `7.0` and `12.0`). The default value is `7.0` for HunyuanVideo. + +## HunyuanVideoPipeline + +[[autodoc]] HunyuanVideoPipeline + - all + - __call__ + +## HunyuanVideoPipelineOutput + +[[autodoc]] pipelines.hunyuan_video.pipeline_output.HunyuanVideoPipelineOutput From ae0b3590075fe0170d1a563896c5985677c6de8d Mon Sep 17 00:00:00 2001 From: Aryan Date: Sun, 15 Dec 2024 05:46:33 +0100 Subject: [PATCH 50/58] update docs --- docs/source/en/api/pipelines/hunyuan_video.md | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/source/en/api/pipelines/hunyuan_video.md b/docs/source/en/api/pipelines/hunyuan_video.md index 8cb877e07a25..86ef816fcd4d 100644 --- a/docs/source/en/api/pipelines/hunyuan_video.md +++ b/docs/source/en/api/pipelines/hunyuan_video.md @@ -30,6 +30,7 @@ Recommendations for inference: - VAE should be in `torch.float16`. - `num_frames` should be of the form `4 * k + 1`, for example `49` or `129`. - For smaller resolution images, try lower values of `shift` (between `2.0` to `5.0`) in the [Scheduler](https://huggingface.co/docs/diffusers/main/en/api/schedulers/flow_match_euler_discrete#diffusers.FlowMatchEulerDiscreteScheduler.shift). For larger resolution images, try higher values (between `7.0` and `12.0`). The default value is `7.0` for HunyuanVideo. +- For more information about supported resolutions and other details, please refer to the original repository [here](https://github.com/Tencent/HunyuanVideo/). ## HunyuanVideoPipeline From eee00ab09eef7b4cba4507e03cd4976448a992d1 Mon Sep 17 00:00:00 2001 From: Aryan Date: Sun, 15 Dec 2024 05:46:53 +0100 Subject: [PATCH 51/58] apply suggestions from review --- src/diffusers/models/activations.py | 24 +++++------ src/diffusers/models/attention.py | 6 +-- .../autoencoder_kl_hunyuan_video.py | 40 +++++++++---------- .../transformers/transformer_hunyuan_video.py | 8 ++-- .../hunyuan_video/pipeline_hunyuan_video.py | 4 ++ 5 files changed, 43 insertions(+), 39 deletions(-) diff --git a/src/diffusers/models/activations.py b/src/diffusers/models/activations.py index e9a31f89e870..c61baefa08f4 100644 --- a/src/diffusers/models/activations.py +++ b/src/diffusers/models/activations.py @@ -146,18 +146,6 @@ def forward(self, hidden_states): return hidden_states * self.activation(gate) -class SiLU(nn.Module): - def __init__(self, dim_in: int, dim_out: int, bias: bool = True): - super().__init__() - - self.proj = nn.Linear(dim_in, dim_out, bias=bias) - self.activation = nn.SiLU() - - def forward(self, hidden_states): - hidden_states = self.proj(hidden_states) - return self.activation(hidden_states) - - class ApproximateGELU(nn.Module): r""" The approximate form of the Gaussian Error Linear Unit (GELU). For more details, see section 2 of this @@ -176,3 +164,15 @@ def __init__(self, dim_in: int, dim_out: int, bias: bool = True): def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.proj(x) return x * torch.sigmoid(1.702 * x) + + +class LinearActivation(nn.Module): + def __init__(self, dim_in: int, dim_out: int, bias: bool = True, activation: str = "silu"): + super().__init__() + + self.proj = nn.Linear(dim_in, dim_out, bias=bias) + self.activation = get_activation(activation) + + def forward(self, hidden_states): + hidden_states = self.proj(hidden_states) + return self.activation(hidden_states) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 2666ffe94528..6749c7f17254 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -19,7 +19,7 @@ from ..utils import deprecate, logging from ..utils.torch_utils import maybe_allow_in_graph -from .activations import GEGLU, GELU, ApproximateGELU, FP32SiLU, SiLU, SwiGLU +from .activations import GEGLU, GELU, ApproximateGELU, FP32SiLU, LinearActivation, SwiGLU from .attention_processor import Attention, JointAttnProcessor2_0 from .embeddings import SinusoidalPositionalEmbedding from .normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero, RMSNorm, SD35AdaLayerNormZeroX @@ -1222,8 +1222,8 @@ def __init__( act_fn = ApproximateGELU(dim, inner_dim, bias=bias) elif activation_fn == "swiglu": act_fn = SwiGLU(dim, inner_dim, bias=bias) - elif activation_fn == "silu": - act_fn = SiLU(dim, inner_dim, bias=bias) + elif activation_fn == "linear-silu": + act_fn = LinearActivation(dim, inner_dim, bias=bias, activation="silu") self.net = nn.ModuleList([]) # project in diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py b/src/diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py index 496677671c5d..3e9b3a5be48e 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py @@ -35,7 +35,7 @@ def prepare_causal_attention_mask( num_frames: int, height_width: int, dtype: torch.dtype, device: torch.device, batch_size: int = None -): +) -> torch.Tensor: seq_len = num_frames * height_width mask = torch.full((seq_len, seq_len), float("-inf"), dtype=dtype, device=device) for i in range(seq_len): @@ -46,7 +46,7 @@ def prepare_causal_attention_mask( return mask -class CausalConv3d(nn.Module): +class HunyuanVideoCausalConv3d(nn.Module): def __init__( self, in_channels: int, @@ -79,7 +79,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: return self.conv(hidden_states) -class UpsampleCausal3D(nn.Module): +class HunyuanVideoUpsampleCausal3D(nn.Module): def __init__( self, in_channels: int, @@ -94,7 +94,7 @@ def __init__( out_channels = out_channels or in_channels self.upsample_factor = upsample_factor - self.conv = CausalConv3d(in_channels, out_channels, kernel_size, stride, bias=bias) + self.conv = HunyuanVideoCausalConv3d(in_channels, out_channels, kernel_size, stride, bias=bias) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: num_frames = hidden_states.size(2) @@ -114,7 +114,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: return hidden_states -class DownsampleCausal3D(nn.Module): +class HunyuanVideoDownsampleCausal3D(nn.Module): def __init__( self, channels: int, @@ -127,14 +127,14 @@ def __init__( super().__init__() out_channels = out_channels or channels - self.conv = CausalConv3d(channels, out_channels, kernel_size, stride, padding, bias=bias) + self.conv = HunyuanVideoCausalConv3d(channels, out_channels, kernel_size, stride, padding, bias=bias) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states = self.conv(hidden_states) return hidden_states -class ResnetBlockCausal3D(nn.Module): +class HunyuanVideoResnetBlockCausal3D(nn.Module): def __init__( self, in_channels: int, @@ -150,15 +150,15 @@ def __init__( self.nonlinearity = get_activation(non_linearity) self.norm1 = nn.GroupNorm(groups, in_channels, eps=eps, affine=True) - self.conv1 = CausalConv3d(in_channels, out_channels, 3, 1, 0) + self.conv1 = HunyuanVideoCausalConv3d(in_channels, out_channels, 3, 1, 0) self.norm2 = nn.GroupNorm(groups, out_channels, eps=eps, affine=True) self.dropout = torch.nn.Dropout(dropout) - self.conv2 = CausalConv3d(out_channels, out_channels, 3, 1, 0) + self.conv2 = HunyuanVideoCausalConv3d(out_channels, out_channels, 3, 1, 0) self.conv_shortcut = None if in_channels != out_channels: - self.conv_shortcut = CausalConv3d(in_channels, out_channels, 1, 1, 0) + self.conv_shortcut = HunyuanVideoCausalConv3d(in_channels, out_channels, 1, 1, 0) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: residual = hidden_states @@ -197,7 +197,7 @@ def __init__( # There is always at least one resnet resnets = [ - ResnetBlockCausal3D( + HunyuanVideoResnetBlockCausal3D( in_channels=in_channels, out_channels=in_channels, eps=resnet_eps, @@ -227,7 +227,7 @@ def __init__( attentions.append(None) resnets.append( - ResnetBlockCausal3D( + HunyuanVideoResnetBlockCausal3D( in_channels=in_channels, out_channels=in_channels, eps=resnet_eps, @@ -312,7 +312,7 @@ def __init__( for i in range(num_layers): in_channels = in_channels if i == 0 else out_channels resnets.append( - ResnetBlockCausal3D( + HunyuanVideoResnetBlockCausal3D( in_channels=in_channels, out_channels=out_channels, eps=resnet_eps, @@ -327,7 +327,7 @@ def __init__( if add_downsample: self.downsamplers = nn.ModuleList( [ - DownsampleCausal3D( + HunyuanVideoDownsampleCausal3D( out_channels, out_channels=out_channels, padding=downsample_padding, @@ -389,7 +389,7 @@ def __init__( input_channels = in_channels if i == 0 else out_channels resnets.append( - ResnetBlockCausal3D( + HunyuanVideoResnetBlockCausal3D( in_channels=input_channels, out_channels=out_channels, eps=resnet_eps, @@ -404,7 +404,7 @@ def __init__( if add_upsample: self.upsamplers = nn.ModuleList( [ - UpsampleCausal3D( + HunyuanVideoUpsampleCausal3D( out_channels, out_channels=out_channels, upsample_factor=upsample_scale_factor, @@ -472,7 +472,7 @@ def __init__( ) -> None: super().__init__() - self.conv_in = CausalConv3d(in_channels, block_out_channels[0], kernel_size=3, stride=1) + self.conv_in = HunyuanVideoCausalConv3d(in_channels, block_out_channels[0], kernel_size=3, stride=1) self.mid_block = None self.down_blocks = nn.ModuleList([]) @@ -529,7 +529,7 @@ def __init__( self.conv_act = nn.SiLU() conv_out_channels = 2 * out_channels if double_z else out_channels - self.conv_out = CausalConv3d(block_out_channels[-1], conv_out_channels, kernel_size=3) + self.conv_out = HunyuanVideoCausalConv3d(block_out_channels[-1], conv_out_channels, kernel_size=3) self.gradient_checkpointing = False @@ -596,7 +596,7 @@ def __init__( super().__init__() self.layers_per_block = layers_per_block - self.conv_in = CausalConv3d(in_channels, block_out_channels[-1], kernel_size=3, stride=1) + self.conv_in = HunyuanVideoCausalConv3d(in_channels, block_out_channels[-1], kernel_size=3, stride=1) self.mid_block = None self.up_blocks = nn.ModuleList([]) @@ -652,7 +652,7 @@ def __init__( # out self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6) self.conv_act = nn.SiLU() - self.conv_out = CausalConv3d(block_out_channels[0], out_channels, kernel_size=3) + self.conv_out = HunyuanVideoCausalConv3d(block_out_channels[0], out_channels, kernel_size=3) self.gradient_checkpointing = False diff --git a/src/diffusers/models/transformers/transformer_hunyuan_video.py b/src/diffusers/models/transformers/transformer_hunyuan_video.py index 51b51f739075..9bb6ca3525b3 100644 --- a/src/diffusers/models/transformers/transformer_hunyuan_video.py +++ b/src/diffusers/models/transformers/transformer_hunyuan_video.py @@ -121,11 +121,11 @@ def __call__( hidden_states[:, -encoder_hidden_states.shape[1] :], ) - if not attn.pre_only: + if getattr(attn, "to_out", None) is not None: hidden_states = attn.to_out[0](hidden_states) hidden_states = attn.to_out[1](hidden_states) - if attn.context_pre_only is not None and not attn.context_pre_only: + if getattr(attn, "to_add_out", None) is not None: encoder_hidden_states = attn.to_add_out(encoder_hidden_states) return hidden_states, encoder_hidden_states @@ -189,7 +189,7 @@ def __init__( ) self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6) - self.mlp = FeedForward(hidden_size, mult=mlp_width_ratio, activation_fn="silu", dropout=mlp_drop_rate) + self.ff = FeedForward(hidden_size, mult=mlp_width_ratio, activation_fn="linear-silu", dropout=mlp_drop_rate) self.norm_out = HunyuanVideoAdaNorm(hidden_size, 2 * hidden_size) @@ -210,7 +210,7 @@ def forward( gate_msa, gate_mlp = self.norm_out(temb) hidden_states = hidden_states + attn_output * gate_msa - ff_output = self.mlp(self.norm2(hidden_states)) + ff_output = self.ff(self.norm2(hidden_states)) hidden_states = hidden_states + ff_output * gate_mlp return hidden_states diff --git a/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py b/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py index a6611d32e643..577f622a8576 100644 --- a/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +++ b/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py @@ -15,6 +15,7 @@ import inspect from typing import Any, Callable, Dict, List, Optional, Tuple, Union +import numpy as np import torch from transformers import CLIPTextModel, CLIPTokenizer, LlamaModel, LlamaTokenizerFast @@ -595,12 +596,14 @@ def __call__( pooled_prompt_embeds = pooled_prompt_embeds.to(transformer_dtype) # 4. Prepare timesteps + sigmas = np.linspace(1.0, 0.0, num_inference_steps + 1)[:-1] if sigmas is None else sigmas timesteps, num_inference_steps = retrieve_timesteps( self.scheduler, num_inference_steps, device, sigmas=sigmas, ) + print(self.scheduler.sigmas) # 5. Prepare latent variables num_channels_latents = self.transformer.config.in_channels @@ -617,6 +620,7 @@ def __call__( latents, ) + # 6. Prepare guidance condition guidance = torch.tensor([guidance_scale] * latents.shape[0], dtype=transformer_dtype, device=device) * 1000.0 # 7. Denoising loop From 04614753d0da1edf7f811ae6b4dae4ed4fa9919e Mon Sep 17 00:00:00 2001 From: Aryan Date: Sun, 15 Dec 2024 07:17:57 +0100 Subject: [PATCH 52/58] refactor vae --- scripts/convert_hunyuan_video_to_diffusers.py | 1 + .../autoencoder_kl_hunyuan_video.py | 98 +++++++++++-------- .../hunyuan_video/pipeline_hunyuan_video.py | 1 - 3 files changed, 58 insertions(+), 42 deletions(-) diff --git a/scripts/convert_hunyuan_video_to_diffusers.py b/scripts/convert_hunyuan_video_to_diffusers.py index 13ced51240d3..464c9e0fb954 100644 --- a/scripts/convert_hunyuan_video_to_diffusers.py +++ b/scripts/convert_hunyuan_video_to_diffusers.py @@ -28,6 +28,7 @@ def rename_key(key): new_key = new_key.replace("t_embedder.mlp.0", "time_text_embed.timestep_embedder.linear_1") new_key = new_key.replace("t_embedder.mlp.2", "time_text_embed.timestep_embedder.linear_2") new_key = new_key.replace("c_embedder", "time_text_embed.text_embedder") + new_key = new_key.replace("mlp", "ff") return new_key if "self_attn_qkv" in key: diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py b/src/diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py index 3e9b3a5be48e..e23c15e00422 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py @@ -105,6 +105,13 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: ).unsqueeze(2) if num_frames > 1: + # See: https://github.com/pytorch/pytorch/issues/81665 + # Unless you have a version of pytorch where non-contiguous implementation of F.interpolate + # is fixed, this will raise either a runtime error, or fail silently with bad outputs. + # If you are encountering an error here, make sure to try running encoding/decoding with + # `vae.enable_tiling()` first. If that doesn't work, open an issue at: + # https://github.com/huggingface/diffusers/issues + other_frames = other_frames.contiguous() other_frames = F.interpolate(other_frames, scale_factor=self.upsample_factor, mode="nearest") hidden_states = torch.cat((first_frame, other_frames), dim=2) else: @@ -786,14 +793,12 @@ def __init__( # The minimal tile height and width for spatial tiling to be used self.tile_sample_min_height = 256 self.tile_sample_min_width = 256 - - # The minimal tile temporal batch size for temporal tiling to be used - self.tile_sample_min_tsize = 64 - self.tile_overlap_factor = 0.25 + self.tile_sample_min_num_frames = 64 # The minimal distance between two spatial tiles self.tile_sample_stride_height = 192 self.tile_sample_stride_width = 192 + self.tile_sample_stride_num_frames = 48 def _set_gradient_checkpointing(self, module, value=False): if isinstance(module, (HunyuanVideoEncoder3D, HunyuanVideoDecoder3D)): @@ -803,8 +808,10 @@ def enable_tiling( self, tile_sample_min_height: Optional[int] = None, tile_sample_min_width: Optional[int] = None, + tile_sample_min_num_frames: Optional[int] = None, tile_sample_stride_height: Optional[float] = None, tile_sample_stride_width: Optional[float] = None, + tile_sample_stride_num_frames: Optional[float] = None, ) -> None: r""" Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to @@ -816,18 +823,26 @@ def enable_tiling( The minimum height required for a sample to be separated into tiles across the height dimension. tile_sample_min_width (`int`, *optional*): The minimum width required for a sample to be separated into tiles across the width dimension. + tile_sample_min_num_frames (`int`, *optional*): + The minimum number of frames required for a sample to be separated into tiles across the frame + dimension. tile_sample_stride_height (`int`, *optional*): The minimum amount of overlap between two consecutive vertical tiles. This is to ensure that there are no tiling artifacts produced across the height dimension. tile_sample_stride_width (`int`, *optional*): The stride between two consecutive horizontal tiles. This is to ensure that there are no tiling artifacts produced across the width dimension. + tile_sample_stride_num_frames (`int`, *optional*): + The stride between two consecutive frame tiles. This is to ensure that there are no tiling artifacts + produced across the frame dimension. """ self.use_tiling = True self.tile_sample_min_height = tile_sample_min_height or self.tile_sample_min_height self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width + self.tile_sample_min_num_frames = tile_sample_min_num_frames or self.tile_sample_min_num_frames self.tile_sample_stride_height = tile_sample_stride_height or self.tile_sample_stride_height self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width + self.tile_sample_stride_num_frames = tile_sample_stride_num_frames or self.tile_sample_stride_num_frames def disable_tiling(self) -> None: r""" @@ -853,8 +868,8 @@ def disable_slicing(self) -> None: def _encode(self, x: torch.Tensor) -> torch.Tensor: batch_size, num_channels, num_frames, height, width = x.shape - if self.use_framewise_decoding and num_frames > self.tile_sample_min_tsize: - return self.temporal_tiled_encode(x) + if self.use_framewise_decoding and num_frames > self.tile_sample_min_num_frames: + return self._temporal_tiled_encode(x) if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height): return self.tiled_encode(x) @@ -895,10 +910,10 @@ def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOut batch_size, num_channels, num_frames, height, width = z.shape tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio tile_latent_min_width = self.tile_sample_stride_width // self.spatial_compression_ratio - tile_latent_min_num_frames = self.tile_sample_min_tsize // self.temporal_compression_ratio + tile_latent_min_num_frames = self.tile_sample_min_num_frames // self.temporal_compression_ratio if self.use_framewise_decoding and num_frames > tile_latent_min_num_frames: - return self.temporal_tiled_decode(z, return_dict=return_dict) + return self._temporal_tiled_decode(z, return_dict=return_dict) if self.use_tiling and (width > tile_latent_min_width or height > tile_latent_min_height): return self.tiled_decode(z, return_dict=return_dict) @@ -1069,52 +1084,51 @@ def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[Decod return (dec,) return DecoderOutput(sample=dec) - def temporal_tiled_encode(self, x: torch.Tensor) -> AutoencoderKLOutput: - B, C, T, H, W = x.shape - tile_latent_min_tsize = self.tile_sample_min_tsize // self.temporal_compression_ratio - overlap_size = int(self.tile_sample_min_tsize * (1 - self.tile_overlap_factor)) - blend_extent = int(tile_latent_min_tsize * self.tile_overlap_factor) - t_limit = tile_latent_min_tsize - blend_extent + def _temporal_tiled_encode(self, x: torch.Tensor) -> AutoencoderKLOutput: + batch_size, num_channels, num_frames, height, width = x.shape + latent_num_frames = (num_frames - 1) // self.temporal_compression_ratio + 1 + + tile_latent_min_num_frames = self.tile_sample_min_num_frames // self.temporal_compression_ratio + tile_latent_stride_num_frames = self.tile_sample_stride_num_frames // self.temporal_compression_ratio + blend_num_frames = tile_latent_min_num_frames - tile_latent_stride_num_frames - # Split the video into tiles and encode them separately. row = [] - for i in range(0, T, overlap_size): - tile = x[:, :, i : i + self.tile_sample_min_tsize + 1, :, :] - if self.use_tiling and ( - tile.shape[-1] > self.tile_sample_min_size or tile.shape[-2] > self.tile_sample_min_size - ): - tile = self.tiled_encode(tile, return_moments=True) + for i in range(0, num_frames, self.tile_sample_stride_num_frames): + tile = x[:, :, i : i + self.tile_sample_min_num_frames + 1, :, :] + if self.use_tiling and (height > self.tile_sample_min_height or width > self.tile_sample_min_width): + tile = self.tiled_encode(tile) else: tile = self.encoder(tile) tile = self.quant_conv(tile) if i > 0: tile = tile[:, :, 1:, :, :] row.append(tile) + result_row = [] for i, tile in enumerate(row): if i > 0: - tile = self.blend_t(row[i - 1], tile, blend_extent) - result_row.append(tile[:, :, :t_limit, :, :]) + tile = self.blend_t(row[i - 1], tile, blend_num_frames) + result_row.append(tile[:, :, :tile_latent_stride_num_frames, :, :]) else: - result_row.append(tile[:, :, : t_limit + 1, :, :]) + result_row.append(tile[:, :, : tile_latent_stride_num_frames + 1, :, :]) - enc = torch.cat(result_row, dim=2) + enc = torch.cat(result_row, dim=2)[:, :, :latent_num_frames] return enc - def temporal_tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]: - # Split z into overlapping tiles and decode them separately. + def _temporal_tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]: + batch_size, num_channels, num_frames, height, width = z.shape + num_sample_frames = (num_frames - 1) * self.temporal_compression_ratio + 1 - B, C, T, H, W = z.shape - tile_latent_min_tsize = self.tile_sample_min_tsize // self.temporal_compression_ratio - overlap_size = int(tile_latent_min_tsize * (1 - self.tile_overlap_factor)) - blend_extent = int(self.tile_sample_min_tsize * self.tile_overlap_factor) - t_limit = self.tile_sample_min_tsize - blend_extent - tile_latent_min_size = self.tile_sample_min_size // self.spatial_compression_ratio + tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio + tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio + tile_latent_min_num_frames = self.tile_sample_min_num_frames // self.temporal_compression_ratio + tile_latent_stride_num_frames = self.tile_sample_stride_num_frames // self.temporal_compression_ratio + blend_num_frames = self.tile_sample_min_num_frames - self.tile_sample_stride_num_frames row = [] - for i in range(0, T, overlap_size): - tile = z[:, :, i : i + tile_latent_min_tsize + 1, :, :] - if self.use_tiling and (tile.shape[-1] > tile_latent_min_size or tile.shape[-2] > tile_latent_min_size): + for i in range(0, num_frames, tile_latent_stride_num_frames): + tile = z[:, :, i : i + tile_latent_min_num_frames + 1, :, :] + if self.use_tiling and (tile.shape[-1] > tile_latent_min_width or tile.shape[-2] > tile_latent_min_height): decoded = self.tiled_decode(tile, return_dict=True).sample else: tile = self.post_quant_conv(tile) @@ -1122,18 +1136,20 @@ def temporal_tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Un if i > 0: decoded = decoded[:, :, 1:, :, :] row.append(decoded) + result_row = [] for i, tile in enumerate(row): if i > 0: - tile = self.blend_t(row[i - 1], tile, blend_extent) - result_row.append(tile[:, :, :t_limit, :, :]) + tile = self.blend_t(row[i - 1], tile, blend_num_frames) + result_row.append(tile[:, :, : self.tile_sample_stride_num_frames, :, :]) else: - result_row.append(tile[:, :, : t_limit + 1, :, :]) + result_row.append(tile[:, :, : self.tile_sample_stride_num_frames + 1, :, :]) + + print("this:", torch.cat(result_row, dim=2).shape) + dec = torch.cat(result_row, dim=2)[:, :, :num_sample_frames] - dec = torch.cat(result_row, dim=2) if not return_dict: return (dec,) - return DecoderOutput(sample=dec) def forward( diff --git a/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py b/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py index 577f622a8576..79bd6e58fb03 100644 --- a/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +++ b/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py @@ -603,7 +603,6 @@ def __call__( device, sigmas=sigmas, ) - print(self.scheduler.sigmas) # 5. Prepare latent variables num_channels_latents = self.transformer.config.in_channels From edfc64bc24afb88c8bf62d0b006908379007a0b1 Mon Sep 17 00:00:00 2001 From: Aryan Date: Sun, 15 Dec 2024 07:42:41 +0100 Subject: [PATCH 53/58] update --- src/diffusers/models/attention_processor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index d205bef384c6..e5903d09ce54 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -3950,7 +3950,7 @@ def apply_rotary_emb(x, freqs_cos, freqs_sin): # dropout hidden_states = attn.to_out[1](hidden_states) - if attn.context_pre_only is not None and not attn.context_pre_only: + if getattr(attn, "to_add_out", None) is not None: encoder_hidden_states = attn.to_add_out(encoder_hidden_states) return hidden_states, encoder_hidden_states From bfe9c4628bbba2822d01d346206c4cf60ae4e173 Mon Sep 17 00:00:00 2001 From: Aryan Date: Sun, 15 Dec 2024 20:03:46 +0530 Subject: [PATCH 54/58] Apply suggestions from code review Co-authored-by: hlky --- .../models/autoencoders/autoencoder_kl_hunyuan_video.py | 4 +--- .../pipelines/hunyuan_video/pipeline_hunyuan_video.py | 5 +---- 2 files changed, 2 insertions(+), 7 deletions(-) diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py b/src/diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py index e23c15e00422..bded90a8bcff 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py @@ -160,7 +160,7 @@ def __init__( self.conv1 = HunyuanVideoCausalConv3d(in_channels, out_channels, 3, 1, 0) self.norm2 = nn.GroupNorm(groups, out_channels, eps=eps, affine=True) - self.dropout = torch.nn.Dropout(dropout) + self.dropout = nn.Dropout(dropout) self.conv2 = HunyuanVideoCausalConv3d(out_channels, out_channels, 3, 1, 0) self.conv_shortcut = None @@ -604,7 +604,6 @@ def __init__( self.layers_per_block = layers_per_block self.conv_in = HunyuanVideoCausalConv3d(in_channels, block_out_channels[-1], kernel_size=3, stride=1) - self.mid_block = None self.up_blocks = nn.ModuleList([]) # mid @@ -1145,7 +1144,6 @@ def _temporal_tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> U else: result_row.append(tile[:, :, : self.tile_sample_stride_num_frames + 1, :, :]) - print("this:", torch.cat(result_row, dim=2).shape) dec = torch.cat(result_row, dim=2)[:, :, :num_sample_frames] if not return_dict: diff --git a/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py b/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py index 79bd6e58fb03..a37ec34efc53 100644 --- a/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +++ b/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py @@ -21,7 +21,7 @@ from ...callbacks import MultiPipelineCallbacks, PipelineCallback from ...models import AutoencoderKLHunyuanVideo, HunyuanVideoTransformer3DModel -from ...schedulers import KarrasDiffusionSchedulers +from ...schedulers import FlowMatchEulerDiscreteScheduler from ...utils import logging, replace_example_docstring from ...utils.torch_utils import randn_tensor from ...video_processor import VideoProcessor @@ -149,9 +149,6 @@ class HunyuanVideoPipeline(DiffusionPipeline): A scheduler to be used in combination with `transformer` to denoise the encoded image latents. vae ([`AutoencoderKLHunyuanVideo`]): Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. - text_encoder_2 ([`T5EncoderModel`]): - [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically - the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant. text_encoder_2 ([`CLIPTextModel`]): [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. From f906aa86ebaec633cf413dae039f7b8c8c5e716f Mon Sep 17 00:00:00 2001 From: Aryan Date: Sun, 15 Dec 2024 20:04:38 +0530 Subject: [PATCH 55/58] Update src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py Co-authored-by: hlky --- src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py b/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py index a37ec34efc53..ca111b6921b3 100644 --- a/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +++ b/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py @@ -166,7 +166,7 @@ def __init__( tokenizer: LlamaTokenizerFast, transformer: HunyuanVideoTransformer3DModel, vae: AutoencoderKLHunyuanVideo, - scheduler: KarrasDiffusionSchedulers, + scheduler: FlowMatchEulerDiscreteScheduler, text_encoder_2: CLIPTextModel, tokenizer_2: CLIPTokenizer, ): From 9795469e1fbac7f57a2ec673af28813ee8f11f97 Mon Sep 17 00:00:00 2001 From: Aryan Date: Sun, 15 Dec 2024 20:04:47 +0530 Subject: [PATCH 56/58] Update src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py Co-authored-by: hlky --- src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py b/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py index ca111b6921b3..331b2e5cf471 100644 --- a/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +++ b/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py @@ -72,6 +72,7 @@ } +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps def retrieve_timesteps( scheduler, num_inference_steps: Optional[int] = None, From 6867c529d08d3ea573048e60d7facfead597ffb3 Mon Sep 17 00:00:00 2001 From: Aryan Date: Sun, 15 Dec 2024 16:11:40 +0100 Subject: [PATCH 57/58] make fix-copies --- src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py b/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py index 331b2e5cf471..bd3d3c1e8485 100644 --- a/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +++ b/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py @@ -81,7 +81,7 @@ def retrieve_timesteps( sigmas: Optional[List[float]] = None, **kwargs, ): - """ + r""" Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. From ce7b0b9983d4039e42c2060da04248e45e54ff33 Mon Sep 17 00:00:00 2001 From: Aryan Date: Sun, 15 Dec 2024 16:12:10 +0100 Subject: [PATCH 58/58] update --- .../models/transformers/transformer_hunyuan_video.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_hunyuan_video.py b/src/diffusers/models/transformers/transformer_hunyuan_video.py index 9bb6ca3525b3..d8f9834ea61c 100644 --- a/src/diffusers/models/transformers/transformer_hunyuan_video.py +++ b/src/diffusers/models/transformers/transformer_hunyuan_video.py @@ -131,7 +131,7 @@ def __call__( return hidden_states, encoder_hidden_states -class PatchEmbed(nn.Module): +class HunyuanVideoPatchEmbed(nn.Module): def __init__( self, patch_size: Union[int, Tuple[int, int, int]] = 16, @@ -523,7 +523,7 @@ def __init__( out_channels = out_channels or in_channels # 1. Latent and condition embedders - self.x_embedder = PatchEmbed((patch_size_t, patch_size, patch_size), in_channels, inner_dim) + self.x_embedder = HunyuanVideoPatchEmbed((patch_size_t, patch_size, patch_size), in_channels, inner_dim) self.context_embedder = HunyuanVideoTokenRefiner( text_embed_dim, num_attention_heads, attention_head_dim, num_layers=num_refiner_layers )