From b28f89db3c02871cd8deb49e84d39ff6127b5953 Mon Sep 17 00:00:00 2001 From: Aryan Date: Tue, 26 Nov 2024 00:02:55 +0100 Subject: [PATCH 01/51] transformer --- src/diffusers/__init__.py | 2 + src/diffusers/models/__init__.py | 2 + src/diffusers/models/transformers/__init__.py | 1 + .../models/transformers/transformer_ltx.py | 288 ++++++++++++++++++ 4 files changed, 293 insertions(+) create mode 100644 src/diffusers/models/transformers/transformer_ltx.py diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index a4749af5f61b..1ea7e391a182 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -101,6 +101,7 @@ "HunyuanDiT2DMultiControlNetModel", "I2VGenXLUNet", "Kandinsky3UNet", + "LTXTransformerModel3D", "LatteTransformer3DModel", "LuminaNextDiT2DModel", "MochiTransformer3DModel", @@ -592,6 +593,7 @@ HunyuanDiT2DMultiControlNetModel, I2VGenXLUNet, Kandinsky3UNet, + LTXTransformer3DModel, LatteTransformer3DModel, LuminaNextDiT2DModel, MochiTransformer3DModel, diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index 65e2418ac794..be279d7882cf 100644 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -63,6 +63,7 @@ _import_structure["transformers.transformer_allegro"] = ["AllegroTransformer3DModel"] _import_structure["transformers.transformer_cogview3plus"] = ["CogView3PlusTransformer2DModel"] _import_structure["transformers.transformer_flux"] = ["FluxTransformer2DModel"] + _import_structure["transformers.transformer_ltx"] = ["LTXTransformerModel3D"] _import_structure["transformers.transformer_mochi"] = ["MochiTransformer3DModel"] _import_structure["transformers.transformer_sd3"] = ["SD3Transformer2DModel"] _import_structure["transformers.transformer_temporal"] = ["TransformerTemporalModel"] @@ -122,6 +123,7 @@ DualTransformer2DModel, FluxTransformer2DModel, HunyuanDiT2DModel, + LTXTransformer3DModel, LatteTransformer3DModel, LuminaNextDiT2DModel, MochiTransformer3DModel, diff --git a/src/diffusers/models/transformers/__init__.py b/src/diffusers/models/transformers/__init__.py index a2c087d708a4..895c3aac4dcc 100644 --- a/src/diffusers/models/transformers/__init__.py +++ b/src/diffusers/models/transformers/__init__.py @@ -17,6 +17,7 @@ from .transformer_allegro import AllegroTransformer3DModel from .transformer_cogview3plus import CogView3PlusTransformer2DModel from .transformer_flux import FluxTransformer2DModel + from .transformer_ltx import LTXTransformer3DModel from .transformer_mochi import MochiTransformer3DModel from .transformer_sd3 import SD3Transformer2DModel from .transformer_temporal import TransformerTemporalModel diff --git a/src/diffusers/models/transformers/transformer_ltx.py b/src/diffusers/models/transformers/transformer_ltx.py new file mode 100644 index 000000000000..130ca70a61e9 --- /dev/null +++ b/src/diffusers/models/transformers/transformer_ltx.py @@ -0,0 +1,288 @@ +# Copyright 2024 The Genmo team and The HuggingFace Team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Dict, Optional, Tuple + +import torch +import torch.nn as nn + +from ...configuration_utils import ConfigMixin, register_to_config +from ...loaders import PeftAdapterMixin +from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers +from ...utils.torch_utils import maybe_allow_in_graph +from ..attention import FeedForward +from ..attention_processor import Attention, MochiAttnProcessor2_0 +from ..embeddings import PixArtAlphaTextProjection +from ..modeling_outputs import Transformer2DModelOutput +from ..modeling_utils import ModelMixin +from ..normalization import AdaLayerNormSingle, RMSNorm + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +@maybe_allow_in_graph +class LTXTransformerBlock(nn.Module): + r""" + Transformer block used in [LTX](https://huggingface.co/Lightricks/LTX-Video). + + Args: + TODO(aryan) + """ + + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + cross_attention_dim: int, + qk_norm: str = "rms_norm", + activation_fn: str = "gelu-approximate", + attention_bias: bool = True, + attention_out_bias: bool = True, + eps: float = 1e-6, + elementwise_affine: bool = False, + ): + super().__init__() + + self.norm1 = RMSNorm(dim, eps=eps, elementwise_affine=elementwise_affine) + + self.attn1 = Attention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + bias=attention_bias, + cross_attention_dim=None, + out_bias=attention_out_bias, + qk_norm=qk_norm, + ) + + self.norm2 = RMSNorm(dim, eps=eps, elementwise_affine=elementwise_affine) + + self.attn2 = Attention( + query_dim=dim, + cross_attention_dim=cross_attention_dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + bias=attention_bias, + out_bias=attention_out_bias, + qk_norm=qk_norm, + ) + + self.ff = FeedForward(dim, activation_fn=activation_fn) + + self.scale_shift_table = nn.Parameter( + torch.randn(6, dim) / dim**0.5 + ) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + temb: torch.Tensor, + image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + batch_size = hidden_states.size(0) + norm_hidden_states = self.norm1(hidden_states) + + num_ada_params = self.scale_shift_table.shape[0] + ada_values = self.scale_shift_table[None, None] + temb.reshape( + batch_size, temb.size(1), num_ada_params, -1 + ) + + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( + ada_values.unbind(dim=2) + ) + norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa + + attn_hidden_states = self.attn1( + hidden_states=norm_hidden_states, + encoder_hidden_states=None, + image_rotary_emb=image_rotary_emb, + ) + hidden_states = hidden_states + attn_hidden_states * gate_msa + + attn_hidden_states = self.attn2( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + image_rotary_emb=image_rotary_emb, + attention_mask=encoder_attention_mask, + ) + hidden_states = hidden_states + attn_hidden_states + norm_hidden_states = self.norm2(hidden_states) * (1 + scale_mlp) + shift_mlp + + ff_output = self.ff(norm_hidden_states) + hidden_states = hidden_states + ff_output * gate_mlp + + return hidden_states + + +@maybe_allow_in_graph +class LTXTransformer3DModel(ModelMixin, ConfigMixin): + r""" + A Transformer model for video-like data used in [LTX](https://huggingface.co/Lightricks/LTX-Video). + + Args: + TODO(aryan) + """ + + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + in_channels: int = None, + out_channels: Optional[int] = None, + patch_size: int = 1, + patch_size_t: int = 1, + num_attention_heads: int = 32, + attention_head_dim: int = 64, + cross_attention_dim: int = 2048, + num_layers: int = 28, + activation_fn: str = "gelu-approximate", + qk_norm: str = "rms_norm", + norm_elementwise_affine: bool = True, + norm_eps: float = 1e-6, + caption_channels: int = 4096, + attention_bias: bool = True, + attention_out_bias: bool = True, + ) -> None: + out_channels = out_channels or in_channels + inner_dim = num_attention_heads * attention_head_dim + + self.patchify_proj = nn.Linear(in_channels, inner_dim) + + self.transformer_blocks = nn.Modulelist([ + LTXTransformerBlock( + dim=inner_dim, + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + cross_attention_dim=cross_attention_dim, + qk_norm=qk_norm, + activation_fn=activation_fn, + attention_bias=attention_bias, + attention_out_bias=attention_out_bias, + eps=norm_eps, + elementwise_affine=norm_elementwise_affine, + ) + for _ in range(num_layers) + ]) + + self.norm_out = nn.LayerNorm(inner_dim, eps=1e-6, elementwise_affine=False) + self.proj_out = nn.Linear(inner_dim, self.out_channels) + + self.scale_shift_table = nn.Parameter( + torch.randn(2, inner_dim) / inner_dim**0.5 + ) + self.adaln_single = AdaLayerNormSingle(inner_dim, use_additional_conditions=False) + + self.caption_projection = PixArtAlphaTextProjection( + in_features=caption_channels, hidden_size=inner_dim + ) + + self.gradient_checkpointing = False + + def _set_gradient_checkpointing(self, module, value=False): + if hasattr(module, "gradient_checkpointing"): + module.gradient_checkpointing = value + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + timestep: torch.LongTensor, + encoder_attention_mask: torch.Tensor, + image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]], + return_dict: bool = True, + ) -> torch.Tensor: + # convert encoder_attention_mask to a bias the same way we do for attention_mask + if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2: + encoder_attention_mask = ( + 1 - encoder_attention_mask.to(hidden_states.dtype) + ) * -10000.0 + encoder_attention_mask = encoder_attention_mask.unsqueeze(1) + + batch_size, num_channels, num_frames, height, width = hidden_states.shape + p = self.config.patch_size + p_t = self.config.patch_size_t + + post_patch_height = height // p + post_patch_width = width // p + post_patch_num_frames = num_frames // p_t + + hidden_states = hidden_states.reshape(batch_size, -1, post_patch_num_frames, p_t, post_patch_height, p, post_patch_width, p) + hidden_states = hidden_states.permute(0, 2, 4, 6, 1, 3, 5, 7).flatten(4, 7).flatten(1, 3) + hidden_states = self.patchify_proj(hidden_states) + + temb, embedded_timestep = self.adaln_single( + timestep.flatten(), + {"resolution": None, "aspect_ratio": None}, + batch_size=batch_size, + hidden_dtype=hidden_states.dtype, + ) + + temb = temb.view(batch_size, -1, temb.size(-1)) + embedded_timestep = embedded_timestep.view(batch_size, -1, embedded_timestep.size(-1)) + + encoder_hidden_states = self.caption_projection(encoder_hidden_states) + encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.size(-1)) + + for block in self.transformer_blocks: + if torch.is_grad_enabled() and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + encoder_hidden_states, + temb, + image_rotary_emb, + encoder_attention_mask, + **ckpt_kwargs, + ) + else: + hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + temb=temb, + image_rotary_emb=image_rotary_emb, + encoder_attention_mask=encoder_attention_mask, + ) + + scale_shift_values = ( + self.scale_shift_table[None, None] + embedded_timestep[:, :, None] + ) + shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1] + + hidden_states = self.norm_out(hidden_states) + hidden_states = hidden_states * (1 + scale) + shift + hidden_states = self.proj_out(hidden_states) + + hidden_states = hidden_states.reshape(batch_size, post_patch_num_frames, post_patch_height, post_patch_width, -1, p_t, p, p) + output = hidden_states.permute(0, 4, 1, 5, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(2, 3) + + if not return_dict: + return (output,) + return Transformer2DModelOutput(sample=output) From f082cc85880ef2f7104147ddac41b1bb2879e323 Mon Sep 17 00:00:00 2001 From: Aryan Date: Tue, 26 Nov 2024 00:03:16 +0100 Subject: [PATCH 02/51] make style & make fix-copies --- src/diffusers/__init__.py | 4 +- src/diffusers/models/__init__.py | 2 +- .../models/transformers/transformer_ltx.py | 89 +++++++++---------- src/diffusers/utils/dummy_pt_objects.py | 15 ++++ 4 files changed, 58 insertions(+), 52 deletions(-) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 1ea7e391a182..b509ef7cfcd4 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -101,8 +101,8 @@ "HunyuanDiT2DMultiControlNetModel", "I2VGenXLUNet", "Kandinsky3UNet", - "LTXTransformerModel3D", "LatteTransformer3DModel", + "LTXTransformerModel3D", "LuminaNextDiT2DModel", "MochiTransformer3DModel", "ModelMixin", @@ -593,8 +593,8 @@ HunyuanDiT2DMultiControlNetModel, I2VGenXLUNet, Kandinsky3UNet, - LTXTransformer3DModel, LatteTransformer3DModel, + LTXTransformer3DModel, LuminaNextDiT2DModel, MochiTransformer3DModel, ModelMixin, diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index be279d7882cf..22eb7d8485a1 100644 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -123,8 +123,8 @@ DualTransformer2DModel, FluxTransformer2DModel, HunyuanDiT2DModel, - LTXTransformer3DModel, LatteTransformer3DModel, + LTXTransformer3DModel, LuminaNextDiT2DModel, MochiTransformer3DModel, PixArtTransformer2DModel, diff --git a/src/diffusers/models/transformers/transformer_ltx.py b/src/diffusers/models/transformers/transformer_ltx.py index 130ca70a61e9..7b4cab107908 100644 --- a/src/diffusers/models/transformers/transformer_ltx.py +++ b/src/diffusers/models/transformers/transformer_ltx.py @@ -19,11 +19,10 @@ import torch.nn as nn from ...configuration_utils import ConfigMixin, register_to_config -from ...loaders import PeftAdapterMixin -from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers +from ...utils import is_torch_version, logging from ...utils.torch_utils import maybe_allow_in_graph from ..attention import FeedForward -from ..attention_processor import Attention, MochiAttnProcessor2_0 +from ..attention_processor import Attention from ..embeddings import PixArtAlphaTextProjection from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin @@ -44,7 +43,7 @@ class LTXTransformerBlock(nn.Module): def __init__( self, - dim: int, + dim: int, num_attention_heads: int, attention_head_dim: int, cross_attention_dim: int, @@ -58,7 +57,7 @@ def __init__( super().__init__() self.norm1 = RMSNorm(dim, eps=eps, elementwise_affine=elementwise_affine) - + self.attn1 = Attention( query_dim=dim, heads=num_attention_heads, @@ -83,10 +82,8 @@ def __init__( self.ff = FeedForward(dim, activation_fn=activation_fn) - self.scale_shift_table = nn.Parameter( - torch.randn(6, dim) / dim**0.5 - ) - + self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5) + def forward( self, hidden_states: torch.Tensor, @@ -97,15 +94,11 @@ def forward( ) -> torch.Tensor: batch_size = hidden_states.size(0) norm_hidden_states = self.norm1(hidden_states) - + num_ada_params = self.scale_shift_table.shape[0] - ada_values = self.scale_shift_table[None, None] + temb.reshape( - batch_size, temb.size(1), num_ada_params, -1 - ) + ada_values = self.scale_shift_table[None, None] + temb.reshape(batch_size, temb.size(1), num_ada_params, -1) - shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( - ada_values.unbind(dim=2) - ) + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ada_values.unbind(dim=2) norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa attn_hidden_states = self.attn1( @@ -114,7 +107,7 @@ def forward( image_rotary_emb=image_rotary_emb, ) hidden_states = hidden_states + attn_hidden_states * gate_msa - + attn_hidden_states = self.attn2( hidden_states, encoder_hidden_states=encoder_hidden_states, @@ -165,40 +158,38 @@ def __init__( self.patchify_proj = nn.Linear(in_channels, inner_dim) - self.transformer_blocks = nn.Modulelist([ - LTXTransformerBlock( - dim=inner_dim, - num_attention_heads=num_attention_heads, - attention_head_dim=attention_head_dim, - cross_attention_dim=cross_attention_dim, - qk_norm=qk_norm, - activation_fn=activation_fn, - attention_bias=attention_bias, - attention_out_bias=attention_out_bias, - eps=norm_eps, - elementwise_affine=norm_elementwise_affine, - ) - for _ in range(num_layers) - ]) + self.transformer_blocks = nn.Modulelist( + [ + LTXTransformerBlock( + dim=inner_dim, + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + cross_attention_dim=cross_attention_dim, + qk_norm=qk_norm, + activation_fn=activation_fn, + attention_bias=attention_bias, + attention_out_bias=attention_out_bias, + eps=norm_eps, + elementwise_affine=norm_elementwise_affine, + ) + for _ in range(num_layers) + ] + ) self.norm_out = nn.LayerNorm(inner_dim, eps=1e-6, elementwise_affine=False) self.proj_out = nn.Linear(inner_dim, self.out_channels) - self.scale_shift_table = nn.Parameter( - torch.randn(2, inner_dim) / inner_dim**0.5 - ) + self.scale_shift_table = nn.Parameter(torch.randn(2, inner_dim) / inner_dim**0.5) self.adaln_single = AdaLayerNormSingle(inner_dim, use_additional_conditions=False) - self.caption_projection = PixArtAlphaTextProjection( - in_features=caption_channels, hidden_size=inner_dim - ) + self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=inner_dim) self.gradient_checkpointing = False - + def _set_gradient_checkpointing(self, module, value=False): if hasattr(module, "gradient_checkpointing"): module.gradient_checkpointing = value - + def forward( self, hidden_states: torch.Tensor, @@ -210,9 +201,7 @@ def forward( ) -> torch.Tensor: # convert encoder_attention_mask to a bias the same way we do for attention_mask if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2: - encoder_attention_mask = ( - 1 - encoder_attention_mask.to(hidden_states.dtype) - ) * -10000.0 + encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0 encoder_attention_mask = encoder_attention_mask.unsqueeze(1) batch_size, num_channels, num_frames, height, width = hidden_states.shape @@ -223,7 +212,9 @@ def forward( post_patch_width = width // p post_patch_num_frames = num_frames // p_t - hidden_states = hidden_states.reshape(batch_size, -1, post_patch_num_frames, p_t, post_patch_height, p, post_patch_width, p) + hidden_states = hidden_states.reshape( + batch_size, -1, post_patch_num_frames, p_t, post_patch_height, p, post_patch_width, p + ) hidden_states = hidden_states.permute(0, 2, 4, 6, 1, 3, 5, 7).flatten(4, 7).flatten(1, 3) hidden_states = self.patchify_proj(hidden_states) @@ -270,17 +261,17 @@ def custom_forward(*inputs): image_rotary_emb=image_rotary_emb, encoder_attention_mask=encoder_attention_mask, ) - - scale_shift_values = ( - self.scale_shift_table[None, None] + embedded_timestep[:, :, None] - ) + + scale_shift_values = self.scale_shift_table[None, None] + embedded_timestep[:, :, None] shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1] hidden_states = self.norm_out(hidden_states) hidden_states = hidden_states * (1 + scale) + shift hidden_states = self.proj_out(hidden_states) - hidden_states = hidden_states.reshape(batch_size, post_patch_num_frames, post_patch_height, post_patch_width, -1, p_t, p, p) + hidden_states = hidden_states.reshape( + batch_size, post_patch_num_frames, post_patch_height, post_patch_width, -1, p_t, p, p + ) output = hidden_states.permute(0, 4, 1, 5, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(2, 3) if not return_dict: diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index 5091ff318f1b..69ae5f746782 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -377,6 +377,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class LTXTransformer3DModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class LuminaNextDiT2DModel(metaclass=DummyObject): _backends = ["torch"] From a25504522654cba988fc9fedd84452d22c06706d Mon Sep 17 00:00:00 2001 From: Aryan Date: Tue, 26 Nov 2024 01:52:43 +0100 Subject: [PATCH 03/51] transformer --- scripts/convert_ltx_to_diffusers.py | 91 +++++++++++++++++++ src/diffusers/__init__.py | 2 +- src/diffusers/models/__init__.py | 2 +- src/diffusers/models/attention_processor.py | 6 +- .../models/transformers/transformer_ltx.py | 78 ++++++++++++++-- 5 files changed, 168 insertions(+), 11 deletions(-) create mode 100644 scripts/convert_ltx_to_diffusers.py diff --git a/scripts/convert_ltx_to_diffusers.py b/scripts/convert_ltx_to_diffusers.py new file mode 100644 index 000000000000..cf2356f7b5df --- /dev/null +++ b/scripts/convert_ltx_to_diffusers.py @@ -0,0 +1,91 @@ +import argparse +from typing import Any, Dict + +import torch +from safetensors.torch import load_file + +from diffusers import LTXTransformer3DModel + + +TRANSFORMER_KEYS_RENAME_DICT = { + "q_norm": "norm_q", + "k_norm": "norm_k", +} + +TRANSFORMER_SPECIAL_KEYS_REMAP = {} + + +def get_state_dict(saved_dict: Dict[str, Any]) -> Dict[str, Any]: + state_dict = saved_dict + if "model" in saved_dict.keys(): + state_dict = state_dict["model"] + if "module" in saved_dict.keys(): + state_dict = state_dict["module"] + if "state_dict" in saved_dict.keys(): + state_dict = state_dict["state_dict"] + return state_dict + + +def update_state_dict_inplace(state_dict: Dict[str, Any], old_key: str, new_key: str) -> Dict[str, Any]: + state_dict[new_key] = state_dict.pop(old_key) + + +def convert_transformer( + ckpt_path: str, + dtype: torch.dtype, +): + PREFIX_KEY = "" + + original_state_dict = get_state_dict(load_file(ckpt_path)) + transformer = LTXTransformer3DModel().to(dtype=dtype) + + for key in list(original_state_dict.keys()): + new_key = key[len(PREFIX_KEY) :] + for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items(): + new_key = new_key.replace(replace_key, rename_key) + update_state_dict_inplace(original_state_dict, key, new_key) + + for key in list(original_state_dict.keys()): + for special_key, handler_fn_inplace in TRANSFORMER_SPECIAL_KEYS_REMAP.items(): + if special_key not in key: + continue + handler_fn_inplace(key, original_state_dict) + + transformer.load_state_dict(original_state_dict, strict=True) + return transformer + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--transformer_ckpt_path", type=str, default=None, help="Path to original transformer checkpoint" + ) + parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved") + parser.add_argument("--dtype", default="fp32", help="Torch dtype to save the model in.") + return parser.parse_args() + + +DTYPE_MAPPING = { + "fp32": torch.float32, + "fp16": torch.float16, + "bf16": torch.bfloat16, +} + +VARIANT_MAPPING = { + "fp32": None, + "fp16": "fp16", + "bf16": "bf16", +} + + +if __name__ == "__main__": + args = get_args() + + transformer = None + dtype = DTYPE_MAPPING[args.dtype] + + if args.transformer_ckpt_path is not None: + transformer: LTXTransformer3DModel = convert_transformer(args.transformer_ckpt_path, dtype) + + variant = VARIANT_MAPPING[args.dtype] + transformer.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB") diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index b509ef7cfcd4..78b56dc75456 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -102,7 +102,7 @@ "I2VGenXLUNet", "Kandinsky3UNet", "LatteTransformer3DModel", - "LTXTransformerModel3D", + "LTXTransformer3DModel", "LuminaNextDiT2DModel", "MochiTransformer3DModel", "ModelMixin", diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index 22eb7d8485a1..f2cb434f59f5 100644 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -63,7 +63,7 @@ _import_structure["transformers.transformer_allegro"] = ["AllegroTransformer3DModel"] _import_structure["transformers.transformer_cogview3plus"] = ["CogView3PlusTransformer2DModel"] _import_structure["transformers.transformer_flux"] = ["FluxTransformer2DModel"] - _import_structure["transformers.transformer_ltx"] = ["LTXTransformerModel3D"] + _import_structure["transformers.transformer_ltx"] = ["LTXTransformer3DModel"] _import_structure["transformers.transformer_mochi"] = ["MochiTransformer3DModel"] _import_structure["transformers.transformer_sd3"] = ["SD3Transformer2DModel"] _import_structure["transformers.transformer_temporal"] = ["TransformerTemporalModel"] diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index ffbf4a0056c6..ddf73a85f253 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -190,12 +190,16 @@ def __init__( self.norm_q = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps) self.norm_k = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps) elif qk_norm == "layer_norm_across_heads": - # Lumina applys qk norm across all heads + # Lumina applies qk norm across all heads self.norm_q = nn.LayerNorm(dim_head * heads, eps=eps) self.norm_k = nn.LayerNorm(dim_head * kv_heads, eps=eps) elif qk_norm == "rms_norm": self.norm_q = RMSNorm(dim_head, eps=eps) self.norm_k = RMSNorm(dim_head, eps=eps) + elif qk_norm == "rms_norm_across_heads": + # LTX applies qk norm across all heads + self.norm_q = RMSNorm(dim_head * heads, eps=eps) + self.norm_k = RMSNorm(dim_head * kv_heads, eps=eps) elif qk_norm == "l2": self.norm_q = LpNorm(p=2, dim=-1, eps=eps) self.norm_k = LpNorm(p=2, dim=-1, eps=eps) diff --git a/src/diffusers/models/transformers/transformer_ltx.py b/src/diffusers/models/transformers/transformer_ltx.py index 7b4cab107908..68f22a07456e 100644 --- a/src/diffusers/models/transformers/transformer_ltx.py +++ b/src/diffusers/models/transformers/transformer_ltx.py @@ -17,6 +17,7 @@ import torch import torch.nn as nn +import torch.nn.functional as F from ...configuration_utils import ConfigMixin, register_to_config from ...utils import is_torch_version, logging @@ -32,6 +33,63 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name +class LTXAttentionProcessor2_0: + r""" + Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is + used in the LTX model. It applies a normalization layer and rotary embedding on the query and key vector. + """ + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError( + "AttnAddedKVProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." + ) + + def __call__( + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + temb: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + + query = attn.to_q(hidden_states) + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + query = attn.norm_q(query) + key = attn.norm_k(key) + + query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2) + key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2) + value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2) + + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + hidden_states = hidden_states.transpose(1, 2).flatten(2, 3) + hidden_states = hidden_states.to(query.dtype) + + hidden_states = attn.to_out[0](hidden_states) + hidden_states = attn.to_out[1](hidden_states) + + return hidden_states + + @maybe_allow_in_graph class LTXTransformerBlock(nn.Module): r""" @@ -57,27 +115,29 @@ def __init__( super().__init__() self.norm1 = RMSNorm(dim, eps=eps, elementwise_affine=elementwise_affine) - self.attn1 = Attention( query_dim=dim, heads=num_attention_heads, + kv_heads=num_attention_heads, dim_head=attention_head_dim, bias=attention_bias, cross_attention_dim=None, out_bias=attention_out_bias, qk_norm=qk_norm, + processor=LTXAttentionProcessor2_0(), ) self.norm2 = RMSNorm(dim, eps=eps, elementwise_affine=elementwise_affine) - self.attn2 = Attention( query_dim=dim, cross_attention_dim=cross_attention_dim, heads=num_attention_heads, + kv_heads=num_attention_heads, dim_head=attention_head_dim, bias=attention_bias, out_bias=attention_out_bias, qk_norm=qk_norm, + processor=LTXAttentionProcessor2_0(), ) self.ff = FeedForward(dim, activation_fn=activation_fn) @@ -137,8 +197,8 @@ class LTXTransformer3DModel(ModelMixin, ConfigMixin): @register_to_config def __init__( self, - in_channels: int = None, - out_channels: Optional[int] = None, + in_channels: int = 128, + out_channels: int = 128, patch_size: int = 1, patch_size_t: int = 1, num_attention_heads: int = 32, @@ -146,19 +206,21 @@ def __init__( cross_attention_dim: int = 2048, num_layers: int = 28, activation_fn: str = "gelu-approximate", - qk_norm: str = "rms_norm", - norm_elementwise_affine: bool = True, + qk_norm: str = "rms_norm_across_heads", + norm_elementwise_affine: bool = False, norm_eps: float = 1e-6, caption_channels: int = 4096, attention_bias: bool = True, attention_out_bias: bool = True, ) -> None: + super().__init__() + out_channels = out_channels or in_channels inner_dim = num_attention_heads * attention_head_dim self.patchify_proj = nn.Linear(in_channels, inner_dim) - self.transformer_blocks = nn.Modulelist( + self.transformer_blocks = nn.ModuleList( [ LTXTransformerBlock( dim=inner_dim, @@ -177,7 +239,7 @@ def __init__( ) self.norm_out = nn.LayerNorm(inner_dim, eps=1e-6, elementwise_affine=False) - self.proj_out = nn.Linear(inner_dim, self.out_channels) + self.proj_out = nn.Linear(inner_dim, out_channels) self.scale_shift_table = nn.Parameter(torch.randn(2, inner_dim) / inner_dim**0.5) self.adaln_single = AdaLayerNormSingle(inner_dim, use_additional_conditions=False) From 36c9b4018e6f74d467d0451a8185017ce2874e25 Mon Sep 17 00:00:00 2001 From: Aryan Date: Tue, 26 Nov 2024 02:18:15 +0100 Subject: [PATCH 04/51] add transformer tests --- .../test_models_transformer_ltx.py | 80 +++++++++++++++++++ 1 file changed, 80 insertions(+) create mode 100644 tests/models/transformers/test_models_transformer_ltx.py diff --git a/tests/models/transformers/test_models_transformer_ltx.py b/tests/models/transformers/test_models_transformer_ltx.py new file mode 100644 index 000000000000..72a77a23f03c --- /dev/null +++ b/tests/models/transformers/test_models_transformer_ltx.py @@ -0,0 +1,80 @@ +# coding=utf-8 +# Copyright 2024 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch + +from diffusers import LTXTransformer3DModel +from diffusers.utils.testing_utils import enable_full_determinism, torch_device + +from ..test_modeling_common import ModelTesterMixin + + +enable_full_determinism() + + +class MochiTransformerTests(ModelTesterMixin, unittest.TestCase): + model_class = LTXTransformer3DModel + main_input_name = "hidden_states" + uses_custom_attn_processor = True + + @property + def dummy_input(self): + batch_size = 2 + num_channels = 4 + num_frames = 2 + height = 16 + width = 16 + embedding_dim = 16 + sequence_length = 16 + + hidden_states = torch.randn((batch_size, num_channels, num_frames, height, width)).to(torch_device) + encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device) + encoder_attention_mask = torch.ones((batch_size, sequence_length)).bool().to(torch_device) + timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device) + + return { + "hidden_states": hidden_states, + "encoder_hidden_states": encoder_hidden_states, + "timestep": timestep, + "encoder_attention_mask": encoder_attention_mask, + } + + @property + def input_shape(self): + return (4, 2, 16, 16) + + @property + def output_shape(self): + return (4, 2, 16, 16) + + def prepare_init_args_and_inputs_for_common(self): + init_dict = { + "in_channels": 4, + "out_channels": 4, + "num_attention_heads": 2, + "attention_head_dim": 8, + "cross_attention_dim": 16, + "num_layers": 1, + "qk_norm": "rms_norm_across_heads", + "caption_channels": 16, + } + inputs_dict = self.dummy_input + return init_dict, inputs_dict + + def test_gradient_checkpointing_is_applied(self): + expected_set = {"LTXTransformer3DModel"} + super().test_gradient_checkpointing_is_applied(expected_set=expected_set) From c3bd2e417a8fcaa092597204efe28757a389768d Mon Sep 17 00:00:00 2001 From: Aryan Date: Tue, 26 Nov 2024 14:12:49 +0100 Subject: [PATCH 05/51] 80% vae --- .../autoencoders/autoencoder_kl_cogvideox.py | 165 ++- .../models/autoencoders/autoencoder_kl_ltx.py | 958 ++++++++++++++++++ src/diffusers/models/normalization.py | 4 + .../models/transformers/transformer_ltx.py | 5 +- 4 files changed, 1109 insertions(+), 23 deletions(-) create mode 100644 src/diffusers/models/autoencoders/autoencoder_kl_ltx.py diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py b/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py index fbcb964392f9..0620ec22379f 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, Optional, Tuple, Union +from typing import Dict, List, Optional, Tuple, Union import numpy as np import torch @@ -28,6 +28,7 @@ from ..downsampling import CogVideoXDownsample3D from ..modeling_outputs import AutoencoderKLOutput from ..modeling_utils import ModelMixin +from ..normalization import RMSNorm from ..upsampling import CogVideoXUpsample3D from .vae import DecoderOutput, DiagonalGaussianDistribution @@ -35,6 +36,99 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name +class LayerNormNd(nn.LayerNorm): + def __init__( + self, + normalized_shape: Union[int, List[int], Tuple[int], torch.Size], + eps: float = 1e-5, + elementwise_affine: bool = True, + bias: bool = True, + device=None, + dtype=None, + channel_dim: int = -1, + ) -> None: + super().__init__( + normalized_shape=normalized_shape, + eps=eps, + elementwise_affine=elementwise_affine, + bias=bias, + device=device, + dtype=dtype, + ) + + self.channel_dim = channel_dim + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + if self.channel_dim != -1: + hidden_states = hidden_states.movedim(self.channel_dim, -1) + hidden_states = super().forward(hidden_states) + hidden_states = hidden_states.movedim(-1, self.channel_dim) + else: + hidden_states = super().forward(hidden_states) + + return hidden_states + + def extra_repr(self) -> str: + return f"{super().extra_repr()}, channel_dim={self.channel_dim}" + + +class RMSNormNd(RMSNorm): + def __init__( + self, + dim: int, + eps: float, + elementwise_affine: bool = True, + channel_dim: int = -1, + ) -> None: + super().__init__( + dim=dim, + eps=eps, + elementwise_affine=elementwise_affine, + ) + + self.channel_dim = channel_dim + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + if self.channel_dim != -1: + hidden_states = hidden_states.movedim(self.channel_dim, -1) + hidden_states = super().forward(hidden_states) + hidden_states = hidden_states.movedim(-1, self.channel_dim) + else: + hidden_states = super().forward(hidden_states) + + return hidden_states + + def extra_repr(self): + return f"{super().extra_repr()}, channel_dim={self.channel_dim}" + + +def _get_norm( + norm_type: str, + num_channels: int, + groups: int = 32, + eps: float = 1e-6, + elementwise_affine: bool = False, + bias: bool = True, + spatial_norm_dim: Optional[int] = None, + channel_dim: int = -1, +) -> nn.Module: + if norm_type == "group_norm": + norm = nn.GroupNorm(num_channels=num_channels, num_groups=groups, eps=eps) + elif norm_type == "layer_norm": + norm = LayerNormNd(num_channels, eps=eps, elementwise_affine=elementwise_affine, bias=bias, channel_dim=channel_dim) + elif norm_type == "rms_norm": + norm = RMSNormNd(dim=num_channels, eps=eps, elementwise_affine=elementwise_affine, channel_dim=channel_dim) + elif norm_type == "spatial_norm": + norm = CogVideoXSpatialNorm3D( + f_channels=num_channels, + zq_channels=spatial_norm_dim, + groups=groups, + ) + else: + raise ValueError("Invalid `norm_type` specified.") + return norm + + class CogVideoXSafeConv3d(nn.Conv3d): r""" A 3D convolution layer that splits the input tensor into smaller parts to avoid OOM in CogVideoX Model. @@ -83,7 +177,7 @@ def __init__( in_channels: int, out_channels: int, kernel_size: Union[int, Tuple[int, int, int]], - stride: int = 1, + stride: Union[int, Tuple[int, int, int]] = 1, dilation: int = 1, pad_mode: str = "constant", ): @@ -127,8 +221,8 @@ def fake_context_parallel_forward( else: kernel_size = self.time_kernel_size if kernel_size > 1: - cached_inputs = [conv_cache] if conv_cache is not None else [inputs[:, :, :1]] * (kernel_size - 1) - inputs = torch.cat(cached_inputs + [inputs], dim=2) + pad_left = conv_cache if conv_cache is not None else inputs[:, :, :1].repeat(1, 1, kernel_size - 1, 1, 1) + inputs = torch.cat([pad_left, inputs], dim=2) return inputs def forward(self, inputs: torch.Tensor, conv_cache: Optional[torch.Tensor] = None) -> torch.Tensor: @@ -217,6 +311,8 @@ class CogVideoXResnetBlock3D(nn.Module): Activation function to use. conv_shortcut (bool, defaults to `False`): Whether or not to use a convolution shortcut. + norm_type (`str`, defaults to `"group_norm:`): + The type of normalization layer to use. spatial_norm_dim (`int`, *optional*): The dimension to use for spatial norm if it is to be used instead of group norm. pad_mode (str, defaults to `"first"`): @@ -229,8 +325,12 @@ def __init__( out_channels: Optional[int] = None, dropout: float = 0.0, temb_channels: int = 512, + norm_type: str = "group_norm", + final_norm_type: Optional[str] = None, groups: int = 32, eps: float = 1e-6, + elementwise_affine: bool = False, + norm_bias: bool = True, non_linearity: str = "swish", conv_shortcut: bool = False, spatial_norm_dim: Optional[int] = None, @@ -246,26 +346,32 @@ def __init__( self.use_conv_shortcut = conv_shortcut self.spatial_norm_dim = spatial_norm_dim - if spatial_norm_dim is None: - self.norm1 = nn.GroupNorm(num_channels=in_channels, num_groups=groups, eps=eps) - self.norm2 = nn.GroupNorm(num_channels=out_channels, num_groups=groups, eps=eps) - else: - self.norm1 = CogVideoXSpatialNorm3D( - f_channels=in_channels, - zq_channels=spatial_norm_dim, - groups=groups, - ) - self.norm2 = CogVideoXSpatialNorm3D( - f_channels=out_channels, - zq_channels=spatial_norm_dim, - groups=groups, - ) + if spatial_norm_dim is not None and norm_type != "spatial_norm": + logger.info("`spatial_norm_dim` is specified but the `norm_type` is not \"spatial_norm\". The norm type will be overwritten.") + norm_type = "spatial_norm" + + if norm_type == "group_norm": + self.norm1 = _get_norm(norm_type, in_channels, groups, eps, channel_dim=1) + self.norm2 = _get_norm(norm_type, out_channels, groups, eps, channel_dim=1) + elif norm_type == "rms_norm": + self.norm1 = _get_norm(norm_type, out_channels, elementwise_affine=elementwise_affine, channel_dim=1) + self.norm2 = _get_norm(norm_type, out_channels, elementwise_affine=elementwise_affine, channel_dim=1) + elif norm_type == "layer_norm": + # num_channels, eps=eps, elementwise_affine=elementwise_affine, bias=bias, channel_dim=channel_dim + self.norm1 = _get_norm(norm_type, in_channels, eps=eps, elementwise_affine=elementwise_affine, bias=norm_bias, channel_dim=1) + self.norm2 = _get_norm(norm_type, in_channels, eps=eps, elementwise_affine=elementwise_affine, bias=norm_bias, channel_dim=1) + elif norm_type == "spatial_norm": + assert spatial_norm_dim is not None + self.norm1 = _get_norm(norm_type, in_channels, groups, spatial_norm_dim=spatial_norm_dim) + self.norm2 = _get_norm(norm_type, out_channels, groups, spatial_norm_dim=spatial_norm_dim) + elif norm_type is not None: + raise ValueError("Invalid `norm_type` specified.") self.conv1 = CogVideoXCausalConv3d( in_channels=in_channels, out_channels=out_channels, kernel_size=3, pad_mode=pad_mode ) - if temb_channels > 0: + if temb_channels is not None and temb_channels > 0: self.temb_proj = nn.Linear(in_features=temb_channels, out_features=out_channels) self.dropout = nn.Dropout(dropout) @@ -282,6 +388,12 @@ def __init__( self.conv_shortcut = CogVideoXSafeConv3d( in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1, padding=0 ) + else: + self.conv_shortcut = None + + self.norm3 = None + if final_norm_type is not None: + self.norm3 = _get_norm(final_norm_type, in_channels, eps=1e-6, elementwise_affine=True, bias=True, channel_dim=1) def forward( self, @@ -315,7 +427,10 @@ def forward( hidden_states = self.dropout(hidden_states) hidden_states, new_conv_cache["conv2"] = self.conv2(hidden_states, conv_cache=conv_cache.get("conv2")) - if self.in_channels != self.out_channels: + if self.norm3 is not None: + inputs = self.norm3(inputs) + + if self.conv_shortcut is not None: if self.use_conv_shortcut: inputs, new_conv_cache["conv_shortcut"] = self.conv_shortcut( inputs, conv_cache=conv_cache.get("conv_shortcut") @@ -466,6 +581,8 @@ class CogVideoXMidBlock3D(nn.Module): Activation function to use. resnet_groups (`int`, defaults to `32`): Number of groups to separate the channels into for group normalization. + norm_type (`str`, defaults to `"group_norm:`): + The type of normalization layer to use. spatial_norm_dim (`int`, *optional*): The dimension to use for spatial norm if it is to be used instead of group norm. pad_mode (str, defaults to `"first"`): @@ -480,6 +597,7 @@ def __init__( temb_channels: int, dropout: float = 0.0, num_layers: int = 1, + norm_type: str = "group_norm", resnet_eps: float = 1e-6, resnet_act_fn: str = "swish", resnet_groups: int = 32, @@ -496,6 +614,7 @@ def __init__( out_channels=in_channels, dropout=dropout, temb_channels=temb_channels, + norm_type=norm_type, groups=resnet_groups, eps=resnet_eps, spatial_norm_dim=spatial_norm_dim, @@ -562,6 +681,8 @@ class CogVideoXUpBlock3D(nn.Module): Activation function to use. resnet_groups (`int`, defaults to `32`): Number of groups to separate the channels into for group normalization. + norm_type (`str`, defaults to `"group_norm:`): + The type of normalization layer to use. spatial_norm_dim (`int`, defaults to `16`): The dimension to use for spatial norm if it is to be used instead of group norm. add_upsample (`bool`, defaults to `True`): @@ -582,6 +703,7 @@ def __init__( resnet_eps: float = 1e-6, resnet_act_fn: str = "swish", resnet_groups: int = 32, + norm_type: str = "group_norm", spatial_norm_dim: int = 16, add_upsample: bool = True, upsample_padding: int = 1, @@ -602,6 +724,7 @@ def __init__( groups=resnet_groups, eps=resnet_eps, non_linearity=resnet_act_fn, + norm_type=norm_type, spatial_norm_dim=spatial_norm_dim, pad_mode=pad_mode, ) @@ -881,6 +1004,7 @@ def __init__( resnet_eps=norm_eps, resnet_act_fn=act_fn, resnet_groups=norm_num_groups, + norm_type="spatial_norm", spatial_norm_dim=in_channels, pad_mode=pad_mode, ) @@ -907,6 +1031,7 @@ def __init__( resnet_eps=norm_eps, resnet_act_fn=act_fn, resnet_groups=norm_num_groups, + norm_type="spatial_norm", spatial_norm_dim=in_channels, add_upsample=not is_final_block, compress_time=compress_time, diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_ltx.py b/src/diffusers/models/autoencoders/autoencoder_kl_ltx.py new file mode 100644 index 000000000000..bf5b99a96047 --- /dev/null +++ b/src/diffusers/models/autoencoders/autoencoder_kl_ltx.py @@ -0,0 +1,958 @@ +# Copyright 2024 The Mochi team and The HuggingFace Team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import functools +from typing import Dict, Optional, Tuple, Union + +import torch +import torch.nn as nn + +from ...configuration_utils import ConfigMixin, register_to_config +from ...utils import logging +from ...utils.accelerate_utils import apply_forward_hook +from ..activations import get_activation +from ..attention_processor import Attention, MochiVaeAttnProcessor2_0 +from ..modeling_outputs import AutoencoderKLOutput +from ..modeling_utils import ModelMixin +from .autoencoder_kl_cogvideox import _get_norm, CogVideoXCausalConv3d, CogVideoXResnetBlock3D, LayerNormNd, CogVideoXMidBlock3D +from .vae import DecoderOutput, DiagonalGaussianDistribution + + +# {'_class_name': 'CausalVideoAutoencoder', 'dims': 3, 'in_channels': 3, 'out_channels': 3, 'latent_channels': 128, +# 'blocks': [ +# ['res_x', 4], ['compress_all', 1], ['res_x_y', 1], +# ['res_x', 3], ['compress_all', 1], ['res_x_y', 1], +# ['res_x', 3], ['compress_all', 1], +# ['res_x', 3], +# ['res_x', 4]], +# 'scaling_factor': 1.0, 'norm_layer': 'pixel_norm', 'patch_size': 4, 'latent_log_var': 'uniform', 'use_quant_conv': False, 'causal_decoder': False} + + +class LTXDownsampler3D(CogVideoXCausalConv3d): + pass + + +class LTXUpsampler3D(nn.Module): + def __init__( + self, + in_channels: int, + stride: Union[int, Tuple[int, int, int]] = 1, + ) -> None: + super().__init__() + + self.stride = stride if isinstance(stride, tuple) else (stride, stride, stride) + + out_channels = in_channels * stride[0] * stride[1] * stride[2] + + self.conv = CogVideoXCausalConv3d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=3, + stride=1, + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + batch_size, num_channels, num_frames, height, width = hidden_states.shape + + hidden_states, _ = self.conv(hidden_states) + hidden_states = hidden_states.reshape(batch_size, -1, self.stride[0], self.stride[1], self.stride[2], num_frames, height, width) + hidden_states = hidden_states.permute(0, 1, 5, 2, 6, 3, 7, 4).flatten(6, 7).flatten(4, 5).flatten(2, 3) + + hidden_states = hidden_states[:, :, self.stride[0] - 1:] + return hidden_states + + +class LTXDownBlock3D(nn.Module): + r""" + Down block used in the LTX model. + + Args: + in_channels (`int`): + Number of input channels. + out_channels (`int`, *optional*): + Number of output channels. If None, defaults to `in_channels`. + temb_channels (`int`, defaults to `512`): + Number of time embedding channels. + num_layers (`int`, defaults to `1`): + Number of resnet layers. + dropout (`float`, defaults to `0.0`): + Dropout rate. + resnet_eps (`float`, defaults to `1e-6`): + Epsilon value for normalization layers. + resnet_act_fn (`str`, defaults to `"swish"`): + Activation function to use. + resnet_groups (`int`, defaults to `32`): + Number of groups to separate the channels into for group normalization. + add_downsample (`bool`, defaults to `True`): + Whether or not to use a downsampling layer. If not used, output dimension would be same as input dimension. + compress_time (`bool`, defaults to `False`): + Whether or not to downsample across temporal dimension. + pad_mode (str, defaults to `"first"`): + Padding mode. + """ + + _supports_gradient_checkpointing = True + + def __init__( + self, + in_channels: int, + out_channels: Optional[int] = None, + temb_channels: Optional[int] = None, + dropout: float = 0.0, + num_layers: int = 1, + resnet_norm_type: str = "rms_norm", + resnet_eps: float = 1e-6, + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + spatio_temporal_scale: bool = True, + pad_mode: str = "first", + ): + super().__init__() + + out_channels = out_channels or in_channels + + resnets = [] + for _ in range(num_layers): + resnets.append( + CogVideoXResnetBlock3D( + in_channels=in_channels, + out_channels=in_channels, + dropout=dropout, + temb_channels=temb_channels, + norm_type=resnet_norm_type, + groups=resnet_groups, + eps=resnet_eps, + non_linearity=resnet_act_fn, + pad_mode=pad_mode, + ) + ) + self.resnets = nn.ModuleList(resnets) + + self.downsamplers = None + if spatio_temporal_scale: + self.downsamplers = nn.ModuleList( + [ + LTXDownsampler3D(in_channels=in_channels, out_channels=in_channels, kernel_size=3, stride=(2, 2, 2)) + ] + ) + + self.conv_out = None + if in_channels != out_channels: + self.conv_out = CogVideoXResnetBlock3D( + in_channels=in_channels, + out_channels=out_channels, + dropout=dropout, + temb_channels=temb_channels, + norm_type=resnet_norm_type, + final_norm_type="layer_norm", + groups=resnet_groups, + eps=resnet_eps, + non_linearity=resnet_act_fn, + ) + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + r"""Forward method of the `LTXDownBlock3D` class.""" + + for i, resnet in enumerate(self.resnets): + if torch.is_grad_enabled() and self.gradient_checkpointing: + + def create_custom_forward(module): + def create_forward(*inputs): + return module(*inputs) + + return create_forward + + hidden_states, _ = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states) + else: + hidden_states, _ = resnet(hidden_states) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states, _ = downsampler(hidden_states) + + if self.conv_out is not None: + hidden_states, _ = self.conv_out(hidden_states) + + return hidden_states + + +class LTXUpBlock3D(nn.Module): + r""" + Up block used in the LTX model. + + Args: + in_channels (`int`): + Number of input channels. + out_channels (`int`, *optional*): + Number of output channels. If None, defaults to `in_channels`. + temb_channels (`int`, defaults to `512`): + Number of time embedding channels. + num_layers (`int`, defaults to `1`): + Number of resnet layers. + dropout (`float`, defaults to `0.0`): + Dropout rate. + resnet_eps (`float`, defaults to `1e-6`): + Epsilon value for normalization layers. + resnet_act_fn (`str`, defaults to `"swish"`): + Activation function to use. + resnet_groups (`int`, defaults to `32`): + Number of groups to separate the channels into for group normalization. + add_downsample (`bool`, defaults to `True`): + Whether or not to use a downsampling layer. If not used, output dimension would be same as input dimension. + compress_time (`bool`, defaults to `False`): + Whether or not to downsample across temporal dimension. + pad_mode (str, defaults to `"first"`): + Padding mode. + """ + + _supports_gradient_checkpointing = True + + def __init__( + self, + in_channels: int, + out_channels: Optional[int] = None, + temb_channels: Optional[int] = None, + dropout: float = 0.0, + num_layers: int = 1, + resnet_norm_type: str = "rms_norm", + resnet_eps: float = 1e-6, + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + spatio_temporal_scale: bool = True, + pad_mode: str = "first", + ): + super().__init__() + + out_channels = out_channels or in_channels + + self.conv_in = None + if in_channels != out_channels: + self.conv_in = CogVideoXResnetBlock3D( + in_channels=in_channels, + out_channels=out_channels, + dropout=dropout, + temb_channels=temb_channels, + norm_type=resnet_norm_type, + final_norm_type="layer_norm", + groups=resnet_groups, + eps=resnet_eps, + non_linearity=resnet_act_fn, + ) + + self.upsamplers = None + if spatio_temporal_scale: + self.upsamplers = nn.ModuleList( + [ + LTXUpsampler3D(out_channels, stride=(2, 2, 2)) + ] + ) + + resnets = [] + for _ in range(num_layers): + resnets.append( + CogVideoXResnetBlock3D( + in_channels=out_channels, + out_channels=out_channels, + dropout=dropout, + temb_channels=temb_channels, + norm_type=resnet_norm_type, + groups=resnet_groups, + eps=resnet_eps, + non_linearity=resnet_act_fn, + pad_mode=pad_mode, + ) + ) + self.resnets = nn.ModuleList(resnets) + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + r"""Forward method of the `LTXDownBlock3D` class.""" + + if self.conv_in is not None: + hidden_states, _ = self.conv_in(hidden_states) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states) + + for i, resnet in enumerate(self.resnets): + if torch.is_grad_enabled() and self.gradient_checkpointing: + + def create_custom_forward(module): + def create_forward(*inputs): + return module(*inputs) + + return create_forward + + hidden_states, _ = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states) + else: + hidden_states, _ = resnet(hidden_states) + + return hidden_states + + +class LTXEncoder3D(nn.Module): + r""" + The `LTXEncoder3D` layer of a variational autoencoder that encodes input video samples to its latent + representation. + + Args: + TODO(aryan) + """ + + def __init__( + self, + in_channels: int = 3, + out_channels: int = 128, + block_out_channels: Tuple[int, ...] = (128, 256, 512, 512), + spatio_temporal_scaling: Tuple[bool, ...] = (True, True, True, False), + layers_per_block: Tuple[int, ...] = (4, 3, 3, 3, 4), + patch_size: int = 4, + patch_size_t: int = 1, + resnet_norm_type: str = "rms_norm", + resnet_groups: int = 32, + resnet_norm_eps: float = 1e-6, + ): + super().__init__() + + self.patch_size = patch_size + self.patch_size_t = patch_size_t + self.in_channels = in_channels * patch_size**2 + + output_channel = block_out_channels[0] + + self.conv_in = CogVideoXCausalConv3d( + in_channels=self.in_channels, + out_channels=output_channel, + kernel_size=3, + stride=1, + ) + + # down blocks + num_block_out_channels = len(block_out_channels) + self.down_blocks = nn.ModuleList([]) + for i in range(num_block_out_channels): + input_channel = output_channel + output_channel = block_out_channels[i + 1] if i + 1 < num_block_out_channels else block_out_channels[i] + + down_block = LTXDownBlock3D( + in_channels=input_channel, + out_channels=output_channel, + temb_channels=None, + num_layers=layers_per_block[i], + resnet_norm_type=resnet_norm_type, + resnet_eps=resnet_norm_eps, + resnet_groups=resnet_groups, + spatio_temporal_scale=spatio_temporal_scaling[i], + ) + + self.down_blocks.append(down_block) + + # mid block + self.mid_block = CogVideoXMidBlock3D( + in_channels=output_channel, + temb_channels=None, + num_layers=layers_per_block[-1], + norm_type=resnet_norm_type, + resnet_eps=resnet_norm_eps, + resnet_groups=resnet_groups, + ) + + # out + self.norm_out = _get_norm(resnet_norm_type, output_channel, eps=1e-6, elementwise_affine=False, channel_dim=1) + self.conv_act = nn.SiLU() + self.conv_out = CogVideoXCausalConv3d( + in_channels=output_channel, + out_channels=out_channels + 1, + kernel_size=3, + stride=1, + ) + + self.gradient_checkpointing = False + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + r"""The forward method of the `LTXEncoder3D` class.""" + + p = self.patch_size + p_t = self.patch_size_t + + batch_size, num_channels, num_frames, height, width = hidden_states.shape + post_patch_num_frames = num_frames // p_t + post_patch_height = height // p + post_patch_width = width // p + + hidden_states = hidden_states.reshape(batch_size, -1, post_patch_num_frames, p_t, post_patch_height, p, post_patch_width, p) + hidden_states, _ = self.conv_in(hidden_states) + + if torch.is_grad_enabled() and self.gradient_checkpointing: + + def create_custom_forward(module): + def create_forward(*inputs): + return module(*inputs) + + return create_forward + + for down_block in self.down_blocks: + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(down_block), + hidden_states, + ) + else: + for down_block in self.down_blocks: + hidden_states = down_block(hidden_states) + + hidden_states = self.norm_out(hidden_states) + hidden_states = self.conv_act(hidden_states) + hidden_states, _ = self.conv_out(hidden_states) + + last_channel = hidden_states[:, -1:] + last_channel = last_channel.repeat(1, hidden_states.size(1) - 2, 1, 1, 1) + hidden_states = torch.cat([hidden_states, last_channel], dim=1) + + return hidden_states + + +class LTXDecoder3D(nn.Module): + r""" + The `LTXDecoder3D` layer of a variational autoencoder that decodes its latent representation into an output + sample. + + Args: + TODO(aryan) + """ + + def __init__( + self, + in_channels: int = 128, + out_channels: int = 3, + block_out_channels: Tuple[int, ...] = (128, 256, 512, 512), + spatio_temporal_scaling: Tuple[bool, ...] = (True, True, True, False), + layers_per_block: Tuple[int, ...] = (4, 3, 3, 3, 4), + patch_size: int = 4, + patch_size_t: int = 1, + resnet_norm_type: str = "rms_norm", + resnet_groups: int = 32, + resnet_norm_eps: float = 1e-6, + ) -> None: + super().__init__() + + self.patch_size = patch_size + self.patch_size_t = patch_size_t + self.out_channels = out_channels * patch_size**2 + + block_out_channels = tuple(reversed(block_out_channels)) + spatio_temporal_scaling = tuple(reversed(spatio_temporal_scaling)) + layers_per_block = tuple(reversed(layers_per_block)) + + output_channel = block_out_channels[0] + + self.conv_in = CogVideoXCausalConv3d( + in_channels=in_channels, + out_channels=output_channel, + kernel_size=3, + stride=1, + ) + + self.mid_block = CogVideoXMidBlock3D( + in_channels=output_channel, + temb_channels=None, + num_layers=layers_per_block[-1], + norm_type=resnet_norm_type, + resnet_eps=resnet_norm_eps, + resnet_groups=resnet_groups, + ) + + # up blocks + num_block_out_channels = len(block_out_channels) + self.up_blocks = nn.ModuleList([]) + for i in range(num_block_out_channels): + input_channel = output_channel + output_channel = block_out_channels[i + 1] if i + 1 < num_block_out_channels else block_out_channels[i] + + up_block = LTXUpBlock3D( + in_channels=input_channel, + out_channels=output_channel, + temb_channels=None, + num_layers=layers_per_block[i], + resnet_norm_type=resnet_norm_type, + resnet_eps=resnet_norm_eps, + resnet_groups=resnet_groups, + spatio_temporal_scale=spatio_temporal_scaling[i], + ) + + self.up_blocks.append(up_block) + + # out + self.norm_out = _get_norm(resnet_norm_type, output_channel, eps=1e-6, elementwise_affine=False, channel_dim=1) + self.conv_act = nn.SiLU() + self.conv_out = CogVideoXCausalConv3d( + in_channels=output_channel, + out_channels=out_channels, + kernel_size=3, + stride=1, + ) + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor + ) -> torch.Tensor: + + hidden_states, _ = self.conv_in(hidden_states) + + if torch.is_grad_enabled() and self.gradient_checkpointing: + def create_custom_forward(module): + def create_forward(*inputs): + return module(*inputs) + + return create_forward + + for up_block in self.up_blocks: + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(up_block), + hidden_states, + ) + else: + for up_block in self.up_blocks: + hidden_states = up_block(hidden_states) + + hidden_states = self.norm_out(hidden_states) + hidden_states = self.conv_act(hidden_states) + hidden_states, _ = self.conv_out(hidden_states, causal=self.causal) + + p = self.patch_size + p_t = self.patch_size_t + + batch_size, num_channels, num_frames, height, width = hidden_states.shape + hidden_states = hidden_states.reshape(batch_size, -1, p_t, p, p, num_frames, height, width) + hidden_states = hidden_states.permute(0, 1, 5, 2, 6, 3, 7, 4).flatten(6, 7).flatten(4, 5).flatten(2, 2) + + return hidden_states + + +class AutoencoderKLLTX(ModelMixin, ConfigMixin): + r""" + A VAE model with KL loss for encoding images into latents and decoding latent representations into images. Used in + [LTX](https://huggingface.co/Lightricks/LTX-Video). + + This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented + for all models (such as downloading or saving). + + Args: + TODO(aryan) + """ + + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + in_channels: int = 3, + out_channels: int = 3, + latent_channels: int = 128, + block_out_channels: Tuple[int, ...] = (128, 256, 512, 512), + spatio_temporal_scaling: Tuple[bool, ...] = (True, True, True, False), + layers_per_block: Tuple[int, ...] = (4, 3, 3, 3, 4), + patch_size: int = 4, + patch_size_t: int = 1, + resnet_norm_type: str = "rms_norm", + resnet_groups: int = 32, + resnet_norm_eps: float = 1e-6, + scaling_factor: float = 1.0, + ) -> None: + super().__init__() + + self.encoder = LTXEncoder3D( + in_channels=in_channels, + out_channels=latent_channels, + block_out_channels=block_out_channels, + spatio_temporal_scaling=spatio_temporal_scaling, + layers_per_block=layers_per_block, + patch_size=patch_size, + patch_size_t=patch_size_t, + resnet_norm_type=resnet_norm_type, + resnet_groups=resnet_groups, + resnet_norm_eps=resnet_norm_eps, + ) + self.decoder = LTXDecoder3D( + in_channels=latent_channels, + out_channels=out_channels, + block_out_channels=block_out_channels, + spatio_temporal_scaling=spatio_temporal_scaling, + layers_per_block=layers_per_block, + patch_size=patch_size, + patch_size_t=patch_size_t, + resnet_norm_type=resnet_norm_type, + resnet_groups=resnet_groups, + resnet_norm_eps=resnet_norm_eps, + ) + + self.spatial_compression_ratio = patch_size * patch_size + self.temporal_compression_ratio = patch_size_t + + # When decoding a batch of video latents at a time, one can save memory by slicing across the batch dimension + # to perform decoding of a single video latent at a time. + self.use_slicing = False + + # When decoding spatially large video latents, the memory requirement is very high. By breaking the video latent + # frames spatially into smaller tiles and performing multiple forward passes for decoding, and then blending the + # intermediate tiles together, the memory requirement can be lowered. + self.use_tiling = False + + # When decoding temporally long video latents, the memory requirement is very high. By decoding latent frames + # at a fixed frame batch size (based on `self.num_latent_frames_batch_sizes`), the memory requirement can be lowered. + self.use_framewise_encoding = True + self.use_framewise_decoding = True + + # This can be configured based on the amount of GPU memory available. + # `12` for sample frames and `2` for latent frames are sensible defaults for consumer GPUs. + # Setting it to higher values results in higher memory usage. + self.num_sample_frames_batch_size = 8 + self.num_latent_frames_batch_size = 2 + + # The minimal tile height and width for spatial tiling to be used + self.tile_sample_min_height = 256 + self.tile_sample_min_width = 256 + + # The minimal distance between two spatial tiles + self.tile_sample_stride_height = 192 + self.tile_sample_stride_width = 192 + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, (LTXEncoder3D, LTXDecoder3D)): + module.gradient_checkpointing = value + + def enable_tiling( + self, + tile_sample_min_height: Optional[int] = None, + tile_sample_min_width: Optional[int] = None, + tile_sample_stride_height: Optional[float] = None, + tile_sample_stride_width: Optional[float] = None, + ) -> None: + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + + Args: + tile_sample_min_height (`int`, *optional*): + The minimum height required for a sample to be separated into tiles across the height dimension. + tile_sample_min_width (`int`, *optional*): + The minimum width required for a sample to be separated into tiles across the width dimension. + tile_sample_stride_height (`int`, *optional*): + The minimum amount of overlap between two consecutive vertical tiles. This is to ensure that there are + no tiling artifacts produced across the height dimension. + tile_sample_stride_width (`int`, *optional*): + The stride between two consecutive horizontal tiles. This is to ensure that there are no tiling + artifacts produced across the width dimension. + """ + self.use_tiling = True + self.tile_sample_min_height = tile_sample_min_height or self.tile_sample_min_height + self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width + self.tile_sample_stride_height = tile_sample_stride_height or self.tile_sample_stride_height + self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width + + def disable_tiling(self) -> None: + r""" + Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing + decoding in one step. + """ + self.use_tiling = False + + def enable_slicing(self) -> None: + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + self.use_slicing = True + + def disable_slicing(self) -> None: + r""" + Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing + decoding in one step. + """ + self.use_slicing = False + + def _encode(self, x: torch.Tensor) -> torch.Tensor: + batch_size, num_channels, num_frames, height, width = x.shape + + if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height): + return self.tiled_encode(x) + + if self.use_framewise_encoding: + enc = [] + for i in range(0, num_frames, self.num_sample_frames_batch_size): + x_intermediate = x[:, :, i : i + self.num_sample_frames_batch_size] + x_intermediate = self.encoder(x_intermediate) + enc.append(x_intermediate) + enc = torch.cat(enc, dim=2) + else: + enc, _ = self.encoder(x) + + return enc + + @apply_forward_hook + def encode( + self, x: torch.Tensor, return_dict: bool = True + ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]: + """ + Encode a batch of images into latents. + + Args: + x (`torch.Tensor`): Input batch of images. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple. + + Returns: + The latent representations of the encoded videos. If `return_dict` is True, a + [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned. + """ + if self.use_slicing and x.shape[0] > 1: + encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)] + h = torch.cat(encoded_slices) + else: + h = self._encode(x) + + posterior = DiagonalGaussianDistribution(h) + + if not return_dict: + return (posterior,) + return AutoencoderKLOutput(latent_dist=posterior) + + def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]: + batch_size, num_channels, num_frames, height, width = z.shape + tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio + tile_latent_min_width = self.tile_sample_stride_width // self.spatial_compression_ratio + + if self.use_tiling and (width > tile_latent_min_width or height > tile_latent_min_height): + return self.tiled_decode(z, return_dict=return_dict) + + if self.use_framewise_decoding: + dec = [] + for i in range(0, num_frames, self.num_latent_frames_batch_size): + z_intermediate = z[:, :, i : i + self.num_latent_frames_batch_size] + z_intermediate = self.decoder(z_intermediate) + dec.append(z_intermediate) + dec = torch.cat(dec, dim=2) + else: + dec = self.decoder(z) + + if not return_dict: + return (dec,) + + return DecoderOutput(sample=dec) + + @apply_forward_hook + def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]: + """ + Decode a batch of images. + + Args: + z (`torch.Tensor`): Input batch of latent vectors. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. + + Returns: + [`~models.vae.DecoderOutput`] or `tuple`: + If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is + returned. + """ + if self.use_slicing and z.shape[0] > 1: + decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)] + decoded = torch.cat(decoded_slices) + else: + decoded = self._decode(z).sample + + if not return_dict: + return (decoded,) + + return DecoderOutput(sample=decoded) + + def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + blend_extent = min(a.shape[3], b.shape[3], blend_extent) + for y in range(blend_extent): + b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * ( + y / blend_extent + ) + return b + + def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + blend_extent = min(a.shape[4], b.shape[4], blend_extent) + for x in range(blend_extent): + b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * ( + x / blend_extent + ) + return b + + def tiled_encode(self, x: torch.Tensor) -> torch.Tensor: + r"""Encode a batch of images using a tiled encoder. + + Args: + x (`torch.Tensor`): Input batch of videos. + + Returns: + `torch.Tensor`: + The latent representation of the encoded videos. + """ + batch_size, num_channels, num_frames, height, width = x.shape + latent_height = height // self.spatial_compression_ratio + latent_width = width // self.spatial_compression_ratio + + tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio + tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio + tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio + tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio + + blend_height = tile_latent_min_height - tile_latent_stride_height + blend_width = tile_latent_min_width - tile_latent_stride_width + + # Split x into overlapping tiles and encode them separately. + # The tiles have an overlap to avoid seams between tiles. + rows = [] + for i in range(0, height, self.tile_sample_stride_height): + row = [] + for j in range(0, width, self.tile_sample_stride_width): + if self.use_framewise_encoding: + time = [] + for k in range(0, num_frames, self.num_sample_frames_batch_size): + tile = x[:, :, k : k + self.num_sample_frames_batch_size, i : i + self.tile_sample_min_height, j : j + self.tile_sample_min_width] + tile = self.encoder(tile) + time.append(tile) + else: + time = self.encoder( + x[:, :, :, i : i + self.tile_sample_min_height, j : j + self.tile_sample_min_width] + ) + + row.append(time) + rows.append(row) + + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_height) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_width) + result_row.append(tile[:, :, :, :tile_latent_stride_height, :tile_latent_stride_width]) + result_rows.append(torch.cat(result_row, dim=4)) + + enc = torch.cat(result_rows, dim=3)[:, :, :, :latent_height, :latent_width] + return enc + + def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]: + r""" + Decode a batch of images using a tiled decoder. + + Args: + z (`torch.Tensor`): Input batch of latent vectors. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. + + Returns: + [`~models.vae.DecoderOutput`] or `tuple`: + If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is + returned. + """ + + batch_size, num_channels, num_frames, height, width = z.shape + sample_height = height * self.spatial_compression_ratio + sample_width = width * self.spatial_compression_ratio + + tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio + tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio + tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio + tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio + + blend_height = self.tile_sample_min_height - self.tile_sample_stride_height + blend_width = self.tile_sample_min_width - self.tile_sample_stride_width + + # Split z into overlapping tiles and decode them separately. + # The tiles have an overlap to avoid seams between tiles. + rows = [] + for i in range(0, height, tile_latent_stride_height): + row = [] + for j in range(0, width, tile_latent_stride_width): + if self.use_framewise_decoding: + time = [] + for k in range(0, num_frames, self.num_latent_frames_batch_size): + tile = z[ + :, + :, + k : k + self.num_latent_frames_batch_size, + i : i + tile_latent_min_height, + j : j + tile_latent_min_width, + ] + tile = self.decoder(tile) + time.append(tile) + time = torch.cat(time, dim=2) + else: + time = self.decoder(z[:, :, :, i : i + tile_latent_min_height, j : j + tile_latent_min_width]) + + if self.drop_last_temporal_frames and time.size(2) >= self.temporal_compression_ratio: + time = time[:, :, self.temporal_compression_ratio - 1 :] + + row.append(time) + rows.append(row) + + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_height) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_width) + result_row.append(tile[:, :, :, : self.tile_sample_stride_height, : self.tile_sample_stride_width]) + result_rows.append(torch.cat(result_row, dim=4)) + + dec = torch.cat(result_rows, dim=3)[:, :, :, :sample_height, :sample_width] + + if not return_dict: + return (dec,) + + return DecoderOutput(sample=dec) + + def forward( + self, + sample: torch.Tensor, + sample_posterior: bool = False, + return_dict: bool = True, + generator: Optional[torch.Generator] = None, + ) -> Union[torch.Tensor, torch.Tensor]: + x = sample + posterior = self.encode(x).latent_dist + if sample_posterior: + z = posterior.sample(generator=generator) + else: + z = posterior.mode() + dec = self.decode(z) + if not return_dict: + return (dec,) + return dec diff --git a/src/diffusers/models/normalization.py b/src/diffusers/models/normalization.py index 817b3fff2ea6..cb83aee365e0 100644 --- a/src/diffusers/models/normalization.py +++ b/src/diffusers/models/normalization.py @@ -516,6 +516,7 @@ def __init__(self, dim, eps: float, elementwise_affine: bool = True): super().__init__() self.eps = eps + self.elementwise_affine = elementwise_affine if isinstance(dim, numbers.Integral): dim = (dim,) @@ -541,6 +542,9 @@ def forward(self, hidden_states): hidden_states = hidden_states.to(input_dtype) return hidden_states + + def extra_repr(self) -> str: + return f"features={self.dim}, eps={self.eps}, elementwise_affine={self.elementwise_affine}" class GlobalResponseNorm(nn.Module): diff --git a/src/diffusers/models/transformers/transformer_ltx.py b/src/diffusers/models/transformers/transformer_ltx.py index 68f22a07456e..d0aefda47076 100644 --- a/src/diffusers/models/transformers/transformer_ltx.py +++ b/src/diffusers/models/transformers/transformer_ltx.py @@ -42,7 +42,7 @@ class LTXAttentionProcessor2_0: def __init__(self): if not hasattr(F, "scaled_dot_product_attention"): raise ImportError( - "AttnAddedKVProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." + "LTXAttentionProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." ) def __call__( @@ -51,7 +51,6 @@ def __call__( hidden_states: torch.Tensor, encoder_hidden_states: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, - temb: Optional[torch.Tensor] = None, image_rotary_emb: Optional[torch.Tensor] = None, ) -> torch.Tensor: batch_size, sequence_length, _ = ( @@ -258,7 +257,7 @@ def forward( encoder_hidden_states: torch.Tensor, timestep: torch.LongTensor, encoder_attention_mask: torch.Tensor, - image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]], + image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, return_dict: bool = True, ) -> torch.Tensor: # convert encoder_attention_mask to a bias the same way we do for attention_mask From 43f79070f6389dbc8eba40fc2019907cea3246f8 Mon Sep 17 00:00:00 2001 From: Aryan Date: Tue, 26 Nov 2024 14:13:33 +0100 Subject: [PATCH 06/51] make style --- .../autoencoders/autoencoder_kl_cogvideox.py | 34 +++++--- .../models/autoencoders/autoencoder_kl_ltx.py | 80 +++++++++---------- src/diffusers/models/normalization.py | 2 +- 3 files changed, 62 insertions(+), 54 deletions(-) diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py b/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py index 0620ec22379f..15318d0a0311 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py @@ -57,7 +57,7 @@ def __init__( ) self.channel_dim = channel_dim - + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: if self.channel_dim != -1: hidden_states = hidden_states.movedim(self.channel_dim, -1) @@ -65,7 +65,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states = hidden_states.movedim(-1, self.channel_dim) else: hidden_states = super().forward(hidden_states) - + return hidden_states def extra_repr(self) -> str: @@ -87,7 +87,7 @@ def __init__( ) self.channel_dim = channel_dim - + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: if self.channel_dim != -1: hidden_states = hidden_states.movedim(self.channel_dim, -1) @@ -95,7 +95,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states = hidden_states.movedim(-1, self.channel_dim) else: hidden_states = super().forward(hidden_states) - + return hidden_states def extra_repr(self): @@ -115,7 +115,9 @@ def _get_norm( if norm_type == "group_norm": norm = nn.GroupNorm(num_channels=num_channels, num_groups=groups, eps=eps) elif norm_type == "layer_norm": - norm = LayerNormNd(num_channels, eps=eps, elementwise_affine=elementwise_affine, bias=bias, channel_dim=channel_dim) + norm = LayerNormNd( + num_channels, eps=eps, elementwise_affine=elementwise_affine, bias=bias, channel_dim=channel_dim + ) elif norm_type == "rms_norm": norm = RMSNormNd(dim=num_channels, eps=eps, elementwise_affine=elementwise_affine, channel_dim=channel_dim) elif norm_type == "spatial_norm": @@ -221,7 +223,9 @@ def fake_context_parallel_forward( else: kernel_size = self.time_kernel_size if kernel_size > 1: - pad_left = conv_cache if conv_cache is not None else inputs[:, :, :1].repeat(1, 1, kernel_size - 1, 1, 1) + pad_left = ( + conv_cache if conv_cache is not None else inputs[:, :, :1].repeat(1, 1, kernel_size - 1, 1, 1) + ) inputs = torch.cat([pad_left, inputs], dim=2) return inputs @@ -347,7 +351,9 @@ def __init__( self.spatial_norm_dim = spatial_norm_dim if spatial_norm_dim is not None and norm_type != "spatial_norm": - logger.info("`spatial_norm_dim` is specified but the `norm_type` is not \"spatial_norm\". The norm type will be overwritten.") + logger.info( + '`spatial_norm_dim` is specified but the `norm_type` is not "spatial_norm". The norm type will be overwritten.' + ) norm_type = "spatial_norm" if norm_type == "group_norm": @@ -358,8 +364,12 @@ def __init__( self.norm2 = _get_norm(norm_type, out_channels, elementwise_affine=elementwise_affine, channel_dim=1) elif norm_type == "layer_norm": # num_channels, eps=eps, elementwise_affine=elementwise_affine, bias=bias, channel_dim=channel_dim - self.norm1 = _get_norm(norm_type, in_channels, eps=eps, elementwise_affine=elementwise_affine, bias=norm_bias, channel_dim=1) - self.norm2 = _get_norm(norm_type, in_channels, eps=eps, elementwise_affine=elementwise_affine, bias=norm_bias, channel_dim=1) + self.norm1 = _get_norm( + norm_type, in_channels, eps=eps, elementwise_affine=elementwise_affine, bias=norm_bias, channel_dim=1 + ) + self.norm2 = _get_norm( + norm_type, in_channels, eps=eps, elementwise_affine=elementwise_affine, bias=norm_bias, channel_dim=1 + ) elif norm_type == "spatial_norm": assert spatial_norm_dim is not None self.norm1 = _get_norm(norm_type, in_channels, groups, spatial_norm_dim=spatial_norm_dim) @@ -390,10 +400,12 @@ def __init__( ) else: self.conv_shortcut = None - + self.norm3 = None if final_norm_type is not None: - self.norm3 = _get_norm(final_norm_type, in_channels, eps=1e-6, elementwise_affine=True, bias=True, channel_dim=1) + self.norm3 = _get_norm( + final_norm_type, in_channels, eps=1e-6, elementwise_affine=True, bias=True, channel_dim=1 + ) def forward( self, diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_ltx.py b/src/diffusers/models/autoencoders/autoencoder_kl_ltx.py index bf5b99a96047..355f045fe266 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_ltx.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_ltx.py @@ -13,20 +13,16 @@ # See the License for the specific language governing permissions and # limitations under the License. -import functools -from typing import Dict, Optional, Tuple, Union +from typing import Optional, Tuple, Union import torch import torch.nn as nn from ...configuration_utils import ConfigMixin, register_to_config -from ...utils import logging from ...utils.accelerate_utils import apply_forward_hook -from ..activations import get_activation -from ..attention_processor import Attention, MochiVaeAttnProcessor2_0 from ..modeling_outputs import AutoencoderKLOutput from ..modeling_utils import ModelMixin -from .autoencoder_kl_cogvideox import _get_norm, CogVideoXCausalConv3d, CogVideoXResnetBlock3D, LayerNormNd, CogVideoXMidBlock3D +from .autoencoder_kl_cogvideox import CogVideoXCausalConv3d, CogVideoXMidBlock3D, CogVideoXResnetBlock3D, _get_norm from .vae import DecoderOutput, DiagonalGaussianDistribution @@ -55,22 +51,24 @@ def __init__( self.stride = stride if isinstance(stride, tuple) else (stride, stride, stride) out_channels = in_channels * stride[0] * stride[1] * stride[2] - + self.conv = CogVideoXCausalConv3d( in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1, ) - + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: batch_size, num_channels, num_frames, height, width = hidden_states.shape - + hidden_states, _ = self.conv(hidden_states) - hidden_states = hidden_states.reshape(batch_size, -1, self.stride[0], self.stride[1], self.stride[2], num_frames, height, width) + hidden_states = hidden_states.reshape( + batch_size, -1, self.stride[0], self.stride[1], self.stride[2], num_frames, height, width + ) hidden_states = hidden_states.permute(0, 1, 5, 2, 6, 3, 7, 4).flatten(6, 7).flatten(4, 5).flatten(2, 3) - hidden_states = hidden_states[:, :, self.stride[0] - 1:] + hidden_states = hidden_states[:, :, self.stride[0] - 1 :] return hidden_states @@ -139,15 +137,13 @@ def __init__( ) ) self.resnets = nn.ModuleList(resnets) - + self.downsamplers = None if spatio_temporal_scale: self.downsamplers = nn.ModuleList( - [ - LTXDownsampler3D(in_channels=in_channels, out_channels=in_channels, kernel_size=3, stride=(2, 2, 2)) - ] + [LTXDownsampler3D(in_channels=in_channels, out_channels=in_channels, kernel_size=3, stride=(2, 2, 2))] ) - + self.conv_out = None if in_channels != out_channels: self.conv_out = CogVideoXResnetBlock3D( @@ -186,7 +182,7 @@ def create_forward(*inputs): if self.downsamplers is not None: for downsampler in self.downsamplers: hidden_states, _ = downsampler(hidden_states) - + if self.conv_out is not None: hidden_states, _ = self.conv_out(hidden_states) @@ -258,12 +254,8 @@ def __init__( self.upsamplers = None if spatio_temporal_scale: - self.upsamplers = nn.ModuleList( - [ - LTXUpsampler3D(out_channels, stride=(2, 2, 2)) - ] - ) - + self.upsamplers = nn.ModuleList([LTXUpsampler3D(out_channels, stride=(2, 2, 2))]) + resnets = [] for _ in range(num_layers): resnets.append( @@ -291,7 +283,7 @@ def forward( if self.conv_in is not None: hidden_states, _ = self.conv_in(hidden_states) - + if self.upsamplers is not None: for upsampler in self.upsamplers: hidden_states = upsampler(hidden_states) @@ -339,7 +331,7 @@ def __init__( self.patch_size = patch_size self.patch_size_t = patch_size_t self.in_channels = in_channels * patch_size**2 - + output_channel = block_out_channels[0] self.conv_in = CogVideoXCausalConv3d( @@ -355,7 +347,7 @@ def __init__( for i in range(num_block_out_channels): input_channel = output_channel output_channel = block_out_channels[i + 1] if i + 1 < num_block_out_channels else block_out_channels[i] - + down_block = LTXDownBlock3D( in_channels=input_channel, out_channels=output_channel, @@ -366,9 +358,9 @@ def __init__( resnet_groups=resnet_groups, spatio_temporal_scale=spatio_temporal_scaling[i], ) - + self.down_blocks.append(down_block) - + # mid block self.mid_block = CogVideoXMidBlock3D( in_channels=output_channel, @@ -402,7 +394,9 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: post_patch_height = height // p post_patch_width = width // p - hidden_states = hidden_states.reshape(batch_size, -1, post_patch_num_frames, p_t, post_patch_height, p, post_patch_width, p) + hidden_states = hidden_states.reshape( + batch_size, -1, post_patch_num_frames, p_t, post_patch_height, p, post_patch_width, p + ) hidden_states, _ = self.conv_in(hidden_states) if torch.is_grad_enabled() and self.gradient_checkpointing: @@ -435,8 +429,7 @@ def create_forward(*inputs): class LTXDecoder3D(nn.Module): r""" - The `LTXDecoder3D` layer of a variational autoencoder that decodes its latent representation into an output - sample. + The `LTXDecoder3D` layer of a variational autoencoder that decodes its latent representation into an output sample. Args: TODO(aryan) @@ -489,7 +482,7 @@ def __init__( for i in range(num_block_out_channels): input_channel = output_channel output_channel = block_out_channels[i + 1] if i + 1 < num_block_out_channels else block_out_channels[i] - + up_block = LTXUpBlock3D( in_channels=input_channel, out_channels=output_channel, @@ -515,14 +508,11 @@ def __init__( self.gradient_checkpointing = False - def forward( - self, - hidden_states: torch.Tensor - ) -> torch.Tensor: - + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states, _ = self.conv_in(hidden_states) if torch.is_grad_enabled() and self.gradient_checkpointing: + def create_custom_forward(module): def create_forward(*inputs): return module(*inputs) @@ -544,7 +534,7 @@ def create_forward(*inputs): p = self.patch_size p_t = self.patch_size_t - + batch_size, num_channels, num_frames, height, width = hidden_states.shape hidden_states = hidden_states.reshape(batch_size, -1, p_t, p, p, num_frames, height, width) hidden_states = hidden_states.permute(0, 1, 5, 2, 6, 3, 7, 4).flatten(6, 7).flatten(4, 5).flatten(2, 2) @@ -639,11 +629,11 @@ def __init__( # The minimal distance between two spatial tiles self.tile_sample_stride_height = 192 self.tile_sample_stride_width = 192 - + def _set_gradient_checkpointing(self, module, value=False): if isinstance(module, (LTXEncoder3D, LTXDecoder3D)): module.gradient_checkpointing = value - + def enable_tiling( self, tile_sample_min_height: Optional[int] = None, @@ -673,7 +663,7 @@ def enable_tiling( self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width self.tile_sample_stride_height = tile_sample_stride_height or self.tile_sample_stride_height self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width - + def disable_tiling(self) -> None: r""" Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing @@ -837,7 +827,13 @@ def tiled_encode(self, x: torch.Tensor) -> torch.Tensor: if self.use_framewise_encoding: time = [] for k in range(0, num_frames, self.num_sample_frames_batch_size): - tile = x[:, :, k : k + self.num_sample_frames_batch_size, i : i + self.tile_sample_min_height, j : j + self.tile_sample_min_width] + tile = x[ + :, + :, + k : k + self.num_sample_frames_batch_size, + i : i + self.tile_sample_min_height, + j : j + self.tile_sample_min_width, + ] tile = self.encoder(tile) time.append(tile) else: diff --git a/src/diffusers/models/normalization.py b/src/diffusers/models/normalization.py index cb83aee365e0..4e96086ada4e 100644 --- a/src/diffusers/models/normalization.py +++ b/src/diffusers/models/normalization.py @@ -542,7 +542,7 @@ def forward(self, hidden_states): hidden_states = hidden_states.to(input_dtype) return hidden_states - + def extra_repr(self) -> str: return f"features={self.dim}, eps={self.eps}, elementwise_affine={self.elementwise_affine}" From 02a2b6b9b4b1c3f85ce6a39998da41a5f0568ec3 Mon Sep 17 00:00:00 2001 From: Aryan Date: Tue, 26 Nov 2024 14:15:03 +0100 Subject: [PATCH 07/51] make fix-copies --- src/diffusers/__init__.py | 2 ++ src/diffusers/models/__init__.py | 2 ++ src/diffusers/models/autoencoders/__init__.py | 1 + src/diffusers/utils/dummy_pt_objects.py | 15 +++++++++++++++ 4 files changed, 20 insertions(+) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 78b56dc75456..b4ae36d9c5fa 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -83,6 +83,7 @@ "AutoencoderKL", "AutoencoderKLAllegro", "AutoencoderKLCogVideoX", + "AutoencoderKLLTX", "AutoencoderKLMochi", "AutoencoderKLTemporalDecoder", "AutoencoderOobleck", @@ -575,6 +576,7 @@ AutoencoderKL, AutoencoderKLAllegro, AutoencoderKLCogVideoX, + AutoencoderKLLTX, AutoencoderKLMochi, AutoencoderKLTemporalDecoder, AutoencoderOobleck, diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index f2cb434f59f5..5e5279cdcad6 100644 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -30,6 +30,7 @@ _import_structure["autoencoders.autoencoder_kl"] = ["AutoencoderKL"] _import_structure["autoencoders.autoencoder_kl_allegro"] = ["AutoencoderKLAllegro"] _import_structure["autoencoders.autoencoder_kl_cogvideox"] = ["AutoencoderKLCogVideoX"] + _import_structure["autoencoders.autoencoder_kl_ltx"] = ["AutoencoderKLLTX"] _import_structure["autoencoders.autoencoder_kl_mochi"] = ["AutoencoderKLMochi"] _import_structure["autoencoders.autoencoder_kl_temporal_decoder"] = ["AutoencoderKLTemporalDecoder"] _import_structure["autoencoders.autoencoder_oobleck"] = ["AutoencoderOobleck"] @@ -92,6 +93,7 @@ AutoencoderKL, AutoencoderKLAllegro, AutoencoderKLCogVideoX, + AutoencoderKLLTX, AutoencoderKLMochi, AutoencoderKLTemporalDecoder, AutoencoderOobleck, diff --git a/src/diffusers/models/autoencoders/__init__.py b/src/diffusers/models/autoencoders/__init__.py index ba45d6671252..795cc21c7963 100644 --- a/src/diffusers/models/autoencoders/__init__.py +++ b/src/diffusers/models/autoencoders/__init__.py @@ -2,6 +2,7 @@ from .autoencoder_kl import AutoencoderKL from .autoencoder_kl_allegro import AutoencoderKLAllegro from .autoencoder_kl_cogvideox import AutoencoderKLCogVideoX +from .autoencoder_kl_ltx import AutoencoderKLLTX from .autoencoder_kl_mochi import AutoencoderKLMochi from .autoencoder_kl_temporal_decoder import AutoencoderKLTemporalDecoder from .autoencoder_oobleck import AutoencoderOobleck diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index 69ae5f746782..65dd0bbce2b7 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -92,6 +92,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class AutoencoderKLLTX(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class AutoencoderKLMochi(metaclass=DummyObject): _backends = ["torch"] From c90164160e45f58e0084c48c93a284e5a39245ea Mon Sep 17 00:00:00 2001 From: Aryan Date: Tue, 26 Nov 2024 14:37:01 +0100 Subject: [PATCH 08/51] fix --- src/diffusers/models/autoencoders/autoencoder_kl_ltx.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_ltx.py b/src/diffusers/models/autoencoders/autoencoder_kl_ltx.py index 355f045fe266..775d58def0da 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_ltx.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_ltx.py @@ -470,7 +470,7 @@ def __init__( self.mid_block = CogVideoXMidBlock3D( in_channels=output_channel, temb_channels=None, - num_layers=layers_per_block[-1], + num_layers=layers_per_block[0], norm_type=resnet_norm_type, resnet_eps=resnet_norm_eps, resnet_groups=resnet_groups, @@ -481,13 +481,13 @@ def __init__( self.up_blocks = nn.ModuleList([]) for i in range(num_block_out_channels): input_channel = output_channel - output_channel = block_out_channels[i + 1] if i + 1 < num_block_out_channels else block_out_channels[i] + output_channel = block_out_channels[i] up_block = LTXUpBlock3D( in_channels=input_channel, out_channels=output_channel, temb_channels=None, - num_layers=layers_per_block[i], + num_layers=layers_per_block[i + 1], resnet_norm_type=resnet_norm_type, resnet_eps=resnet_norm_eps, resnet_groups=resnet_groups, @@ -501,7 +501,7 @@ def __init__( self.conv_act = nn.SiLU() self.conv_out = CogVideoXCausalConv3d( in_channels=output_channel, - out_channels=out_channels, + out_channels=self.out_channels, kernel_size=3, stride=1, ) From 868cd47c8fd58e2efbaf2fbc9463e7d1bf813bf2 Mon Sep 17 00:00:00 2001 From: Aryan Date: Wed, 27 Nov 2024 02:45:06 +0100 Subject: [PATCH 09/51] undo cogvideox changes --- .../autoencoders/autoencoder_kl_cogvideox.py | 173 ++---------------- 1 file changed, 18 insertions(+), 155 deletions(-) diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py b/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py index 15318d0a0311..fbcb964392f9 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, List, Optional, Tuple, Union +from typing import Dict, Optional, Tuple, Union import numpy as np import torch @@ -28,7 +28,6 @@ from ..downsampling import CogVideoXDownsample3D from ..modeling_outputs import AutoencoderKLOutput from ..modeling_utils import ModelMixin -from ..normalization import RMSNorm from ..upsampling import CogVideoXUpsample3D from .vae import DecoderOutput, DiagonalGaussianDistribution @@ -36,101 +35,6 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name -class LayerNormNd(nn.LayerNorm): - def __init__( - self, - normalized_shape: Union[int, List[int], Tuple[int], torch.Size], - eps: float = 1e-5, - elementwise_affine: bool = True, - bias: bool = True, - device=None, - dtype=None, - channel_dim: int = -1, - ) -> None: - super().__init__( - normalized_shape=normalized_shape, - eps=eps, - elementwise_affine=elementwise_affine, - bias=bias, - device=device, - dtype=dtype, - ) - - self.channel_dim = channel_dim - - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - if self.channel_dim != -1: - hidden_states = hidden_states.movedim(self.channel_dim, -1) - hidden_states = super().forward(hidden_states) - hidden_states = hidden_states.movedim(-1, self.channel_dim) - else: - hidden_states = super().forward(hidden_states) - - return hidden_states - - def extra_repr(self) -> str: - return f"{super().extra_repr()}, channel_dim={self.channel_dim}" - - -class RMSNormNd(RMSNorm): - def __init__( - self, - dim: int, - eps: float, - elementwise_affine: bool = True, - channel_dim: int = -1, - ) -> None: - super().__init__( - dim=dim, - eps=eps, - elementwise_affine=elementwise_affine, - ) - - self.channel_dim = channel_dim - - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - if self.channel_dim != -1: - hidden_states = hidden_states.movedim(self.channel_dim, -1) - hidden_states = super().forward(hidden_states) - hidden_states = hidden_states.movedim(-1, self.channel_dim) - else: - hidden_states = super().forward(hidden_states) - - return hidden_states - - def extra_repr(self): - return f"{super().extra_repr()}, channel_dim={self.channel_dim}" - - -def _get_norm( - norm_type: str, - num_channels: int, - groups: int = 32, - eps: float = 1e-6, - elementwise_affine: bool = False, - bias: bool = True, - spatial_norm_dim: Optional[int] = None, - channel_dim: int = -1, -) -> nn.Module: - if norm_type == "group_norm": - norm = nn.GroupNorm(num_channels=num_channels, num_groups=groups, eps=eps) - elif norm_type == "layer_norm": - norm = LayerNormNd( - num_channels, eps=eps, elementwise_affine=elementwise_affine, bias=bias, channel_dim=channel_dim - ) - elif norm_type == "rms_norm": - norm = RMSNormNd(dim=num_channels, eps=eps, elementwise_affine=elementwise_affine, channel_dim=channel_dim) - elif norm_type == "spatial_norm": - norm = CogVideoXSpatialNorm3D( - f_channels=num_channels, - zq_channels=spatial_norm_dim, - groups=groups, - ) - else: - raise ValueError("Invalid `norm_type` specified.") - return norm - - class CogVideoXSafeConv3d(nn.Conv3d): r""" A 3D convolution layer that splits the input tensor into smaller parts to avoid OOM in CogVideoX Model. @@ -179,7 +83,7 @@ def __init__( in_channels: int, out_channels: int, kernel_size: Union[int, Tuple[int, int, int]], - stride: Union[int, Tuple[int, int, int]] = 1, + stride: int = 1, dilation: int = 1, pad_mode: str = "constant", ): @@ -223,10 +127,8 @@ def fake_context_parallel_forward( else: kernel_size = self.time_kernel_size if kernel_size > 1: - pad_left = ( - conv_cache if conv_cache is not None else inputs[:, :, :1].repeat(1, 1, kernel_size - 1, 1, 1) - ) - inputs = torch.cat([pad_left, inputs], dim=2) + cached_inputs = [conv_cache] if conv_cache is not None else [inputs[:, :, :1]] * (kernel_size - 1) + inputs = torch.cat(cached_inputs + [inputs], dim=2) return inputs def forward(self, inputs: torch.Tensor, conv_cache: Optional[torch.Tensor] = None) -> torch.Tensor: @@ -315,8 +217,6 @@ class CogVideoXResnetBlock3D(nn.Module): Activation function to use. conv_shortcut (bool, defaults to `False`): Whether or not to use a convolution shortcut. - norm_type (`str`, defaults to `"group_norm:`): - The type of normalization layer to use. spatial_norm_dim (`int`, *optional*): The dimension to use for spatial norm if it is to be used instead of group norm. pad_mode (str, defaults to `"first"`): @@ -329,12 +229,8 @@ def __init__( out_channels: Optional[int] = None, dropout: float = 0.0, temb_channels: int = 512, - norm_type: str = "group_norm", - final_norm_type: Optional[str] = None, groups: int = 32, eps: float = 1e-6, - elementwise_affine: bool = False, - norm_bias: bool = True, non_linearity: str = "swish", conv_shortcut: bool = False, spatial_norm_dim: Optional[int] = None, @@ -350,38 +246,26 @@ def __init__( self.use_conv_shortcut = conv_shortcut self.spatial_norm_dim = spatial_norm_dim - if spatial_norm_dim is not None and norm_type != "spatial_norm": - logger.info( - '`spatial_norm_dim` is specified but the `norm_type` is not "spatial_norm". The norm type will be overwritten.' - ) - norm_type = "spatial_norm" - - if norm_type == "group_norm": - self.norm1 = _get_norm(norm_type, in_channels, groups, eps, channel_dim=1) - self.norm2 = _get_norm(norm_type, out_channels, groups, eps, channel_dim=1) - elif norm_type == "rms_norm": - self.norm1 = _get_norm(norm_type, out_channels, elementwise_affine=elementwise_affine, channel_dim=1) - self.norm2 = _get_norm(norm_type, out_channels, elementwise_affine=elementwise_affine, channel_dim=1) - elif norm_type == "layer_norm": - # num_channels, eps=eps, elementwise_affine=elementwise_affine, bias=bias, channel_dim=channel_dim - self.norm1 = _get_norm( - norm_type, in_channels, eps=eps, elementwise_affine=elementwise_affine, bias=norm_bias, channel_dim=1 + if spatial_norm_dim is None: + self.norm1 = nn.GroupNorm(num_channels=in_channels, num_groups=groups, eps=eps) + self.norm2 = nn.GroupNorm(num_channels=out_channels, num_groups=groups, eps=eps) + else: + self.norm1 = CogVideoXSpatialNorm3D( + f_channels=in_channels, + zq_channels=spatial_norm_dim, + groups=groups, ) - self.norm2 = _get_norm( - norm_type, in_channels, eps=eps, elementwise_affine=elementwise_affine, bias=norm_bias, channel_dim=1 + self.norm2 = CogVideoXSpatialNorm3D( + f_channels=out_channels, + zq_channels=spatial_norm_dim, + groups=groups, ) - elif norm_type == "spatial_norm": - assert spatial_norm_dim is not None - self.norm1 = _get_norm(norm_type, in_channels, groups, spatial_norm_dim=spatial_norm_dim) - self.norm2 = _get_norm(norm_type, out_channels, groups, spatial_norm_dim=spatial_norm_dim) - elif norm_type is not None: - raise ValueError("Invalid `norm_type` specified.") self.conv1 = CogVideoXCausalConv3d( in_channels=in_channels, out_channels=out_channels, kernel_size=3, pad_mode=pad_mode ) - if temb_channels is not None and temb_channels > 0: + if temb_channels > 0: self.temb_proj = nn.Linear(in_features=temb_channels, out_features=out_channels) self.dropout = nn.Dropout(dropout) @@ -398,14 +282,6 @@ def __init__( self.conv_shortcut = CogVideoXSafeConv3d( in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1, padding=0 ) - else: - self.conv_shortcut = None - - self.norm3 = None - if final_norm_type is not None: - self.norm3 = _get_norm( - final_norm_type, in_channels, eps=1e-6, elementwise_affine=True, bias=True, channel_dim=1 - ) def forward( self, @@ -439,10 +315,7 @@ def forward( hidden_states = self.dropout(hidden_states) hidden_states, new_conv_cache["conv2"] = self.conv2(hidden_states, conv_cache=conv_cache.get("conv2")) - if self.norm3 is not None: - inputs = self.norm3(inputs) - - if self.conv_shortcut is not None: + if self.in_channels != self.out_channels: if self.use_conv_shortcut: inputs, new_conv_cache["conv_shortcut"] = self.conv_shortcut( inputs, conv_cache=conv_cache.get("conv_shortcut") @@ -593,8 +466,6 @@ class CogVideoXMidBlock3D(nn.Module): Activation function to use. resnet_groups (`int`, defaults to `32`): Number of groups to separate the channels into for group normalization. - norm_type (`str`, defaults to `"group_norm:`): - The type of normalization layer to use. spatial_norm_dim (`int`, *optional*): The dimension to use for spatial norm if it is to be used instead of group norm. pad_mode (str, defaults to `"first"`): @@ -609,7 +480,6 @@ def __init__( temb_channels: int, dropout: float = 0.0, num_layers: int = 1, - norm_type: str = "group_norm", resnet_eps: float = 1e-6, resnet_act_fn: str = "swish", resnet_groups: int = 32, @@ -626,7 +496,6 @@ def __init__( out_channels=in_channels, dropout=dropout, temb_channels=temb_channels, - norm_type=norm_type, groups=resnet_groups, eps=resnet_eps, spatial_norm_dim=spatial_norm_dim, @@ -693,8 +562,6 @@ class CogVideoXUpBlock3D(nn.Module): Activation function to use. resnet_groups (`int`, defaults to `32`): Number of groups to separate the channels into for group normalization. - norm_type (`str`, defaults to `"group_norm:`): - The type of normalization layer to use. spatial_norm_dim (`int`, defaults to `16`): The dimension to use for spatial norm if it is to be used instead of group norm. add_upsample (`bool`, defaults to `True`): @@ -715,7 +582,6 @@ def __init__( resnet_eps: float = 1e-6, resnet_act_fn: str = "swish", resnet_groups: int = 32, - norm_type: str = "group_norm", spatial_norm_dim: int = 16, add_upsample: bool = True, upsample_padding: int = 1, @@ -736,7 +602,6 @@ def __init__( groups=resnet_groups, eps=resnet_eps, non_linearity=resnet_act_fn, - norm_type=norm_type, spatial_norm_dim=spatial_norm_dim, pad_mode=pad_mode, ) @@ -1016,7 +881,6 @@ def __init__( resnet_eps=norm_eps, resnet_act_fn=act_fn, resnet_groups=norm_num_groups, - norm_type="spatial_norm", spatial_norm_dim=in_channels, pad_mode=pad_mode, ) @@ -1043,7 +907,6 @@ def __init__( resnet_eps=norm_eps, resnet_act_fn=act_fn, resnet_groups=norm_num_groups, - norm_type="spatial_norm", spatial_norm_dim=in_channels, add_upsample=not is_final_block, compress_time=compress_time, From db13a83868428df59fc5bc93c07f25f4c9116515 Mon Sep 17 00:00:00 2001 From: Aryan Date: Wed, 27 Nov 2024 04:12:08 +0100 Subject: [PATCH 10/51] update --- scripts/convert_ltx_to_diffusers.py | 104 ++++- .../models/autoencoders/autoencoder_kl_ltx.py | 406 +++++++++++------- src/diffusers/models/normalization.py | 68 ++- src/diffusers/pipelines/ltx/pipeline_ltx.py | 0 4 files changed, 426 insertions(+), 152 deletions(-) create mode 100644 src/diffusers/pipelines/ltx/pipeline_ltx.py diff --git a/scripts/convert_ltx_to_diffusers.py b/scripts/convert_ltx_to_diffusers.py index cf2356f7b5df..1acd16b98233 100644 --- a/scripts/convert_ltx_to_diffusers.py +++ b/scripts/convert_ltx_to_diffusers.py @@ -3,10 +3,17 @@ import torch from safetensors.torch import load_file +from transformers import T5EncoderModel, T5Tokenizer -from diffusers import LTXTransformer3DModel +from diffusers import AutoencoderKLLTX, FlowMatchEulerDiscreteScheduler, LTXTransformer3DModel +def remove_keys_inplace(key: str, state_dict: Dict[str, Any]): + state_dict.pop(key) + + +TOKENIZER_MAX_LENGTH = 128 + TRANSFORMER_KEYS_RENAME_DICT = { "q_norm": "norm_q", "k_norm": "norm_k", @@ -14,6 +21,43 @@ TRANSFORMER_SPECIAL_KEYS_REMAP = {} +VAE_KEYS_RENAME_DICT = { + # decoder + "up_blocks.0": "mid_block", + "up_blocks.1": "up_blocks.0", + "up_blocks.2": "up_blocks.1.upsamplers.0", + "up_blocks.3": "up_blocks.1", + "up_blocks.4": "up_blocks.2.conv_in", + "up_blocks.5": "up_blocks.2.upsamplers.0", + "up_blocks.6": "up_blocks.2", + "up_blocks.7": "up_blocks.3.conv_in", + "up_blocks.8": "up_blocks.3.upsamplers.0", + "up_blocks.9": "up_blocks.3", + # encoder + "down_blocks.0": "down_blocks.0", + "down_blocks.1": "down_blocks.0.downsamplers.0", + "down_blocks.2": "down_blocks.0.conv_out", + "down_blocks.3": "down_blocks.1", + "down_blocks.4": "down_blocks.1.downsamplers.0", + "down_blocks.5": "down_blocks.1.conv_out", + "down_blocks.6": "down_blocks.2", + "down_blocks.7": "down_blocks.2.downsamplers.0", + "down_blocks.8": "down_blocks.3", + "down_blocks.9": "mid_block", + # common + "conv_shortcut": "conv_shortcut.conv", + "res_blocks": "resnets", + "norm3.norm": "norm3", + "per_channel_statistics.mean-of-means": "latent_means", + "per_channel_statistics.std-of-means": "latent_stds", +} + +VAE_SPECIAL_KEYS_REMAP = { + "per_channel_statistics.channel": remove_keys_inplace, + "per_channel_statistics.mean-of-means": remove_keys_inplace, + "per_channel_statistics.mean-of-stds": remove_keys_inplace, +} + def get_state_dict(saved_dict: Dict[str, Any]) -> Dict[str, Any]: state_dict = saved_dict @@ -55,11 +99,36 @@ def convert_transformer( return transformer +def convert_vae(ckpt_path: str, dtype: torch.dtype): + original_state_dict = get_state_dict(load_file(ckpt_path)) + vae = AutoencoderKLLTX().to(dtype=dtype) + + for key in list(original_state_dict.keys()): + new_key = key[:] + for replace_key, rename_key in VAE_KEYS_RENAME_DICT.items(): + new_key = new_key.replace(replace_key, rename_key) + update_state_dict_inplace(original_state_dict, key, new_key) + + for key in list(original_state_dict.keys()): + for special_key, handler_fn_inplace in VAE_SPECIAL_KEYS_REMAP.items(): + if special_key not in key: + continue + handler_fn_inplace(key, original_state_dict) + + vae.load_state_dict(original_state_dict, strict=True) + return vae + + def get_args(): parser = argparse.ArgumentParser() parser.add_argument( "--transformer_ckpt_path", type=str, default=None, help="Path to original transformer checkpoint" ) + parser.add_argument("--vae_ckpt_path", type=str, default=None, help="Path to original vae checkpoint") + parser.add_argument( + "--text_encoder_cache_dir", type=str, default=None, help="Path to text encoder cache directory" + ) + parser.add_argument("--save_pipeline", action="store_true") parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved") parser.add_argument("--dtype", default="fp32", help="Torch dtype to save the model in.") return parser.parse_args() @@ -83,9 +152,36 @@ def get_args(): transformer = None dtype = DTYPE_MAPPING[args.dtype] + variant = VARIANT_MAPPING[args.dtype] + + if args.save_pipeline: + assert args.transformer_ckpt_path is not None and args.vae_ckpt_path is not None if args.transformer_ckpt_path is not None: transformer: LTXTransformer3DModel = convert_transformer(args.transformer_ckpt_path, dtype) - - variant = VARIANT_MAPPING[args.dtype] - transformer.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB") + if not args.save_piipeline: + transformer.save_pretrained( + args.output_path, safe_serialization=True, max_shard_size="5GB", variant=variant + ) + + if args.vae_ckpt_path is not None: + vae: AutoencoderKLLTX = convert_vae(args.vae_ckpt_path, dtype) + if not args.save_pipeline: + vae.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB", variant=variant) + + if args.save_pipeline: + text_encoder_id = "google/t5-v1_1-xxl" + tokenizer = T5Tokenizer.from_pretrained(text_encoder_id, model_max_length=TOKENIZER_MAX_LENGTH) + text_encoder = T5EncoderModel.from_pretrained(text_encoder_id, cache_dir=args.text_encoder_cache_dir) + + if args.typecast_text_encoder: + text_encoder = text_encoder.to(dtype=dtype) + + # Apparently, the conversion does not work anymore without this :shrug: + for param in text_encoder.parameters(): + param.data = param.data.contiguous() + + scheduler = FlowMatchEulerDiscreteScheduler( + shift=0.1, + use_dynamic_shifting=True, + ) diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_ltx.py b/src/diffusers/models/autoencoders/autoencoder_kl_ltx.py index 775d58def0da..9ca76afe4911 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_ltx.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_ltx.py @@ -20,27 +20,141 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...utils.accelerate_utils import apply_forward_hook +from ..activations import get_activation from ..modeling_outputs import AutoencoderKLOutput from ..modeling_utils import ModelMixin -from .autoencoder_kl_cogvideox import CogVideoXCausalConv3d, CogVideoXMidBlock3D, CogVideoXResnetBlock3D, _get_norm +from ..normalization import LayerNormNd, RMSNormNd from .vae import DecoderOutput, DiagonalGaussianDistribution -# {'_class_name': 'CausalVideoAutoencoder', 'dims': 3, 'in_channels': 3, 'out_channels': 3, 'latent_channels': 128, -# 'blocks': [ -# ['res_x', 4], ['compress_all', 1], ['res_x_y', 1], -# ['res_x', 3], ['compress_all', 1], ['res_x_y', 1], -# ['res_x', 3], ['compress_all', 1], -# ['res_x', 3], -# ['res_x', 4]], -# 'scaling_factor': 1.0, 'norm_layer': 'pixel_norm', 'patch_size': 4, 'latent_log_var': 'uniform', 'use_quant_conv': False, 'causal_decoder': False} +# Adapted from diffusers.models.autoencoders.autoencoder_kl_cogvideox.CogVideoXCausalConv3d +class LTXCausalConv3d(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: Union[int, Tuple[int, int, int]] = 3, + stride: Union[int, Tuple[int, int, int]] = 1, + dilation: Union[int, Tuple[int, int, int]] = 1, + groups: int = 1, + padding_mode: str = "zeros", + is_causal: bool = True, + ): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.is_causal = is_causal + self.kernel_size = kernel_size if isinstance(kernel_size, tuple) else (kernel_size, kernel_size, kernel_size) + + dilation = dilation if isinstance(dilation, tuple) else (dilation, 1, 1) + stride = stride if isinstance(stride, tuple) else (stride, stride, stride) + height_pad = self.kernel_size[1] // 2 + width_pad = self.kernel_size[2] // 2 + padding = (0, height_pad, width_pad) + + self.conv = nn.Conv3d( + in_channels, + out_channels, + self.kernel_size, + stride=stride, + dilation=dilation, + padding=padding, + padding_mode=padding_mode, + groups=groups, + ) -class LTXDownsampler3D(CogVideoXCausalConv3d): - pass + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + print(hidden_states.shape) + time_kernel_size = self.kernel_size[0] + if self.is_causal: + pad_left = hidden_states[:, :, :1, :, :].repeat((1, 1, time_kernel_size - 1, 1, 1)) + hidden_states = torch.concatenate([pad_left, hidden_states], dim=2) + else: + pad_left = hidden_states[:, :, :1, :, :].repeat((1, 1, (time_kernel_size - 1) // 2, 1, 1)) + pad_right = hidden_states[:, :, -1:, :, :].repeat((1, 1, (time_kernel_size - 1) // 2, 1, 1)) + hidden_states = torch.concatenate([pad_left, hidden_states, pad_right], dim=2) + + hidden_states = self.conv(hidden_states) + return hidden_states + + +# Adapted from diffusers.models.autoencoders.autoencoder_kl_cogvideox.CogVideoXResnetBlock3d +class LTXResnetBlock3d(nn.Module): + r""" + A 3D ResNet block used in the LTX model. + + Args: + in_channels (`int`): + Number of input channels. + out_channels (`int`, *optional*): + Number of output channels. If None, defaults to `in_channels`. + dropout (`float`, defaults to `0.0`): + Dropout rate. + eps (`float`, defaults to `1e-6`): + Epsilon value for normalization layers. + elementwise_affine (`bool`, defaults to `False`): + Whether to enable elementwise affinity in the normalization layers. + non_linearity (`str`, defaults to `"swish"`): + Activation function to use. + conv_shortcut (bool, defaults to `False`): + Whether or not to use a convolution shortcut. + """ -class LTXUpsampler3D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: Optional[int] = None, + dropout: float = 0.0, + eps: float = 1e-6, + elementwise_affine: bool = False, + non_linearity: str = "swish", + ): + super().__init__() + + out_channels = out_channels or in_channels + + self.nonlinearity = get_activation(non_linearity) + + self.norm1 = RMSNormNd(dim=in_channels, eps=eps, elementwise_affine=elementwise_affine, channel_dim=1) + self.conv1 = LTXCausalConv3d(in_channels=in_channels, out_channels=out_channels, kernel_size=3) + + self.norm2 = RMSNormNd(dim=out_channels, eps=eps, elementwise_affine=elementwise_affine, channel_dim=1) + self.dropout = nn.Dropout(dropout) + self.conv2 = LTXCausalConv3d(in_channels=out_channels, out_channels=out_channels, kernel_size=3) + + self.norm3 = None + self.conv_shortcut = None + if in_channels != out_channels: + self.norm3 = LayerNormNd(in_channels, eps=eps, elementwise_affine=True, bias=True, channel_dim=1) + self.conv_shortcut = LTXCausalConv3d( + in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1 + ) + + def forward(self, inputs: torch.Tensor) -> torch.Tensor: + hidden_states = inputs + + hidden_states = self.norm1(hidden_states) + hidden_states = self.nonlinearity(hidden_states) + hidden_states = self.conv1(hidden_states) + + hidden_states = self.norm2(hidden_states) + hidden_states = self.nonlinearity(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.conv2(hidden_states) + + if self.norm3 is not None: + inputs = self.norm3(inputs) + + if self.conv_shortcut is not None: + inputs = self.conv_shortcut(inputs) + + hidden_states = hidden_states + inputs + return hidden_states + + +class LTXUpsampler3d(nn.Module): def __init__( self, in_channels: int, @@ -52,7 +166,7 @@ def __init__( out_channels = in_channels * stride[0] * stride[1] * stride[2] - self.conv = CogVideoXCausalConv3d( + self.conv = LTXCausalConv3d( in_channels=in_channels, out_channels=out_channels, kernel_size=3, @@ -67,8 +181,8 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: batch_size, -1, self.stride[0], self.stride[1], self.stride[2], num_frames, height, width ) hidden_states = hidden_states.permute(0, 1, 5, 2, 6, 3, 7, 4).flatten(6, 7).flatten(4, 5).flatten(2, 3) - hidden_states = hidden_states[:, :, self.stride[0] - 1 :] + return hidden_states @@ -91,13 +205,11 @@ class LTXDownBlock3D(nn.Module): Epsilon value for normalization layers. resnet_act_fn (`str`, defaults to `"swish"`): Activation function to use. - resnet_groups (`int`, defaults to `32`): - Number of groups to separate the channels into for group normalization. add_downsample (`bool`, defaults to `True`): Whether or not to use a downsampling layer. If not used, output dimension would be same as input dimension. compress_time (`bool`, defaults to `False`): Whether or not to downsample across temporal dimension. - pad_mode (str, defaults to `"first"`): + padding_mode (str, defaults to `"zeros"`): Padding mode. """ @@ -107,15 +219,11 @@ def __init__( self, in_channels: int, out_channels: Optional[int] = None, - temb_channels: Optional[int] = None, dropout: float = 0.0, num_layers: int = 1, - resnet_norm_type: str = "rms_norm", resnet_eps: float = 1e-6, resnet_act_fn: str = "swish", - resnet_groups: int = 32, spatio_temporal_scale: bool = True, - pad_mode: str = "first", ): super().__init__() @@ -124,16 +232,12 @@ def __init__( resnets = [] for _ in range(num_layers): resnets.append( - CogVideoXResnetBlock3D( + LTXResnetBlock3d( in_channels=in_channels, out_channels=in_channels, dropout=dropout, - temb_channels=temb_channels, - norm_type=resnet_norm_type, - groups=resnet_groups, eps=resnet_eps, non_linearity=resnet_act_fn, - pad_mode=pad_mode, ) ) self.resnets = nn.ModuleList(resnets) @@ -141,29 +245,22 @@ def __init__( self.downsamplers = None if spatio_temporal_scale: self.downsamplers = nn.ModuleList( - [LTXDownsampler3D(in_channels=in_channels, out_channels=in_channels, kernel_size=3, stride=(2, 2, 2))] + [LTXCausalConv3d(in_channels=in_channels, out_channels=in_channels, kernel_size=3, stride=(2, 2, 2))] ) self.conv_out = None if in_channels != out_channels: - self.conv_out = CogVideoXResnetBlock3D( + self.conv_out = LTXResnetBlock3d( in_channels=in_channels, out_channels=out_channels, dropout=dropout, - temb_channels=temb_channels, - norm_type=resnet_norm_type, - final_norm_type="layer_norm", - groups=resnet_groups, eps=resnet_eps, non_linearity=resnet_act_fn, ) self.gradient_checkpointing = False - def forward( - self, - hidden_states: torch.Tensor, - ) -> torch.Tensor: + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: r"""Forward method of the `LTXDownBlock3D` class.""" for i, resnet in enumerate(self.resnets): @@ -175,21 +272,85 @@ def create_forward(*inputs): return create_forward - hidden_states, _ = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states) + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states) else: - hidden_states, _ = resnet(hidden_states) + hidden_states = resnet(hidden_states) if self.downsamplers is not None: for downsampler in self.downsamplers: - hidden_states, _ = downsampler(hidden_states) + hidden_states = downsampler(hidden_states) if self.conv_out is not None: - hidden_states, _ = self.conv_out(hidden_states) + hidden_states = self.conv_out(hidden_states) return hidden_states -class LTXUpBlock3D(nn.Module): +# Adapted from diffusers.models.autoencoders.autoencoder_kl_cogvideox.CogVideoMidBlock3d +class LTXMidBlock3d(nn.Module): + r""" + A middle block used in the LTX model. + + Args: + in_channels (`int`): + Number of input channels. + dropout (`float`, defaults to `0.0`): + Dropout rate. + num_layers (`int`, defaults to `1`): + Number of resnet layers. + resnet_eps (`float`, defaults to `1e-6`): + Epsilon value for normalization layers. + resnet_act_fn (`str`, defaults to `"swish"`): + Activation function to use. + """ + + _supports_gradient_checkpointing = True + + def __init__( + self, + in_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_act_fn: str = "swish", + ): + super().__init__() + + resnets = [] + for _ in range(num_layers): + resnets.append( + LTXResnetBlock3d( + in_channels=in_channels, + out_channels=in_channels, + dropout=dropout, + eps=resnet_eps, + non_linearity=resnet_act_fn, + ) + ) + self.resnets = nn.ModuleList(resnets) + + self.gradient_checkpointing = False + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + r"""Forward method of the `LTXMidBlock3D` class.""" + + for i, resnet in enumerate(self.resnets): + if torch.is_grad_enabled() and self.gradient_checkpointing: + + def create_custom_forward(module): + def create_forward(*inputs): + return module(*inputs) + + return create_forward + + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states) + else: + hidden_states = resnet(hidden_states) + + return hidden_states + + +class LTXUpBlock3d(nn.Module): r""" Up block used in the LTX model. @@ -198,8 +359,6 @@ class LTXUpBlock3D(nn.Module): Number of input channels. out_channels (`int`, *optional*): Number of output channels. If None, defaults to `in_channels`. - temb_channels (`int`, defaults to `512`): - Number of time embedding channels. num_layers (`int`, defaults to `1`): Number of resnet layers. dropout (`float`, defaults to `0.0`): @@ -208,14 +367,8 @@ class LTXUpBlock3D(nn.Module): Epsilon value for normalization layers. resnet_act_fn (`str`, defaults to `"swish"`): Activation function to use. - resnet_groups (`int`, defaults to `32`): - Number of groups to separate the channels into for group normalization. - add_downsample (`bool`, defaults to `True`): - Whether or not to use a downsampling layer. If not used, output dimension would be same as input dimension. - compress_time (`bool`, defaults to `False`): - Whether or not to downsample across temporal dimension. - pad_mode (str, defaults to `"first"`): - Padding mode. + spatio_temporal_scale (`bool`, defaults to `True`): + Whether or not to use a upsampling layer. If not used, output dimension would be same as input dimension. """ _supports_gradient_checkpointing = True @@ -224,15 +377,11 @@ def __init__( self, in_channels: int, out_channels: Optional[int] = None, - temb_channels: Optional[int] = None, - dropout: float = 0.0, num_layers: int = 1, - resnet_norm_type: str = "rms_norm", + dropout: float = 0.0, resnet_eps: float = 1e-6, resnet_act_fn: str = "swish", - resnet_groups: int = 32, spatio_temporal_scale: bool = True, - pad_mode: str = "first", ): super().__init__() @@ -240,35 +389,27 @@ def __init__( self.conv_in = None if in_channels != out_channels: - self.conv_in = CogVideoXResnetBlock3D( + self.conv_in = LTXResnetBlock3d( in_channels=in_channels, out_channels=out_channels, dropout=dropout, - temb_channels=temb_channels, - norm_type=resnet_norm_type, - final_norm_type="layer_norm", - groups=resnet_groups, eps=resnet_eps, non_linearity=resnet_act_fn, ) self.upsamplers = None if spatio_temporal_scale: - self.upsamplers = nn.ModuleList([LTXUpsampler3D(out_channels, stride=(2, 2, 2))]) + self.upsamplers = nn.ModuleList([LTXUpsampler3d(out_channels, stride=(2, 2, 2))]) resnets = [] for _ in range(num_layers): resnets.append( - CogVideoXResnetBlock3D( + LTXResnetBlock3d( in_channels=out_channels, out_channels=out_channels, dropout=dropout, - temb_channels=temb_channels, - norm_type=resnet_norm_type, - groups=resnet_groups, eps=resnet_eps, non_linearity=resnet_act_fn, - pad_mode=pad_mode, ) ) self.resnets = nn.ModuleList(resnets) @@ -279,10 +420,9 @@ def forward( self, hidden_states: torch.Tensor, ) -> torch.Tensor: - r"""Forward method of the `LTXDownBlock3D` class.""" - + print("in up block", hidden_states.shape) if self.conv_in is not None: - hidden_states, _ = self.conv_in(hidden_states) + hidden_states = self.conv_in(hidden_states) if self.upsamplers is not None: for upsampler in self.upsamplers: @@ -297,14 +437,14 @@ def create_forward(*inputs): return create_forward - hidden_states, _ = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states) + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states) else: - hidden_states, _ = resnet(hidden_states) + hidden_states = resnet(hidden_states) return hidden_states -class LTXEncoder3D(nn.Module): +class LTXEncoder3d(nn.Module): r""" The `LTXEncoder3D` layer of a variational autoencoder that encodes input video samples to its latent representation. @@ -322,8 +462,6 @@ def __init__( layers_per_block: Tuple[int, ...] = (4, 3, 3, 3, 4), patch_size: int = 4, patch_size_t: int = 1, - resnet_norm_type: str = "rms_norm", - resnet_groups: int = 32, resnet_norm_eps: float = 1e-6, ): super().__init__() @@ -334,11 +472,8 @@ def __init__( output_channel = block_out_channels[0] - self.conv_in = CogVideoXCausalConv3d( - in_channels=self.in_channels, - out_channels=output_channel, - kernel_size=3, - stride=1, + self.conv_in = LTXCausalConv3d( + in_channels=self.in_channels, out_channels=output_channel, kernel_size=3, stride=1 ) # down blocks @@ -351,34 +486,23 @@ def __init__( down_block = LTXDownBlock3D( in_channels=input_channel, out_channels=output_channel, - temb_channels=None, num_layers=layers_per_block[i], - resnet_norm_type=resnet_norm_type, resnet_eps=resnet_norm_eps, - resnet_groups=resnet_groups, spatio_temporal_scale=spatio_temporal_scaling[i], ) self.down_blocks.append(down_block) # mid block - self.mid_block = CogVideoXMidBlock3D( - in_channels=output_channel, - temb_channels=None, - num_layers=layers_per_block[-1], - norm_type=resnet_norm_type, - resnet_eps=resnet_norm_eps, - resnet_groups=resnet_groups, + self.mid_block = LTXMidBlock3d( + in_channels=output_channel, num_layers=layers_per_block[-1], resnet_eps=resnet_norm_eps ) # out - self.norm_out = _get_norm(resnet_norm_type, output_channel, eps=1e-6, elementwise_affine=False, channel_dim=1) + self.norm_out = RMSNormNd(dim=out_channels, eps=1e-6, elementwise_affine=False, channel_dim=1) self.conv_act = nn.SiLU() - self.conv_out = CogVideoXCausalConv3d( - in_channels=output_channel, - out_channels=out_channels + 1, - kernel_size=3, - stride=1, + self.conv_out = LTXCausalConv3d( + in_channels=output_channel, out_channels=out_channels + 1, kernel_size=3, stride=1 ) self.gradient_checkpointing = False @@ -397,7 +521,8 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states = hidden_states.reshape( batch_size, -1, post_patch_num_frames, p_t, post_patch_height, p, post_patch_width, p ) - hidden_states, _ = self.conv_in(hidden_states) + hidden_states = hidden_states.permute(0, 1, 3, 5, 7, 2, 4, 6).flatten(1, 4) + hidden_states = self.conv_in(hidden_states) if torch.is_grad_enabled() and self.gradient_checkpointing: @@ -408,17 +533,18 @@ def create_forward(*inputs): return create_forward for down_block in self.down_blocks: - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(down_block), - hidden_states, - ) + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(down_block), hidden_states) + + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), hidden_states) else: for down_block in self.down_blocks: hidden_states = down_block(hidden_states) + hidden_states = self.mid_block(hidden_states) + hidden_states = self.norm_out(hidden_states) hidden_states = self.conv_act(hidden_states) - hidden_states, _ = self.conv_out(hidden_states) + hidden_states = self.conv_out(hidden_states) last_channel = hidden_states[:, -1:] last_channel = last_channel.repeat(1, hidden_states.size(1) - 2, 1, 1, 1) @@ -427,9 +553,9 @@ def create_forward(*inputs): return hidden_states -class LTXDecoder3D(nn.Module): +class LTXDecoder3d(nn.Module): r""" - The `LTXDecoder3D` layer of a variational autoencoder that decodes its latent representation into an output sample. + The `LTXDecoder3d` layer of a variational autoencoder that decodes its latent representation into an output sample. Args: TODO(aryan) @@ -444,8 +570,6 @@ def __init__( layers_per_block: Tuple[int, ...] = (4, 3, 3, 3, 4), patch_size: int = 4, patch_size_t: int = 1, - resnet_norm_type: str = "rms_norm", - resnet_groups: int = 32, resnet_norm_eps: float = 1e-6, ) -> None: super().__init__() @@ -457,23 +581,14 @@ def __init__( block_out_channels = tuple(reversed(block_out_channels)) spatio_temporal_scaling = tuple(reversed(spatio_temporal_scaling)) layers_per_block = tuple(reversed(layers_per_block)) - output_channel = block_out_channels[0] - self.conv_in = CogVideoXCausalConv3d( - in_channels=in_channels, - out_channels=output_channel, - kernel_size=3, - stride=1, - ) + self.conv_in = LTXCausalConv3d(in_channels=in_channels, out_channels=output_channel, kernel_size=3, stride=1) - self.mid_block = CogVideoXMidBlock3D( + self.mid_block = LTXMidBlock3d( in_channels=output_channel, - temb_channels=None, num_layers=layers_per_block[0], - norm_type=resnet_norm_type, resnet_eps=resnet_norm_eps, - resnet_groups=resnet_groups, ) # up blocks @@ -483,23 +598,20 @@ def __init__( input_channel = output_channel output_channel = block_out_channels[i] - up_block = LTXUpBlock3D( + up_block = LTXUpBlock3d( in_channels=input_channel, out_channels=output_channel, - temb_channels=None, num_layers=layers_per_block[i + 1], - resnet_norm_type=resnet_norm_type, resnet_eps=resnet_norm_eps, - resnet_groups=resnet_groups, spatio_temporal_scale=spatio_temporal_scaling[i], ) self.up_blocks.append(up_block) # out - self.norm_out = _get_norm(resnet_norm_type, output_channel, eps=1e-6, elementwise_affine=False, channel_dim=1) + self.norm_out = RMSNormNd(dim=out_channels, eps=1e-6, elementwise_affine=False, channel_dim=1) self.conv_act = nn.SiLU() - self.conv_out = CogVideoXCausalConv3d( + self.conv_out = LTXCausalConv3d( in_channels=output_channel, out_channels=self.out_channels, kernel_size=3, @@ -509,7 +621,7 @@ def __init__( self.gradient_checkpointing = False def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - hidden_states, _ = self.conv_in(hidden_states) + hidden_states = self.conv_in(hidden_states) if torch.is_grad_enabled() and self.gradient_checkpointing: @@ -519,25 +631,26 @@ def create_forward(*inputs): return create_forward + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), hidden_states) + for up_block in self.up_blocks: - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(up_block), - hidden_states, - ) + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(up_block), hidden_states) else: + hidden_states = self.mid_block(hidden_states) + for up_block in self.up_blocks: hidden_states = up_block(hidden_states) hidden_states = self.norm_out(hidden_states) hidden_states = self.conv_act(hidden_states) - hidden_states, _ = self.conv_out(hidden_states, causal=self.causal) + hidden_states = self.conv_out(hidden_states) p = self.patch_size p_t = self.patch_size_t batch_size, num_channels, num_frames, height, width = hidden_states.shape hidden_states = hidden_states.reshape(batch_size, -1, p_t, p, p, num_frames, height, width) - hidden_states = hidden_states.permute(0, 1, 5, 2, 6, 3, 7, 4).flatten(6, 7).flatten(4, 5).flatten(2, 2) + hidden_states = hidden_states.permute(0, 1, 5, 2, 6, 3, 7, 4).flatten(6, 7).flatten(4, 5).flatten(2, 3) return hidden_states @@ -567,14 +680,12 @@ def __init__( layers_per_block: Tuple[int, ...] = (4, 3, 3, 3, 4), patch_size: int = 4, patch_size_t: int = 1, - resnet_norm_type: str = "rms_norm", - resnet_groups: int = 32, resnet_norm_eps: float = 1e-6, scaling_factor: float = 1.0, ) -> None: super().__init__() - self.encoder = LTXEncoder3D( + self.encoder = LTXEncoder3d( in_channels=in_channels, out_channels=latent_channels, block_out_channels=block_out_channels, @@ -582,11 +693,9 @@ def __init__( layers_per_block=layers_per_block, patch_size=patch_size, patch_size_t=patch_size_t, - resnet_norm_type=resnet_norm_type, - resnet_groups=resnet_groups, resnet_norm_eps=resnet_norm_eps, ) - self.decoder = LTXDecoder3D( + self.decoder = LTXDecoder3d( in_channels=latent_channels, out_channels=out_channels, block_out_channels=block_out_channels, @@ -594,11 +703,14 @@ def __init__( layers_per_block=layers_per_block, patch_size=patch_size, patch_size_t=patch_size_t, - resnet_norm_type=resnet_norm_type, - resnet_groups=resnet_groups, resnet_norm_eps=resnet_norm_eps, ) + latent_means = torch.zeros((latent_channels,), requires_grad=False) + latent_stds = torch.zeros((latent_channels,), requires_grad=False) + self.register_buffer("latent_means", latent_means, persistent=True) + self.register_buffer("latent_stds", latent_stds, persistent=True) + self.spatial_compression_ratio = patch_size * patch_size self.temporal_compression_ratio = patch_size_t @@ -617,21 +729,21 @@ def __init__( self.use_framewise_decoding = True # This can be configured based on the amount of GPU memory available. - # `12` for sample frames and `2` for latent frames are sensible defaults for consumer GPUs. + # `16` for sample frames and `2` for latent frames are sensible defaults for consumer GPUs. # Setting it to higher values results in higher memory usage. - self.num_sample_frames_batch_size = 8 + self.num_sample_frames_batch_size = 16 self.num_latent_frames_batch_size = 2 # The minimal tile height and width for spatial tiling to be used - self.tile_sample_min_height = 256 - self.tile_sample_min_width = 256 + self.tile_sample_min_height = 512 + self.tile_sample_min_width = 512 # The minimal distance between two spatial tiles - self.tile_sample_stride_height = 192 - self.tile_sample_stride_width = 192 + self.tile_sample_stride_height = 448 + self.tile_sample_stride_width = 448 def _set_gradient_checkpointing(self, module, value=False): - if isinstance(module, (LTXEncoder3D, LTXDecoder3D)): + if isinstance(module, (LTXEncoder3d, LTXDecoder3d)): module.gradient_checkpointing = value def enable_tiling( @@ -699,7 +811,7 @@ def _encode(self, x: torch.Tensor) -> torch.Tensor: enc.append(x_intermediate) enc = torch.cat(enc, dim=2) else: - enc, _ = self.encoder(x) + enc = self.encoder(x) return enc @@ -724,7 +836,7 @@ def encode( h = torch.cat(encoded_slices) else: h = self._encode(x) - + print("h:", h.shape) posterior = DiagonalGaussianDistribution(h) if not return_dict: diff --git a/src/diffusers/models/normalization.py b/src/diffusers/models/normalization.py index 4e96086ada4e..500a19c51588 100644 --- a/src/diffusers/models/normalization.py +++ b/src/diffusers/models/normalization.py @@ -14,7 +14,7 @@ # limitations under the License. import numbers -from typing import Dict, Optional, Tuple +from typing import Dict, List, Optional, Tuple, Union import torch import torch.nn as nn @@ -570,3 +570,69 @@ def __init__(self, p: int = 2, dim: int = -1, eps: float = 1e-12): def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: return F.normalize(hidden_states, p=self.p, dim=self.dim, eps=self.eps) + + +class LayerNormNd(nn.LayerNorm): + def __init__( + self, + normalized_shape: Union[int, List[int], Tuple[int], torch.Size], + eps: float = 1e-5, + elementwise_affine: bool = True, + bias: bool = True, + device=None, + dtype=None, + channel_dim: int = -1, + ) -> None: + super().__init__( + normalized_shape=normalized_shape, + eps=eps, + elementwise_affine=elementwise_affine, + bias=bias, + device=device, + dtype=dtype, + ) + + self.channel_dim = channel_dim + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + if self.channel_dim != -1: + hidden_states = hidden_states.movedim(self.channel_dim, -1) + hidden_states = super().forward(hidden_states) + hidden_states = hidden_states.movedim(-1, self.channel_dim) + else: + hidden_states = super().forward(hidden_states) + + return hidden_states + + def extra_repr(self) -> str: + return f"{super().extra_repr()}, channel_dim={self.channel_dim}" + + +class RMSNormNd(RMSNorm): + def __init__( + self, + dim: int, + eps: float, + elementwise_affine: bool = True, + channel_dim: int = -1, + ) -> None: + super().__init__( + dim=dim, + eps=eps, + elementwise_affine=elementwise_affine, + ) + + self.channel_dim = channel_dim + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + if self.channel_dim != -1: + hidden_states = hidden_states.movedim(self.channel_dim, -1) + hidden_states = super().forward(hidden_states) + hidden_states = hidden_states.movedim(-1, self.channel_dim) + else: + hidden_states = super().forward(hidden_states) + + return hidden_states + + def extra_repr(self): + return f"{super().extra_repr()}, channel_dim={self.channel_dim}" diff --git a/src/diffusers/pipelines/ltx/pipeline_ltx.py b/src/diffusers/pipelines/ltx/pipeline_ltx.py new file mode 100644 index 000000000000..e69de29bb2d1 From 11d2d91f01d637335c9276b3ef37a2061976d118 Mon Sep 17 00:00:00 2001 From: Aryan Date: Wed, 27 Nov 2024 04:13:14 +0100 Subject: [PATCH 11/51] update --- src/diffusers/models/autoencoders/autoencoder_kl_ltx.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_ltx.py b/src/diffusers/models/autoencoders/autoencoder_kl_ltx.py index 9ca76afe4911..7f64b6b05da8 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_ltx.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_ltx.py @@ -176,7 +176,7 @@ def __init__( def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: batch_size, num_channels, num_frames, height, width = hidden_states.shape - hidden_states, _ = self.conv(hidden_states) + hidden_states = self.conv(hidden_states) hidden_states = hidden_states.reshape( batch_size, -1, self.stride[0], self.stride[1], self.stride[2], num_frames, height, width ) From d320105bf197caef43842b5358a8d938c5f25d2b Mon Sep 17 00:00:00 2001 From: Aryan Date: Wed, 27 Nov 2024 09:15:04 +0100 Subject: [PATCH 12/51] match vae --- .../models/autoencoders/autoencoder_kl_ltx.py | 61 ++++++++++++------- 1 file changed, 40 insertions(+), 21 deletions(-) diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_ltx.py b/src/diffusers/models/autoencoders/autoencoder_kl_ltx.py index 7f64b6b05da8..d8633b0e0750 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_ltx.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_ltx.py @@ -59,13 +59,12 @@ def __init__( self.kernel_size, stride=stride, dilation=dilation, + groups=groups, padding=padding, padding_mode=padding_mode, - groups=groups, ) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - print(hidden_states.shape) time_kernel_size = self.kernel_size[0] if self.is_causal: @@ -110,6 +109,7 @@ def __init__( eps: float = 1e-6, elementwise_affine: bool = False, non_linearity: str = "swish", + is_causal: bool = True, ): super().__init__() @@ -117,19 +117,19 @@ def __init__( self.nonlinearity = get_activation(non_linearity) - self.norm1 = RMSNormNd(dim=in_channels, eps=eps, elementwise_affine=elementwise_affine, channel_dim=1) - self.conv1 = LTXCausalConv3d(in_channels=in_channels, out_channels=out_channels, kernel_size=3) + self.norm1 = RMSNormNd(dim=in_channels, eps=1e-8, elementwise_affine=elementwise_affine, channel_dim=1) + self.conv1 = LTXCausalConv3d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, is_causal=is_causal) - self.norm2 = RMSNormNd(dim=out_channels, eps=eps, elementwise_affine=elementwise_affine, channel_dim=1) + self.norm2 = RMSNormNd(dim=out_channels, eps=1e-8, elementwise_affine=elementwise_affine, channel_dim=1) self.dropout = nn.Dropout(dropout) - self.conv2 = LTXCausalConv3d(in_channels=out_channels, out_channels=out_channels, kernel_size=3) + self.conv2 = LTXCausalConv3d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, is_causal=is_causal) self.norm3 = None self.conv_shortcut = None if in_channels != out_channels: self.norm3 = LayerNormNd(in_channels, eps=eps, elementwise_affine=True, bias=True, channel_dim=1) self.conv_shortcut = LTXCausalConv3d( - in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1 + in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1, is_causal=is_causal ) def forward(self, inputs: torch.Tensor) -> torch.Tensor: @@ -159,6 +159,7 @@ def __init__( self, in_channels: int, stride: Union[int, Tuple[int, int, int]] = 1, + is_causal: bool = True, ) -> None: super().__init__() @@ -171,6 +172,7 @@ def __init__( out_channels=out_channels, kernel_size=3, stride=1, + is_causal=is_causal, ) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: @@ -224,6 +226,7 @@ def __init__( resnet_eps: float = 1e-6, resnet_act_fn: str = "swish", spatio_temporal_scale: bool = True, + is_causal: bool = True, ): super().__init__() @@ -238,6 +241,7 @@ def __init__( dropout=dropout, eps=resnet_eps, non_linearity=resnet_act_fn, + is_causal=is_causal, ) ) self.resnets = nn.ModuleList(resnets) @@ -245,7 +249,7 @@ def __init__( self.downsamplers = None if spatio_temporal_scale: self.downsamplers = nn.ModuleList( - [LTXCausalConv3d(in_channels=in_channels, out_channels=in_channels, kernel_size=3, stride=(2, 2, 2))] + [LTXCausalConv3d(in_channels=in_channels, out_channels=in_channels, kernel_size=3, stride=(2, 2, 2), is_causal=is_causal)] ) self.conv_out = None @@ -256,6 +260,7 @@ def __init__( dropout=dropout, eps=resnet_eps, non_linearity=resnet_act_fn, + is_causal=is_causal ) self.gradient_checkpointing = False @@ -313,6 +318,7 @@ def __init__( num_layers: int = 1, resnet_eps: float = 1e-6, resnet_act_fn: str = "swish", + is_causal: bool = True, ): super().__init__() @@ -325,6 +331,7 @@ def __init__( dropout=dropout, eps=resnet_eps, non_linearity=resnet_act_fn, + is_causal=is_causal, ) ) self.resnets = nn.ModuleList(resnets) @@ -382,6 +389,7 @@ def __init__( resnet_eps: float = 1e-6, resnet_act_fn: str = "swish", spatio_temporal_scale: bool = True, + is_causal: bool = True, ): super().__init__() @@ -395,11 +403,12 @@ def __init__( dropout=dropout, eps=resnet_eps, non_linearity=resnet_act_fn, + is_causal=is_causal, ) self.upsamplers = None if spatio_temporal_scale: - self.upsamplers = nn.ModuleList([LTXUpsampler3d(out_channels, stride=(2, 2, 2))]) + self.upsamplers = nn.ModuleList([LTXUpsampler3d(out_channels, stride=(2, 2, 2), is_causal=is_causal)]) resnets = [] for _ in range(num_layers): @@ -410,6 +419,7 @@ def __init__( dropout=dropout, eps=resnet_eps, non_linearity=resnet_act_fn, + is_causal=is_causal, ) ) self.resnets = nn.ModuleList(resnets) @@ -420,7 +430,6 @@ def forward( self, hidden_states: torch.Tensor, ) -> torch.Tensor: - print("in up block", hidden_states.shape) if self.conv_in is not None: hidden_states = self.conv_in(hidden_states) @@ -463,6 +472,7 @@ def __init__( patch_size: int = 4, patch_size_t: int = 1, resnet_norm_eps: float = 1e-6, + is_causal: bool = True, ): super().__init__() @@ -473,7 +483,7 @@ def __init__( output_channel = block_out_channels[0] self.conv_in = LTXCausalConv3d( - in_channels=self.in_channels, out_channels=output_channel, kernel_size=3, stride=1 + in_channels=self.in_channels, out_channels=output_channel, kernel_size=3, stride=1, is_causal=is_causal, ) # down blocks @@ -489,20 +499,21 @@ def __init__( num_layers=layers_per_block[i], resnet_eps=resnet_norm_eps, spatio_temporal_scale=spatio_temporal_scaling[i], + is_causal=is_causal, ) self.down_blocks.append(down_block) # mid block self.mid_block = LTXMidBlock3d( - in_channels=output_channel, num_layers=layers_per_block[-1], resnet_eps=resnet_norm_eps + in_channels=output_channel, num_layers=layers_per_block[-1], resnet_eps=resnet_norm_eps, is_causal=is_causal ) # out - self.norm_out = RMSNormNd(dim=out_channels, eps=1e-6, elementwise_affine=False, channel_dim=1) + self.norm_out = RMSNormNd(dim=out_channels, eps=1e-8, elementwise_affine=False, channel_dim=1) self.conv_act = nn.SiLU() self.conv_out = LTXCausalConv3d( - in_channels=output_channel, out_channels=out_channels + 1, kernel_size=3, stride=1 + in_channels=output_channel, out_channels=out_channels + 1, kernel_size=3, stride=1, is_causal=is_causal ) self.gradient_checkpointing = False @@ -519,9 +530,10 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: post_patch_width = width // p hidden_states = hidden_states.reshape( - batch_size, -1, post_patch_num_frames, p_t, post_patch_height, p, post_patch_width, p + batch_size, num_channels, post_patch_num_frames, p_t, post_patch_height, p, post_patch_width, p ) - hidden_states = hidden_states.permute(0, 1, 3, 5, 7, 2, 4, 6).flatten(1, 4) + # Thanks for driving me insane with the weird patching order :( + hidden_states = hidden_states.permute(0, 1, 3, 7, 5, 2, 4, 6).flatten(1, 4) hidden_states = self.conv_in(hidden_states) if torch.is_grad_enabled() and self.gradient_checkpointing: @@ -535,7 +547,7 @@ def create_forward(*inputs): for down_block in self.down_blocks: hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(down_block), hidden_states) - hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), hidden_states) + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), hidden_states) else: for down_block in self.down_blocks: hidden_states = down_block(hidden_states) @@ -571,6 +583,7 @@ def __init__( patch_size: int = 4, patch_size_t: int = 1, resnet_norm_eps: float = 1e-6, + is_causal: bool = False, ) -> None: super().__init__() @@ -583,12 +596,13 @@ def __init__( layers_per_block = tuple(reversed(layers_per_block)) output_channel = block_out_channels[0] - self.conv_in = LTXCausalConv3d(in_channels=in_channels, out_channels=output_channel, kernel_size=3, stride=1) + self.conv_in = LTXCausalConv3d(in_channels=in_channels, out_channels=output_channel, kernel_size=3, stride=1, is_causal=is_causal) self.mid_block = LTXMidBlock3d( in_channels=output_channel, num_layers=layers_per_block[0], resnet_eps=resnet_norm_eps, + is_causal=is_causal ) # up blocks @@ -604,18 +618,20 @@ def __init__( num_layers=layers_per_block[i + 1], resnet_eps=resnet_norm_eps, spatio_temporal_scale=spatio_temporal_scaling[i], + is_causal=is_causal ) self.up_blocks.append(up_block) # out - self.norm_out = RMSNormNd(dim=out_channels, eps=1e-6, elementwise_affine=False, channel_dim=1) + self.norm_out = RMSNormNd(dim=out_channels, eps=1e-8, elementwise_affine=False, channel_dim=1) self.conv_act = nn.SiLU() self.conv_out = LTXCausalConv3d( in_channels=output_channel, out_channels=self.out_channels, kernel_size=3, stride=1, + is_causal=is_causal ) self.gradient_checkpointing = False @@ -650,7 +666,7 @@ def create_forward(*inputs): batch_size, num_channels, num_frames, height, width = hidden_states.shape hidden_states = hidden_states.reshape(batch_size, -1, p_t, p, p, num_frames, height, width) - hidden_states = hidden_states.permute(0, 1, 5, 2, 6, 3, 7, 4).flatten(6, 7).flatten(4, 5).flatten(2, 3) + hidden_states = hidden_states.permute(0, 1, 5, 2, 6, 4, 7, 3).flatten(6, 7).flatten(4, 5).flatten(2, 3) return hidden_states @@ -682,6 +698,8 @@ def __init__( patch_size_t: int = 1, resnet_norm_eps: float = 1e-6, scaling_factor: float = 1.0, + encoder_causal: bool = True, + decoder_causal: bool = False, ) -> None: super().__init__() @@ -694,6 +712,7 @@ def __init__( patch_size=patch_size, patch_size_t=patch_size_t, resnet_norm_eps=resnet_norm_eps, + is_causal=encoder_causal, ) self.decoder = LTXDecoder3d( in_channels=latent_channels, @@ -704,6 +723,7 @@ def __init__( patch_size=patch_size, patch_size_t=patch_size_t, resnet_norm_eps=resnet_norm_eps, + is_causal=decoder_causal, ) latent_means = torch.zeros((latent_channels,), requires_grad=False) @@ -836,7 +856,6 @@ def encode( h = torch.cat(encoded_slices) else: h = self._encode(x) - print("h:", h.shape) posterior = DiagonalGaussianDistribution(h) if not return_dict: From 755e29cb6700aa71deadceae7e75157a7518485d Mon Sep 17 00:00:00 2001 From: Aryan Date: Wed, 27 Nov 2024 09:24:56 +0100 Subject: [PATCH 13/51] add docs --- .../models/autoencoders/autoencoder_kl_ltx.py | 96 +++++++++++++++---- 1 file changed, 80 insertions(+), 16 deletions(-) diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_ltx.py b/src/diffusers/models/autoencoders/autoencoder_kl_ltx.py index d8633b0e0750..4ede9529cbef 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_ltx.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_ltx.py @@ -197,8 +197,6 @@ class LTXDownBlock3D(nn.Module): Number of input channels. out_channels (`int`, *optional*): Number of output channels. If None, defaults to `in_channels`. - temb_channels (`int`, defaults to `512`): - Number of time embedding channels. num_layers (`int`, defaults to `1`): Number of resnet layers. dropout (`float`, defaults to `0.0`): @@ -207,12 +205,11 @@ class LTXDownBlock3D(nn.Module): Epsilon value for normalization layers. resnet_act_fn (`str`, defaults to `"swish"`): Activation function to use. - add_downsample (`bool`, defaults to `True`): + spatio_temporal_scale (`bool`, defaults to `True`): Whether or not to use a downsampling layer. If not used, output dimension would be same as input dimension. - compress_time (`bool`, defaults to `False`): Whether or not to downsample across temporal dimension. - padding_mode (str, defaults to `"zeros"`): - Padding mode. + is_causal (`bool`, defaults to `True`): + Whether this layer behaves causally (future frames depend only on past frames) or not. """ _supports_gradient_checkpointing = True @@ -221,8 +218,8 @@ def __init__( self, in_channels: int, out_channels: Optional[int] = None, - dropout: float = 0.0, num_layers: int = 1, + dropout: float = 0.0, resnet_eps: float = 1e-6, resnet_act_fn: str = "swish", spatio_temporal_scale: bool = True, @@ -299,14 +296,16 @@ class LTXMidBlock3d(nn.Module): Args: in_channels (`int`): Number of input channels. - dropout (`float`, defaults to `0.0`): - Dropout rate. num_layers (`int`, defaults to `1`): Number of resnet layers. + dropout (`float`, defaults to `0.0`): + Dropout rate. resnet_eps (`float`, defaults to `1e-6`): Epsilon value for normalization layers. resnet_act_fn (`str`, defaults to `"swish"`): Activation function to use. + is_causal (`bool`, defaults to `True`): + Whether this layer behaves causally (future frames depend only on past frames) or not. """ _supports_gradient_checkpointing = True @@ -314,12 +313,12 @@ class LTXMidBlock3d(nn.Module): def __init__( self, in_channels: int, - dropout: float = 0.0, num_layers: int = 1, + dropout: float = 0.0, resnet_eps: float = 1e-6, resnet_act_fn: str = "swish", is_causal: bool = True, - ): + ) -> None: super().__init__() resnets = [] @@ -375,7 +374,10 @@ class LTXUpBlock3d(nn.Module): resnet_act_fn (`str`, defaults to `"swish"`): Activation function to use. spatio_temporal_scale (`bool`, defaults to `True`): - Whether or not to use a upsampling layer. If not used, output dimension would be same as input dimension. + Whether or not to use a downsampling layer. If not used, output dimension would be same as input dimension. + Whether or not to downsample across temporal dimension. + is_causal (`bool`, defaults to `True`): + Whether this layer behaves causally (future frames depend only on past frames) or not. """ _supports_gradient_checkpointing = True @@ -428,7 +430,7 @@ def __init__( def forward( self, - hidden_states: torch.Tensor, + hidden_states: torch.Tensor ) -> torch.Tensor: if self.conv_in is not None: hidden_states = self.conv_in(hidden_states) @@ -459,7 +461,24 @@ class LTXEncoder3d(nn.Module): representation. Args: - TODO(aryan) + in_channels (`int`): + Number of input channels. + out_channels (`int`): + Number of latent channels. + block_out_channels (`Tuple[int, ...]`, defaults to `(128, 256, 512, 512)`): + The number of output channels for each block. + spatio_temporal_scaling (`Tuple[bool, ...], defaults to `(True, True, True, False)`: + Whether a block should contain spatio-temporal downscaling layers or not. + layers_per_block (`Tuple[int, ...]`, defaults to `(128, 256, 512, 512)`): + The number of layers per block. + patch_size (`int`, defaults to `4`): + The size of spatial patches. + patch_size_t (`int`, defaults to `1`): + The size of temporal patches. + resnet_norm_eps (`float`, defaults to `1e-6`): + Epsilon value for ResNet normalization layers. + is_causal (`bool`, defaults to `True`): + Whether this layer behaves causally (future frames depend only on past frames) or not. """ def __init__( @@ -570,7 +589,24 @@ class LTXDecoder3d(nn.Module): The `LTXDecoder3d` layer of a variational autoencoder that decodes its latent representation into an output sample. Args: - TODO(aryan) + in_channels (`int`): + Number of latent channels. + out_channels (`int`): + Number of output channels. + block_out_channels (`Tuple[int, ...]`, defaults to `(128, 256, 512, 512)`): + The number of output channels for each block. + spatio_temporal_scaling (`Tuple[bool, ...], defaults to `(True, True, True, False)`: + Whether a block should contain spatio-temporal upscaling layers or not. + layers_per_block (`Tuple[int, ...]`, defaults to `(128, 256, 512, 512)`): + The number of layers per block. + patch_size (`int`, defaults to `4`): + The size of spatial patches. + patch_size_t (`int`, defaults to `1`): + The size of temporal patches. + resnet_norm_eps (`float`, defaults to `1e-6`): + Epsilon value for ResNet normalization layers. + is_causal (`bool`, defaults to `False`): + Whether this layer behaves causally (future frames depend only on past frames) or not. """ def __init__( @@ -680,7 +716,35 @@ class AutoencoderKLLTX(ModelMixin, ConfigMixin): for all models (such as downloading or saving). Args: - TODO(aryan) + in_channels (`int`, defaults to `3`): + Number of input channels. + out_channels (`int`, defaults to `3`): + Number of output channels. + latent_channels (`int`, defaults to `128`): + Number of latent channels. + block_out_channels (`Tuple[int, ...]`, defaults to `(128, 256, 512, 512)`): + The number of output channels for each block. + spatio_temporal_scaling (`Tuple[bool, ...], defaults to `(True, True, True, False)`: + Whether a block should contain spatio-temporal downscaling or not. + layers_per_block (`Tuple[int, ...]`, defaults to `(128, 256, 512, 512)`): + The number of layers per block. + patch_size (`int`, defaults to `4`): + The size of spatial patches. + patch_size_t (`int`, defaults to `1`): + The size of temporal patches. + resnet_norm_eps (`float`, defaults to `1e-6`): + Epsilon value for ResNet normalization layers. + scaling_factor (`float`, *optional*, defaults to `1.0`): + The component-wise standard deviation of the trained latent space computed using the first batch of the + training set. This is used to scale the latent space to have unit variance when training the diffusion + model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the + diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1 + / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image + Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper. + encoder_causal (`bool`, defaults to `True`): + Whether the encoder should behave causally (future frames depend only on past frames) or not. + decoder_causal (`bool`, defaults to `False`): + Whether the decoder should behave causally (future frames depend only on past frames) or not. """ _supports_gradient_checkpointing = True From ac95930042a1b8452e5397832a0fca4ce946dcee Mon Sep 17 00:00:00 2001 From: Aryan Date: Wed, 27 Nov 2024 14:53:22 +0100 Subject: [PATCH 14/51] t2v pipeline working; scheduler needs to be checked --- scripts/convert_ltx_to_diffusers.py | 30 +- src/diffusers/__init__.py | 2 + .../models/autoencoders/autoencoder_kl_ltx.py | 69 +- .../models/transformers/transformer_ltx.py | 145 +++- src/diffusers/pipelines/__init__.py | 2 + src/diffusers/pipelines/ltx/__init__.py | 48 ++ src/diffusers/pipelines/ltx/pipeline_ltx.py | 700 ++++++++++++++++++ .../pipelines/ltx/pipeline_output.py | 20 + .../scheduling_flow_match_euler_discrete.py | 10 + 9 files changed, 984 insertions(+), 42 deletions(-) create mode 100644 src/diffusers/pipelines/ltx/__init__.py create mode 100644 src/diffusers/pipelines/ltx/pipeline_output.py diff --git a/scripts/convert_ltx_to_diffusers.py b/scripts/convert_ltx_to_diffusers.py index 1acd16b98233..e801a7a8535a 100644 --- a/scripts/convert_ltx_to_diffusers.py +++ b/scripts/convert_ltx_to_diffusers.py @@ -5,7 +5,7 @@ from safetensors.torch import load_file from transformers import T5EncoderModel, T5Tokenizer -from diffusers import AutoencoderKLLTX, FlowMatchEulerDiscreteScheduler, LTXTransformer3DModel +from diffusers import AutoencoderKLLTX, FlowMatchEulerDiscreteScheduler, LTXPipeline, LTXTransformer3DModel def remove_keys_inplace(key: str, state_dict: Dict[str, Any]): @@ -48,8 +48,8 @@ def remove_keys_inplace(key: str, state_dict: Dict[str, Any]): "conv_shortcut": "conv_shortcut.conv", "res_blocks": "resnets", "norm3.norm": "norm3", - "per_channel_statistics.mean-of-means": "latent_means", - "per_channel_statistics.std-of-means": "latent_stds", + "per_channel_statistics.mean-of-means": "latents_mean", + "per_channel_statistics.std-of-means": "latents_std", } VAE_SPECIAL_KEYS_REMAP = { @@ -128,6 +128,12 @@ def get_args(): parser.add_argument( "--text_encoder_cache_dir", type=str, default=None, help="Path to text encoder cache directory" ) + parser.add_argument( + "--typecast_text_encoder", + action="store_true", + default=False, + help="Whether or not to apply fp16/bf16 precision to text_encoder", + ) parser.add_argument("--save_pipeline", action="store_true") parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved") parser.add_argument("--dtype", default="fp32", help="Torch dtype to save the model in.") @@ -159,7 +165,7 @@ def get_args(): if args.transformer_ckpt_path is not None: transformer: LTXTransformer3DModel = convert_transformer(args.transformer_ckpt_path, dtype) - if not args.save_piipeline: + if not args.save_pipeline: transformer.save_pretrained( args.output_path, safe_serialization=True, max_shard_size="5GB", variant=variant ) @@ -182,6 +188,20 @@ def get_args(): param.data = param.data.contiguous() scheduler = FlowMatchEulerDiscreteScheduler( - shift=0.1, use_dynamic_shifting=True, + base_shift=0.95, + max_shift=2.05, + base_image_seq_len=1024, + max_image_seq_len=4096, + shift_terminal=0.1, ) + + pipe = LTXPipeline( + scheduler=scheduler, + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + transformer=transformer, + ) + + pipe.save_pretrained(args.output_path, safe_serialization=True, variant=variant, max_shard_size="5GB") diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index b4ae36d9c5fa..5db324f22119 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -317,6 +317,7 @@ "LDMTextToImagePipeline", "LEditsPPPipelineStableDiffusion", "LEditsPPPipelineStableDiffusionXL", + "LTXPipeline", "LuminaText2ImgPipeline", "MarigoldDepthPipeline", "MarigoldNormalsPipeline", @@ -789,6 +790,7 @@ LDMTextToImagePipeline, LEditsPPPipelineStableDiffusion, LEditsPPPipelineStableDiffusionXL, + LTXPipeline, LuminaText2ImgPipeline, MarigoldDepthPipeline, MarigoldNormalsPipeline, diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_ltx.py b/src/diffusers/models/autoencoders/autoencoder_kl_ltx.py index 4ede9529cbef..de2a9288afe8 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_ltx.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_ltx.py @@ -118,11 +118,15 @@ def __init__( self.nonlinearity = get_activation(non_linearity) self.norm1 = RMSNormNd(dim=in_channels, eps=1e-8, elementwise_affine=elementwise_affine, channel_dim=1) - self.conv1 = LTXCausalConv3d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, is_causal=is_causal) + self.conv1 = LTXCausalConv3d( + in_channels=in_channels, out_channels=out_channels, kernel_size=3, is_causal=is_causal + ) self.norm2 = RMSNormNd(dim=out_channels, eps=1e-8, elementwise_affine=elementwise_affine, channel_dim=1) self.dropout = nn.Dropout(dropout) - self.conv2 = LTXCausalConv3d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, is_causal=is_causal) + self.conv2 = LTXCausalConv3d( + in_channels=out_channels, out_channels=out_channels, kernel_size=3, is_causal=is_causal + ) self.norm3 = None self.conv_shortcut = None @@ -246,7 +250,15 @@ def __init__( self.downsamplers = None if spatio_temporal_scale: self.downsamplers = nn.ModuleList( - [LTXCausalConv3d(in_channels=in_channels, out_channels=in_channels, kernel_size=3, stride=(2, 2, 2), is_causal=is_causal)] + [ + LTXCausalConv3d( + in_channels=in_channels, + out_channels=in_channels, + kernel_size=3, + stride=(2, 2, 2), + is_causal=is_causal, + ) + ] ) self.conv_out = None @@ -257,7 +269,7 @@ def __init__( dropout=dropout, eps=resnet_eps, non_linearity=resnet_act_fn, - is_causal=is_causal + is_causal=is_causal, ) self.gradient_checkpointing = False @@ -428,10 +440,7 @@ def __init__( self.gradient_checkpointing = False - def forward( - self, - hidden_states: torch.Tensor - ) -> torch.Tensor: + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: if self.conv_in is not None: hidden_states = self.conv_in(hidden_states) @@ -502,7 +511,11 @@ def __init__( output_channel = block_out_channels[0] self.conv_in = LTXCausalConv3d( - in_channels=self.in_channels, out_channels=output_channel, kernel_size=3, stride=1, is_causal=is_causal, + in_channels=self.in_channels, + out_channels=output_channel, + kernel_size=3, + stride=1, + is_causal=is_causal, ) # down blocks @@ -525,7 +538,10 @@ def __init__( # mid block self.mid_block = LTXMidBlock3d( - in_channels=output_channel, num_layers=layers_per_block[-1], resnet_eps=resnet_norm_eps, is_causal=is_causal + in_channels=output_channel, + num_layers=layers_per_block[-1], + resnet_eps=resnet_norm_eps, + is_causal=is_causal, ) # out @@ -632,13 +648,12 @@ def __init__( layers_per_block = tuple(reversed(layers_per_block)) output_channel = block_out_channels[0] - self.conv_in = LTXCausalConv3d(in_channels=in_channels, out_channels=output_channel, kernel_size=3, stride=1, is_causal=is_causal) + self.conv_in = LTXCausalConv3d( + in_channels=in_channels, out_channels=output_channel, kernel_size=3, stride=1, is_causal=is_causal + ) self.mid_block = LTXMidBlock3d( - in_channels=output_channel, - num_layers=layers_per_block[0], - resnet_eps=resnet_norm_eps, - is_causal=is_causal + in_channels=output_channel, num_layers=layers_per_block[0], resnet_eps=resnet_norm_eps, is_causal=is_causal ) # up blocks @@ -654,7 +669,7 @@ def __init__( num_layers=layers_per_block[i + 1], resnet_eps=resnet_norm_eps, spatio_temporal_scale=spatio_temporal_scaling[i], - is_causal=is_causal + is_causal=is_causal, ) self.up_blocks.append(up_block) @@ -663,11 +678,7 @@ def __init__( self.norm_out = RMSNormNd(dim=out_channels, eps=1e-8, elementwise_affine=False, channel_dim=1) self.conv_act = nn.SiLU() self.conv_out = LTXCausalConv3d( - in_channels=output_channel, - out_channels=self.out_channels, - kernel_size=3, - stride=1, - is_causal=is_causal + in_channels=output_channel, out_channels=self.out_channels, kernel_size=3, stride=1, is_causal=is_causal ) self.gradient_checkpointing = False @@ -790,13 +801,13 @@ def __init__( is_causal=decoder_causal, ) - latent_means = torch.zeros((latent_channels,), requires_grad=False) - latent_stds = torch.zeros((latent_channels,), requires_grad=False) - self.register_buffer("latent_means", latent_means, persistent=True) - self.register_buffer("latent_stds", latent_stds, persistent=True) + latents_mean = torch.zeros((latent_channels,), requires_grad=False) + latents_std = torch.zeros((latent_channels,), requires_grad=False) + self.register_buffer("latents_mean", latents_mean, persistent=True) + self.register_buffer("latents_std", latents_std, persistent=True) - self.spatial_compression_ratio = patch_size * patch_size - self.temporal_compression_ratio = patch_size_t + self.spatial_compression_ratio = patch_size * 2 ** sum(spatio_temporal_scaling) + self.temporal_compression_ratio = patch_size_t * 2 ** sum(spatio_temporal_scaling) # When decoding a batch of video latents at a time, one can save memory by slicing across the batch dimension # to perform decoding of a single video latent at a time. @@ -809,8 +820,8 @@ def __init__( # When decoding temporally long video latents, the memory requirement is very high. By decoding latent frames # at a fixed frame batch size (based on `self.num_latent_frames_batch_sizes`), the memory requirement can be lowered. - self.use_framewise_encoding = True - self.use_framewise_decoding = True + self.use_framewise_encoding = False + self.use_framewise_decoding = False # This can be configured based on the amount of GPU memory available. # `16` for sample frames and `2` for latent frames are sensible defaults for consumer GPUs. diff --git a/src/diffusers/models/transformers/transformer_ltx.py b/src/diffusers/models/transformers/transformer_ltx.py index d0aefda47076..12f5dbb8da9f 100644 --- a/src/diffusers/models/transformers/transformer_ltx.py +++ b/src/diffusers/models/transformers/transformer_ltx.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import math from typing import Any, Dict, Optional, Tuple import torch @@ -45,6 +46,15 @@ def __init__(self): "LTXAttentionProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." ) + def _apply_rotary_emb(self, x: torch.Tensor, image_rotary_emb: Tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor: + cos, sin = image_rotary_emb + + x_real, x_imag = x.unflatten(2, (-1, 2)).unbind(-1) # [B, S, H, D//2] + x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(2) + + out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype) + return out + def __call__( self, attn: Attention, @@ -59,12 +69,12 @@ def __call__( if attention_mask is not None: attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) - # scaled_dot_product_attention expects attention_mask shape to be - # (batch, heads, source_length, target_length) attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + apply_rotary_emb = False if encoder_hidden_states is None: encoder_hidden_states = hidden_states + apply_rotary_emb = True query = attn.to_q(hidden_states) key = attn.to_k(encoder_hidden_states) @@ -73,6 +83,10 @@ def __call__( query = attn.norm_q(query) key = attn.norm_k(key) + if image_rotary_emb is not None and apply_rotary_emb: + query = self._apply_rotary_emb(query, image_rotary_emb) + key = self._apply_rotary_emb(key, image_rotary_emb) + query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2) key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2) value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2) @@ -89,13 +103,96 @@ def __call__( return hidden_states +class LTXRoPE(nn.Module): + def __init__( + self, + dim: int, + base_num_frames: int = 20, + base_height: int = 2048, + base_width: int = 2048, + patch_size: int = 1, + patch_size_t: int = 1, + theta: float = 10000.0, + ) -> None: + super().__init__() + + self.dim = dim + self.base_num_frames = base_num_frames + self.base_height = base_height + self.base_width = base_width + self.patch_size = patch_size + self.patch_size_t = patch_size_t + self.theta = theta + + def forward( + self, hidden_states: torch.Tensor, rope_interpolation_scale: Optional[Tuple[torch.Tensor, float, float]] = None + ) -> Tuple[torch.Tensor, torch.Tensor]: + batch_size, num_channels, num_frames, height, width = hidden_states.shape + post_patch_num_frames = num_frames // self.patch_size_t + post_patch_height = height // self.patch_size + post_patch_width = width // self.patch_size + + # Always compute rope in fp32 + grid_h = torch.arange(post_patch_height, dtype=torch.float32, device=hidden_states.device) + grid_w = torch.arange(post_patch_width, dtype=torch.float32, device=hidden_states.device) + grid_f = torch.arange(post_patch_num_frames, dtype=torch.float32, device=hidden_states.device) + grid = torch.meshgrid(grid_f, grid_h, grid_w, indexing="ij") + grid = torch.stack(grid, dim=0) + grid = grid.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1) + + if rope_interpolation_scale is not None: + grid[:, 0:1] = grid[:, 0:1] * rope_interpolation_scale[0] * self.patch_size_t / self.base_num_frames + grid[:, 1:2] = grid[:, 1:2] * rope_interpolation_scale[1] * self.patch_size / self.base_height + grid[:, 2:3] = grid[:, 2:3] * rope_interpolation_scale[2] * self.patch_size / self.base_width + + grid = grid.flatten(2, 4).transpose(1, 2) + + start = 1.0 + end = self.theta + freqs = self.theta ** torch.linspace( + math.log(start, self.theta), + math.log(end, self.theta), + self.dim // 6, + device=hidden_states.device, + dtype=torch.float32, + ) + freqs = freqs * math.pi / 2.0 + freqs = freqs * (grid.unsqueeze(-1) * 2 - 1) + freqs = freqs.transpose(-1, -2).flatten(2) + + cos_freqs = freqs.cos().repeat_interleave(2, dim=-1) + sin_freqs = freqs.sin().repeat_interleave(2, dim=-1) + + if self.dim % 6 != 0: + cos_padding = torch.ones_like(cos_freqs[:, :, : self.dim % 6]) + sin_padding = torch.zeros_like(cos_freqs[:, :, : self.dim % 6]) + cos_freqs = torch.cat([cos_padding, cos_freqs], dim=-1) + sin_freqs = torch.cat([sin_padding, sin_freqs], dim=-1) + + cos_freqs = cos_freqs.to(dtype=hidden_states.dtype) + sin_freqs = sin_freqs.to(dtype=hidden_states.dtype) + + return cos_freqs, sin_freqs + + @maybe_allow_in_graph class LTXTransformerBlock(nn.Module): r""" Transformer block used in [LTX](https://huggingface.co/Lightricks/LTX-Video). Args: - TODO(aryan) + dim (`int`): + The number of channels in the input and output. + num_attention_heads (`int`): + The number of heads to use for multi-head attention. + attention_head_dim (`int`): + The number of channels in each head. + qk_norm (`str`, defaults to `"rms_norm"`): + The normalization layer to use. + activation_fn (`str`, defaults to `"swiglu"`): + Activation function to use in feed-forward. + eps (`float`, defaults to `1e-6`): + Epsilon value for normalization layers. """ def __init__( @@ -141,7 +238,8 @@ def __init__( self.ff = FeedForward(dim, activation_fn=activation_fn) - self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5) + # TODO(aryan): Create a layer for this + self.scale_shift_table = nn.Parameter(torch.randn(6, dim)) def forward( self, @@ -188,7 +286,26 @@ class LTXTransformer3DModel(ModelMixin, ConfigMixin): A Transformer model for video-like data used in [LTX](https://huggingface.co/Lightricks/LTX-Video). Args: - TODO(aryan) + in_channels (`int`, defaults to `128`): + The number of channels in the input. + out_channels (`int`, defaults to `128`): + The number of channels in the output. + patch_size (`int`, defaults to `1`): + The size of the spatial patches to use in the patch embedding layer. + patch_size_t (`int`, defaults to `1`): + The size of the tmeporal patches to use in the patch embedding layer. + num_attention_heads (`int`, defaults to `32`): + The number of heads to use for multi-head attention. + attention_head_dim (`int`, defaults to `64`): + The number of channels in each head. + cross_attention_dim (`int`, defaults to `64`): + The number of channels for cross attention heads. + num_layers (`int`, defaults to `28`): + The number of layers of Transformer blocks to use. + activation_fn (`str`, defaults to `"swiglu"`): + Activation function to use in feed-forward. + qk_norm (`str`, defaults to `"rms_norm_across_heads"`): + The normalization layer to use. """ _supports_gradient_checkpointing = True @@ -219,6 +336,16 @@ def __init__( self.patchify_proj = nn.Linear(in_channels, inner_dim) + self.rope = LTXRoPE( + dim=inner_dim, + base_num_frames=20, + base_height=2048, + base_width=2048, + patch_size=patch_size, + patch_size_t=patch_size_t, + theta=10000.0, + ) + self.transformer_blocks = nn.ModuleList( [ LTXTransformerBlock( @@ -240,7 +367,8 @@ def __init__( self.norm_out = nn.LayerNorm(inner_dim, eps=1e-6, elementwise_affine=False) self.proj_out = nn.Linear(inner_dim, out_channels) - self.scale_shift_table = nn.Parameter(torch.randn(2, inner_dim) / inner_dim**0.5) + # TODO(aryan): create a layer for this + self.scale_shift_table = nn.Parameter(torch.randn(2, inner_dim)) self.adaln_single = AdaLayerNormSingle(inner_dim, use_additional_conditions=False) self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=inner_dim) @@ -257,9 +385,11 @@ def forward( encoder_hidden_states: torch.Tensor, timestep: torch.LongTensor, encoder_attention_mask: torch.Tensor, - image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + rope_interpolation_scale: Optional[Tuple[float, float, float]] = None, return_dict: bool = True, ) -> torch.Tensor: + image_rotary_emb = self.rope(hidden_states, rope_interpolation_scale) + # convert encoder_attention_mask to a bias the same way we do for attention_mask if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2: encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0 @@ -281,7 +411,6 @@ def forward( temb, embedded_timestep = self.adaln_single( timestep.flatten(), - {"resolution": None, "aspect_ratio": None}, batch_size=batch_size, hidden_dtype=hidden_states.dtype, ) diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 5143b1114fd3..621bcce6ae80 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -245,6 +245,7 @@ ] ) _import_structure["latte"] = ["LattePipeline"] + _import_structure["ltx"] = ["LTXPipeline"] _import_structure["lumina"] = ["LuminaText2ImgPipeline"] _import_structure["marigold"].extend( [ @@ -577,6 +578,7 @@ LEditsPPPipelineStableDiffusion, LEditsPPPipelineStableDiffusionXL, ) + from .ltx import LTXPipeline from .lumina import LuminaText2ImgPipeline from .marigold import ( MarigoldDepthPipeline, diff --git a/src/diffusers/pipelines/ltx/__init__.py b/src/diffusers/pipelines/ltx/__init__.py new file mode 100644 index 000000000000..96fc7b3c24cd --- /dev/null +++ b/src/diffusers/pipelines/ltx/__init__.py @@ -0,0 +1,48 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_ltx"] = ["LTXPipeline"] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * + else: + from .pipeline_ltx import LTXPipeline + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/src/diffusers/pipelines/ltx/pipeline_ltx.py b/src/diffusers/pipelines/ltx/pipeline_ltx.py index e69de29bb2d1..cb7210f0f530 100644 --- a/src/diffusers/pipelines/ltx/pipeline_ltx.py +++ b/src/diffusers/pipelines/ltx/pipeline_ltx.py @@ -0,0 +1,700 @@ +# Copyright 2024 Black Forest Labs and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Callable, Dict, List, Optional, Union + +import numpy as np +import torch +from transformers import T5EncoderModel, T5TokenizerFast + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...models.autoencoders import AutoencoderKLLTX +from ...models.transformers import LTXTransformer3DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import is_torch_xla_available, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ...video_processor import VideoProcessor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import LTXPipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import LTXPipeline + >>> from diffusers.utils import export_to_video + + >>> pipe = LTXPipeline.from_pretrained("a-r-r-o-w/LTX-Video-diffusers", torch_dtype=torch.bfloat16) + >>> pipe.to("cuda") + + >>> prompt = "A woman with long brown hair and light skin smiles at another woman with long blonde hair. The woman with brown hair wears a black jacket and has a small, barely noticeable mole on her right cheek. The camera angle is a close-up, focused on the woman with brown hair's face. The lighting is warm and natural, likely from the setting sun, casting a soft glow on the scene. The scene appears to be real-life footage" + >>> negative_prompt = "worst quality, inconsistent motion, blurry, jittery, distorted" + + >>> video = pipe( + ... prompt=prompt, + ... negative_prompt=negative_prompt, + ... width=704, + ... height=480, + ... num_frames=161, + ... num_inference_steps=50, + ... ).frames[0] + >>> export_to_video(video, "output.mp4", fps=24) + ``` +""" + + +# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.16, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class LTXPipeline(DiffusionPipeline): + r""" + Pipeline for text-to-video generation. + + Reference: https://github.com/Lightricks/LTX-Video + + Args: + transformer ([`MochiTransformer3DModel`]): + Conditional Transformer architecture to denoise the encoded video latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`T5EncoderModel`]): + [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically + the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer). + tokenizer (`T5TokenizerFast`): + Second Tokenizer of class + [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast). + """ + + model_cpu_offload_seq = "text_encoder->transformer->vae" + _optional_components = [] + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + + def __init__( + self, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKLLTX, + text_encoder: T5EncoderModel, + tokenizer: T5TokenizerFast, + transformer: LTXTransformer3DModel, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + transformer=transformer, + scheduler=scheduler, + ) + + self.vae_spatial_scale_factor = self.vae.spatial_compression_ratio if hasattr(self, "vae") else 32 + self.vae_temporal_scale_factor = self.vae.temporal_compression_ratio if hasattr(self, "vae") else 8 + self.transformer_spatial_patch_size = self.transformer.config.patch_size if hasattr(self, "transformer") else 1 + self.transformer_temporal_patch_size = ( + self.transformer.config.patch_size_t if hasattr(self, "transformer") else 1 + ) + + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_spatial_scale_factor) + self.tokenizer_max_length = ( + self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 128 + ) + + self.default_height = 512 + self.default_width = 704 + self.default_frames = 121 + + # Copied from diffusers.pipelines.cogvideo.pipeline_mochi.MochiPipeline._get_t5_prompt_embeds + def _get_t5_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + num_videos_per_prompt: int = 1, + max_sequence_length: int = 128, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + prompt_attention_mask = text_inputs.attention_mask + prompt_attention_mask = prompt_attention_mask.bool().to(device) + + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {max_sequence_length} tokens: {removed_text}" + ) + + prompt_embeds = self.text_encoder(text_input_ids.to(device))[0] + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + prompt_attention_mask = prompt_attention_mask.view(batch_size, -1) + prompt_attention_mask = prompt_attention_mask.repeat(num_videos_per_prompt, 1) + + return prompt_embeds, prompt_attention_mask + + # Copied from diffusers.pipelines.cogvideo.pipeline_mochi.MochiPipeline.encode_prompt + def encode_prompt( + self, + prompt: Union[str, List[str]], + negative_prompt: Optional[Union[str, List[str]]] = None, + do_classifier_free_guidance: bool = True, + num_videos_per_prompt: int = 1, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + prompt_attention_mask: Optional[torch.Tensor] = None, + negative_prompt_attention_mask: Optional[torch.Tensor] = None, + max_sequence_length: int = 128, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + Whether to use classifier free guidance or not. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + Number of videos that should be generated per prompt. torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + device: (`torch.device`, *optional*): + torch device + dtype: (`torch.dtype`, *optional*): + torch dtype + """ + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds, prompt_attention_mask = self._get_t5_prompt_embeds( + prompt=prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds, negative_prompt_attention_mask = self._get_t5_prompt_embeds( + prompt=negative_prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask + + def check_inputs( + self, + prompt, + height, + width, + callback_on_step_end_tensor_inputs=None, + prompt_embeds=None, + negative_prompt_embeds=None, + prompt_attention_mask=None, + negative_prompt_attention_mask=None, + ): + if height % 32 != 0 or width % 32 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 32 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if prompt_embeds is not None and prompt_attention_mask is None: + raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.") + + if negative_prompt_embeds is not None and negative_prompt_attention_mask is None: + raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.") + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + if prompt_attention_mask.shape != negative_prompt_attention_mask.shape: + raise ValueError( + "`prompt_attention_mask` and `negative_prompt_attention_mask` must have the same shape when passed directly, but" + f" got: `prompt_attention_mask` {prompt_attention_mask.shape} != `negative_prompt_attention_mask`" + f" {negative_prompt_attention_mask.shape}." + ) + + def prepare_latents( + self, + batch_size, + num_channels_latents, + height, + width, + num_frames, + dtype, + device, + generator, + latents=None, + ): + height = height // self.vae_spatial_scale_factor + width = width // self.vae_spatial_scale_factor + num_frames = (num_frames - 1) // self.vae_temporal_scale_factor + 1 + + shape = (batch_size, num_channels_latents, num_frames, height, width) + + if latents is not None: + return latents.to(device=device, dtype=dtype) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + return latents + + def decode_latents(self, latents: torch.Tensor): + # unscale/denormalize the latents + latents_mean = self.vae.latents_mean.view(1, self.vae.config.latent_channels, 1, 1, 1).to( + latents.device, latents.dtype + ) + latents_std = self.vae.latents_std.view(1, self.vae.config.latent_channels, 1, 1, 1).to( + latents.device, latents.dtype + ) + latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean + video = self.vae.decode(latents, return_dict=False)[0] + return video + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1.0 + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + negative_prompt: Optional[Union[str, List[str]]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_frames: int = 81, + frame_rate: int = 25, + num_inference_steps: int = 50, + timesteps: List[int] = None, + guidance_scale: float = 3, + num_videos_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + prompt_attention_mask: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_attention_mask: Optional[torch.Tensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 128, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + height (`int`, *optional*, defaults to `self.default_height`): + The height in pixels of the generated image. This is set to 480 by default for the best results. + width (`int`, *optional*, defaults to `self.default_width`): + The width in pixels of the generated image. This is set to 848 by default for the best results. + num_frames (`int`, defaults to `19`): + The number of video frames to generate + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + guidance_scale (`float`, defaults to `4.5`): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of videos to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + prompt_attention_mask (`torch.Tensor`, *optional*): + Pre-generated attention mask for text embeddings. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. For PixArt-Sigma this negative prompt should be "". If not + provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. + negative_prompt_attention_mask (`torch.FloatTensor`, *optional*): + Pre-generated attention mask for negative text embeddings. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.ltx.LTXPipelineOutput`] instead of a plain tuple. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int` defaults to `256`): + Maximum sequence length to use with the `prompt`. + + Examples: + + Returns: + [`~pipelines.ltx.LTXPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.ltx.LTXPipelineOutput`] is returned, otherwise a `tuple` is + returned where the first element is a list with the generated images. + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + height = height or self.default_height + width = width or self.default_width + latent_frame_rate = frame_rate // self.vae_temporal_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt=prompt, + height=height, + width=width, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_attention_mask=negative_prompt_attention_mask, + ) + + self._guidance_scale = guidance_scale + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # 3. Prepare text embeddings + ( + prompt_embeds, + prompt_attention_mask, + negative_prompt_embeds, + negative_prompt_attention_mask, + ) = self.encode_prompt( + prompt=prompt, + negative_prompt=negative_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_attention_mask=negative_prompt_attention_mask, + max_sequence_length=max_sequence_length, + device=device, + ) + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0) + + # 4. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels + latents = self.prepare_latents( + batch_size * num_videos_per_prompt, + num_channels_latents, + height, + width, + num_frames, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 5. Prepare timesteps + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) + image_seq_len = latents.size(1) + mu = calculate_shift( + image_seq_len, + self.scheduler.config.base_image_seq_len, + self.scheduler.config.max_image_seq_len, + self.scheduler.config.base_shift, + self.scheduler.config.max_shift, + ) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + timesteps, + sigmas=sigmas, + mu=mu, + ) + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + print(timesteps) + print(self.scheduler.sigmas) + + # 6. Prepare micro-conditions + rope_interpolation_scale = ( + 1 / latent_frame_rate, + self.vae_spatial_scale_factor, + self.vae_spatial_scale_factor, + ) + + # 7. Denoising loop + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latent_model_input.shape[0]).to(latents.dtype) + + noise_pred = self.transformer( + hidden_states=latent_model_input, + encoder_hidden_states=prompt_embeds, + timestep=timestep, + encoder_attention_mask=prompt_attention_mask, + rope_interpolation_scale=rope_interpolation_scale, + return_dict=False, + )[0] + noise_pred = noise_pred.to(dtype=torch.float32) + + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = latents.to(dtype=torch.float32) + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + latents = latents.to(dtype=latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + if output_type == "latent": + video = latents + else: + video = self.decode_latents(latents) + video = self.video_processor.postprocess_video(video, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video,) + + return LTXPipelineOutput(frames=video) diff --git a/src/diffusers/pipelines/ltx/pipeline_output.py b/src/diffusers/pipelines/ltx/pipeline_output.py new file mode 100644 index 000000000000..36ec3ea884a2 --- /dev/null +++ b/src/diffusers/pipelines/ltx/pipeline_output.py @@ -0,0 +1,20 @@ +from dataclasses import dataclass + +import torch + +from diffusers.utils import BaseOutput + + +@dataclass +class LTXPipelineOutput(BaseOutput): + r""" + Output class for LTX pipelines. + + Args: + frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]): + List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing + denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape + `(batch_size, num_frames, channels, height, width)`. + """ + + frames: torch.Tensor diff --git a/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py b/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py index c1096dbe0c29..4cb6d01e9560 100644 --- a/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py +++ b/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py @@ -72,6 +72,7 @@ def __init__( base_image_seq_len: Optional[int] = 256, max_image_seq_len: Optional[int] = 4096, invert_sigmas: bool = False, + shift_terminal: Optional[float] = None, ): timesteps = np.linspace(1, num_train_timesteps, num_train_timesteps, dtype=np.float32)[::-1].copy() timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32) @@ -169,6 +170,12 @@ def _sigma_to_t(self, sigma): def time_shift(self, mu: float, sigma: float, t: torch.Tensor): return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) + def stretch_shift_to_terminal(self, t: torch.Tensor) -> torch.Tensor: + one_minus_z = 1 - t + scale_factor = one_minus_z[-1] / (1 - self.config.shift_terminal) + stretched_t = 1 - (one_minus_z / scale_factor) + return stretched_t + def set_timesteps( self, num_inference_steps: int = None, @@ -202,6 +209,9 @@ def set_timesteps( else: sigmas = self.config.shift * sigmas / (1 + (self.config.shift - 1) * sigmas) + if self.config.shift_terminal: + sigmas = self.stretch_shift_to_terminal(sigmas) + sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device) timesteps = sigmas * self.config.num_train_timesteps From 5f185cd2c1474bfd8661ecf09b26ef6f083d95f4 Mon Sep 17 00:00:00 2001 From: Aryan Date: Wed, 27 Nov 2024 14:56:27 +0100 Subject: [PATCH 15/51] docs --- docs/source/en/_toctree.yml | 6 +++ .../source/en/api/models/autoencoderkl_ltx.md | 37 +++++++++++++++++++ .../source/en/api/models/ltx_transformer3d.md | 30 +++++++++++++++ docs/source/en/api/pipelines/ltx.md | 35 ++++++++++++++++++ 4 files changed, 108 insertions(+) create mode 100644 docs/source/en/api/models/autoencoderkl_ltx.md create mode 100644 docs/source/en/api/models/ltx_transformer3d.md create mode 100644 docs/source/en/api/pipelines/ltx.md diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 2faabfec30ce..b18c2e722da2 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -272,6 +272,8 @@ title: LatteTransformer3DModel - local: api/models/lumina_nextdit2d title: LuminaNextDiT2DModel + - local: api/models/ltx_transformer3d + title: LTXTransformer3DModel - local: api/models/mochi_transformer3d title: MochiTransformer3DModel - local: api/models/pixart_transformer2d @@ -310,6 +312,8 @@ title: AutoencoderKLAllegro - local: api/models/autoencoderkl_cogvideox title: AutoencoderKLCogVideoX + - local: api/models/autoencoderkl_ltx + title: AutoencoderKLLTX - local: api/models/autoencoderkl_mochi title: AutoencoderKLMochi - local: api/models/asymmetricautoencoderkl @@ -402,6 +406,8 @@ title: Latte - local: api/pipelines/ledits_pp title: LEDITS++ + - local: api/pipelines/ltx + title: LTX - local: api/pipelines/lumina title: Lumina-T2X - local: api/pipelines/marigold diff --git a/docs/source/en/api/models/autoencoderkl_ltx.md b/docs/source/en/api/models/autoencoderkl_ltx.md new file mode 100644 index 000000000000..3a9549372ba1 --- /dev/null +++ b/docs/source/en/api/models/autoencoderkl_ltx.md @@ -0,0 +1,37 @@ + + +# AutoencoderKLLTX + +The 3D variational autoencoder (VAE) model with KL loss used in [LTX](https://huggingface.co/Lightricks/LTX-Video) was introduced in by Lightricks. + +The model can be loaded with the following code snippet. + +```python +from diffusers import AutoencoderKLLTX + +vae = AutoencoderKLLTX.from_pretrained("TODO/TODO", subfolder="vae", torch_dtype=torch.float32).to("cuda") +``` + +## AutoencoderKLLTX + +[[autodoc]] AutoencoderKLLTX + - decode + - encode + - all + +## AutoencoderKLOutput + +[[autodoc]] models.autoencoders.autoencoder_kl.AutoencoderKLOutput + +## DecoderOutput + +[[autodoc]] models.autoencoders.vae.DecoderOutput diff --git a/docs/source/en/api/models/ltx_transformer3d.md b/docs/source/en/api/models/ltx_transformer3d.md new file mode 100644 index 000000000000..e70a03ad7ea7 --- /dev/null +++ b/docs/source/en/api/models/ltx_transformer3d.md @@ -0,0 +1,30 @@ + + +# LTXTransformer3DModel + +A Diffusion Transformer model for 3D data from [LTX](https://huggingface.co/Lightricks/LTX-Video) was introduced by Lightricks. + +The model can be loaded with the following code snippet. + +```python +from diffusers import LTXTransformer3DModel + +transformer = LTXTransformer3DModel.from_pretrained("TODO/TODO", subfolder="transformer", torch_dtype=torch.bfloat16).to("cuda") +``` + +## LTXTransformer3DModel + +[[autodoc]] LTXTransformer3DModel + +## Transformer2DModelOutput + +[[autodoc]] models.modeling_outputs.Transformer2DModelOutput diff --git a/docs/source/en/api/pipelines/ltx.md b/docs/source/en/api/pipelines/ltx.md new file mode 100644 index 000000000000..d1c238eb35c2 --- /dev/null +++ b/docs/source/en/api/pipelines/ltx.md @@ -0,0 +1,35 @@ + + +# LTX + +[LTX Video](https://huggingface.co/Lightricks/LTX-Video) from Genmo. + +*LTX-Video is the first DiT-based video generation model capable of generating high-quality videos in real-time. It produces 24 FPS videos at a 768x512 resolution faster than they can be watched. Trained on a large-scale dataset of diverse videos, the model generates high-resolution videos with realistic and varied content. We provide a model for both text-to-video as well as image + text-to-video usecases.* + + + +Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers.md) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading.md#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines. + + + +## LTXPipeline + +[[autodoc]] LTXPipeline + - all + - __call__ + +## LTXPipelineOutput + +[[autodoc]] pipelines.ltx.pipeline_output.LTXPipelineOutput From e580b6ba136dd09cbcf052456d00cb2a4e2fad27 Mon Sep 17 00:00:00 2001 From: Aryan Date: Wed, 27 Nov 2024 15:19:01 +0100 Subject: [PATCH 16/51] add pipeline test --- tests/pipelines/ltx/__init__.py | 0 tests/pipelines/ltx/test_ltx.py | 259 ++++++++++++++++++++++++++++++++ 2 files changed, 259 insertions(+) create mode 100644 tests/pipelines/ltx/__init__.py create mode 100644 tests/pipelines/ltx/test_ltx.py diff --git a/tests/pipelines/ltx/__init__.py b/tests/pipelines/ltx/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/pipelines/ltx/test_ltx.py b/tests/pipelines/ltx/test_ltx.py new file mode 100644 index 000000000000..2d4cac70c40b --- /dev/null +++ b/tests/pipelines/ltx/test_ltx.py @@ -0,0 +1,259 @@ +# Copyright 2024 The HuggingFace Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import unittest + +import numpy as np +import torch +from transformers import AutoTokenizer, T5EncoderModel + +from diffusers import AutoencoderKLLTX, FlowMatchEulerDiscreteScheduler, LTXPipeline, LTXTransformer3DModel +from diffusers.utils.testing_utils import ( + enable_full_determinism, + torch_device, +) + +from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS +from ..test_pipelines_common import PipelineTesterMixin, to_np + + +enable_full_determinism() + + +class LTXPipelineFastTests(PipelineTesterMixin, unittest.TestCase): + pipeline_class = LTXPipeline + params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"} + batch_params = TEXT_TO_IMAGE_BATCH_PARAMS + image_params = TEXT_TO_IMAGE_IMAGE_PARAMS + image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS + required_optional_params = frozenset( + [ + "num_inference_steps", + "generator", + "latents", + "return_dict", + "callback_on_step_end", + "callback_on_step_end_tensor_inputs", + ] + ) + test_xformers_attention = False + + def get_dummy_components(self): + torch.manual_seed(0) + transformer = LTXTransformer3DModel( + in_channels=8, + out_channels=8, + patch_size=1, + patch_size_t=1, + num_attention_heads=4, + attention_head_dim=8, + cross_attention_dim=32, + num_layers=1, + caption_channels=32, + ) + + torch.manual_seed(0) + vae = AutoencoderKLLTX( + latent_channels=8, + block_out_channels=(8, 8, 8, 8), + spatio_temporal_scaling=(True, True, False, False), + layers_per_block=(1, 1, 1, 1, 1), + patch_size=1, + patch_size_t=1, + encoder_causal=True, + decoder_causal=False, + ) + vae.use_framewise_encoding = False + vae.use_framewise_decoding = False + + torch.manual_seed(0) + scheduler = FlowMatchEulerDiscreteScheduler() + text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5") + tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5") + + components = { + "transformer": transformer, + "vae": vae, + "scheduler": scheduler, + "text_encoder": text_encoder, + "tokenizer": tokenizer, + } + return components + + def get_dummy_inputs(self, device, seed=0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + + inputs = { + "prompt": "dance monkey", + "negative_prompt": "", + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 3.0, + "height": 32, + "width": 32, + # 8 * k + 1 is the recommendation + "num_frames": 9, + "max_sequence_length": 16, + "output_type": "pt", + } + + return inputs + + def test_inference(self): + device = "cpu" + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + video = pipe(**inputs).frames + generated_video = video[0] + + self.assertEqual(generated_video.shape, (9, 3, 32, 32)) + expected_video = torch.randn(9, 3, 32, 32) + max_diff = np.abs(generated_video - expected_video).max() + self.assertLessEqual(max_diff, 1e10) + + def test_callback_inputs(self): + sig = inspect.signature(self.pipeline_class.__call__) + has_callback_tensor_inputs = "callback_on_step_end_tensor_inputs" in sig.parameters + has_callback_step_end = "callback_on_step_end" in sig.parameters + + if not (has_callback_tensor_inputs and has_callback_step_end): + return + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + self.assertTrue( + hasattr(pipe, "_callback_tensor_inputs"), + f" {self.pipeline_class} should have `_callback_tensor_inputs` that defines a list of tensor variables its callback function can use as inputs", + ) + + def callback_inputs_subset(pipe, i, t, callback_kwargs): + # iterate over callback args + for tensor_name, tensor_value in callback_kwargs.items(): + # check that we're only passing in allowed tensor inputs + assert tensor_name in pipe._callback_tensor_inputs + + return callback_kwargs + + def callback_inputs_all(pipe, i, t, callback_kwargs): + for tensor_name in pipe._callback_tensor_inputs: + assert tensor_name in callback_kwargs + + # iterate over callback args + for tensor_name, tensor_value in callback_kwargs.items(): + # check that we're only passing in allowed tensor inputs + assert tensor_name in pipe._callback_tensor_inputs + + return callback_kwargs + + inputs = self.get_dummy_inputs(torch_device) + + # Test passing in a subset + inputs["callback_on_step_end"] = callback_inputs_subset + inputs["callback_on_step_end_tensor_inputs"] = ["latents"] + output = pipe(**inputs)[0] + + # Test passing in a everything + inputs["callback_on_step_end"] = callback_inputs_all + inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs + output = pipe(**inputs)[0] + + def callback_inputs_change_tensor(pipe, i, t, callback_kwargs): + is_last = i == (pipe.num_timesteps - 1) + if is_last: + callback_kwargs["latents"] = torch.zeros_like(callback_kwargs["latents"]) + return callback_kwargs + + inputs["callback_on_step_end"] = callback_inputs_change_tensor + inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs + output = pipe(**inputs)[0] + assert output.abs().sum() < 1e10 + + def test_inference_batch_single_identical(self): + self._test_inference_batch_single_identical(batch_size=3, expected_max_diff=1e-3) + + def test_attention_slicing_forward_pass( + self, test_max_difference=True, test_mean_pixel_difference=True, expected_max_diff=1e-3 + ): + if not self.test_attention_slicing: + return + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + for component in pipe.components.values(): + if hasattr(component, "set_default_attn_processor"): + component.set_default_attn_processor() + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + generator_device = "cpu" + inputs = self.get_dummy_inputs(generator_device) + output_without_slicing = pipe(**inputs)[0] + + pipe.enable_attention_slicing(slice_size=1) + inputs = self.get_dummy_inputs(generator_device) + output_with_slicing1 = pipe(**inputs)[0] + + pipe.enable_attention_slicing(slice_size=2) + inputs = self.get_dummy_inputs(generator_device) + output_with_slicing2 = pipe(**inputs)[0] + + if test_max_difference: + max_diff1 = np.abs(to_np(output_with_slicing1) - to_np(output_without_slicing)).max() + max_diff2 = np.abs(to_np(output_with_slicing2) - to_np(output_without_slicing)).max() + self.assertLess( + max(max_diff1, max_diff2), + expected_max_diff, + "Attention slicing should not affect the inference results", + ) + + def test_vae_tiling(self, expected_diff_max: float = 0.2): + generator_device = "cpu" + components = self.get_dummy_components() + + pipe = self.pipeline_class(**components) + pipe.to("cpu") + pipe.set_progress_bar_config(disable=None) + + # Without tiling + inputs = self.get_dummy_inputs(generator_device) + inputs["height"] = inputs["width"] = 128 + output_without_tiling = pipe(**inputs)[0] + + # With tiling + pipe.vae.enable_tiling( + tile_sample_min_height=96, + tile_sample_min_width=96, + tile_sample_stride_height=64, + tile_sample_stride_width=64, + ) + inputs = self.get_dummy_inputs(generator_device) + inputs["height"] = inputs["width"] = 128 + output_with_tiling = pipe(**inputs)[0] + + self.assertLess( + (to_np(output_without_tiling) - to_np(output_with_tiling)).max(), + expected_diff_max, + "VAE tiling should not affect the inference results", + ) From 13adf3f76ba1635ee03ed4dfa8f1280158510c87 Mon Sep 17 00:00:00 2001 From: Aryan Date: Wed, 27 Nov 2024 15:19:20 +0100 Subject: [PATCH 17/51] update --- tests/pipelines/ltx/test_ltx.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/tests/pipelines/ltx/test_ltx.py b/tests/pipelines/ltx/test_ltx.py index 2d4cac70c40b..18e1cee38fe8 100644 --- a/tests/pipelines/ltx/test_ltx.py +++ b/tests/pipelines/ltx/test_ltx.py @@ -20,10 +20,7 @@ from transformers import AutoTokenizer, T5EncoderModel from diffusers import AutoencoderKLLTX, FlowMatchEulerDiscreteScheduler, LTXPipeline, LTXTransformer3DModel -from diffusers.utils.testing_utils import ( - enable_full_determinism, - torch_device, -) +from diffusers.utils.testing_utils import enable_full_determinism, torch_device from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS from ..test_pipelines_common import PipelineTesterMixin, to_np From c8dfa989014cc6d2e4dd08338e132027155eadcb Mon Sep 17 00:00:00 2001 From: Aryan Date: Wed, 27 Nov 2024 15:19:35 +0100 Subject: [PATCH 18/51] update --- src/diffusers/models/autoencoders/autoencoder_kl_ltx.py | 3 --- src/diffusers/pipelines/ltx/pipeline_ltx.py | 7 ++----- 2 files changed, 2 insertions(+), 8 deletions(-) diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_ltx.py b/src/diffusers/models/autoencoders/autoencoder_kl_ltx.py index de2a9288afe8..682fc267c027 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_ltx.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_ltx.py @@ -1115,9 +1115,6 @@ def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[Decod else: time = self.decoder(z[:, :, :, i : i + tile_latent_min_height, j : j + tile_latent_min_width]) - if self.drop_last_temporal_frames and time.size(2) >= self.temporal_compression_ratio: - time = time[:, :, self.temporal_compression_ratio - 1 :] - row.append(time) rows.append(row) diff --git a/src/diffusers/pipelines/ltx/pipeline_ltx.py b/src/diffusers/pipelines/ltx/pipeline_ltx.py index cb7210f0f530..d7eea8272344 100644 --- a/src/diffusers/pipelines/ltx/pipeline_ltx.py +++ b/src/diffusers/pipelines/ltx/pipeline_ltx.py @@ -628,8 +628,6 @@ def __call__( ) num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) self._num_timesteps = len(timesteps) - print(timesteps) - print(self.scheduler.sigmas) # 6. Prepare micro-conditions rope_interpolation_scale = ( @@ -657,7 +655,7 @@ def __call__( rope_interpolation_scale=rope_interpolation_scale, return_dict=False, )[0] - noise_pred = noise_pred.to(dtype=torch.float32) + noise_pred = noise_pred.float() if self.do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) @@ -665,8 +663,7 @@ def __call__( # compute the previous noisy sample x_t -> x_t-1 latents_dtype = latents.dtype - latents = latents.to(dtype=torch.float32) - latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + latents = self.scheduler.step(noise_pred, t, latents.float(), return_dict=False)[0] latents = latents.to(dtype=latents_dtype) if callback_on_step_end is not None: From b234394be62839db6539fc8303b565077049fbd5 Mon Sep 17 00:00:00 2001 From: Aryan Date: Wed, 27 Nov 2024 15:21:08 +0100 Subject: [PATCH 19/51] make fix-copies --- src/diffusers/pipelines/ltx/pipeline_ltx.py | 4 ++-- .../utils/dummy_torch_and_transformers_objects.py | 15 +++++++++++++++ 2 files changed, 17 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/ltx/pipeline_ltx.py b/src/diffusers/pipelines/ltx/pipeline_ltx.py index d7eea8272344..8eb7d41285ea 100644 --- a/src/diffusers/pipelines/ltx/pipeline_ltx.py +++ b/src/diffusers/pipelines/ltx/pipeline_ltx.py @@ -201,7 +201,7 @@ def __init__( self.default_width = 704 self.default_frames = 121 - # Copied from diffusers.pipelines.cogvideo.pipeline_mochi.MochiPipeline._get_t5_prompt_embeds + # Copied from diffusers.pipelines.mochi.pipeline_mochi.MochiPipeline._get_t5_prompt_embeds with 256->128 def _get_t5_prompt_embeds( self, prompt: Union[str, List[str]] = None, @@ -250,7 +250,7 @@ def _get_t5_prompt_embeds( return prompt_embeds, prompt_attention_mask - # Copied from diffusers.pipelines.cogvideo.pipeline_mochi.MochiPipeline.encode_prompt + # Copied from diffusers.pipelines.mochi.pipeline_mochi.MochiPipeline.encode_prompt with 256->128 def encode_prompt( self, prompt: Union[str, List[str]], diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index b76ea3824060..4b54638fc652 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -1067,6 +1067,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class LTXPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class LuminaText2ImgPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] From 6544fcc60854e9d55f7807cf8f3da79b228aaaad Mon Sep 17 00:00:00 2001 From: Aryan Date: Wed, 27 Nov 2024 21:22:07 +0530 Subject: [PATCH 20/51] Apply suggestions from code review Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> --- docs/source/en/api/models/autoencoderkl_ltx.md | 2 +- docs/source/en/api/pipelines/ltx.md | 4 +--- .../models/autoencoders/autoencoder_kl_ltx.py | 16 ++++++++-------- .../models/transformers/transformer_ltx.py | 6 +++--- src/diffusers/pipelines/ltx/pipeline_ltx.py | 6 +++--- 5 files changed, 16 insertions(+), 18 deletions(-) diff --git a/docs/source/en/api/models/autoencoderkl_ltx.md b/docs/source/en/api/models/autoencoderkl_ltx.md index 3a9549372ba1..7c2519866077 100644 --- a/docs/source/en/api/models/autoencoderkl_ltx.md +++ b/docs/source/en/api/models/autoencoderkl_ltx.md @@ -11,7 +11,7 @@ specific language governing permissions and limitations under the License. --> # AutoencoderKLLTX -The 3D variational autoencoder (VAE) model with KL loss used in [LTX](https://huggingface.co/Lightricks/LTX-Video) was introduced in by Lightricks. +The 3D variational autoencoder (VAE) model with KL loss used in [LTX](https://huggingface.co/Lightricks/LTX-Video) was introduced by Lightricks. The model can be loaded with the following code snippet. diff --git a/docs/source/en/api/pipelines/ltx.md b/docs/source/en/api/pipelines/ltx.md index d1c238eb35c2..18de92fe2804 100644 --- a/docs/source/en/api/pipelines/ltx.md +++ b/docs/source/en/api/pipelines/ltx.md @@ -14,9 +14,7 @@ # LTX -[LTX Video](https://huggingface.co/Lightricks/LTX-Video) from Genmo. - -*LTX-Video is the first DiT-based video generation model capable of generating high-quality videos in real-time. It produces 24 FPS videos at a 768x512 resolution faster than they can be watched. Trained on a large-scale dataset of diverse videos, the model generates high-resolution videos with realistic and varied content. We provide a model for both text-to-video as well as image + text-to-video usecases.* +[LTX Video](https://huggingface.co/Lightricks/LTX-Video) is the first DiT-based video generation model capable of generating high-quality videos in real-time. It produces 24 FPS videos at a 768x512 resolution faster than they can be watched. Trained on a large-scale dataset of diverse videos, the model generates high-resolution videos with realistic and varied content. We provide a model for both text-to-video as well as image + text-to-video usecases. diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_ltx.py b/src/diffusers/models/autoencoders/autoencoder_kl_ltx.py index 682fc267c027..2ce3b9d83540 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_ltx.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_ltx.py @@ -1,4 +1,4 @@ -# Copyright 2024 The Mochi team and The HuggingFace Team. +# Copyright 2024 The Lightricks team and The HuggingFace Team. # All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -470,15 +470,15 @@ class LTXEncoder3d(nn.Module): representation. Args: - in_channels (`int`): + in_channels (`int`, defaults to 3): Number of input channels. - out_channels (`int`): + out_channels (`int`, defaults to 128): Number of latent channels. block_out_channels (`Tuple[int, ...]`, defaults to `(128, 256, 512, 512)`): The number of output channels for each block. spatio_temporal_scaling (`Tuple[bool, ...], defaults to `(True, True, True, False)`: Whether a block should contain spatio-temporal downscaling layers or not. - layers_per_block (`Tuple[int, ...]`, defaults to `(128, 256, 512, 512)`): + layers_per_block (`Tuple[int, ...]`, defaults to `(4, 3, 3, 3, 4)`): The number of layers per block. patch_size (`int`, defaults to `4`): The size of spatial patches. @@ -605,15 +605,15 @@ class LTXDecoder3d(nn.Module): The `LTXDecoder3d` layer of a variational autoencoder that decodes its latent representation into an output sample. Args: - in_channels (`int`): + in_channels (`int`, defaults to 128): Number of latent channels. - out_channels (`int`): + out_channels (`int`, defaults to 3): Number of output channels. block_out_channels (`Tuple[int, ...]`, defaults to `(128, 256, 512, 512)`): The number of output channels for each block. spatio_temporal_scaling (`Tuple[bool, ...], defaults to `(True, True, True, False)`: Whether a block should contain spatio-temporal upscaling layers or not. - layers_per_block (`Tuple[int, ...]`, defaults to `(128, 256, 512, 512)`): + layers_per_block (`Tuple[int, ...]`, defaults to `(4, 3, 3, 3, 4)`): The number of layers per block. patch_size (`int`, defaults to `4`): The size of spatial patches. @@ -737,7 +737,7 @@ class AutoencoderKLLTX(ModelMixin, ConfigMixin): The number of output channels for each block. spatio_temporal_scaling (`Tuple[bool, ...], defaults to `(True, True, True, False)`: Whether a block should contain spatio-temporal downscaling or not. - layers_per_block (`Tuple[int, ...]`, defaults to `(128, 256, 512, 512)`): + layers_per_block (`Tuple[int, ...]`, defaults to `(4, 3, 3, 3, 4)`): The number of layers per block. patch_size (`int`, defaults to `4`): The size of spatial patches. diff --git a/src/diffusers/models/transformers/transformer_ltx.py b/src/diffusers/models/transformers/transformer_ltx.py index 12f5dbb8da9f..de5322ea724d 100644 --- a/src/diffusers/models/transformers/transformer_ltx.py +++ b/src/diffusers/models/transformers/transformer_ltx.py @@ -189,7 +189,7 @@ class LTXTransformerBlock(nn.Module): The number of channels in each head. qk_norm (`str`, defaults to `"rms_norm"`): The normalization layer to use. - activation_fn (`str`, defaults to `"swiglu"`): + activation_fn (`str`, defaults to `"gelu-approximate"`): Activation function to use in feed-forward. eps (`float`, defaults to `1e-6`): Epsilon value for normalization layers. @@ -298,11 +298,11 @@ class LTXTransformer3DModel(ModelMixin, ConfigMixin): The number of heads to use for multi-head attention. attention_head_dim (`int`, defaults to `64`): The number of channels in each head. - cross_attention_dim (`int`, defaults to `64`): + cross_attention_dim (`int`, defaults to `2048 `): The number of channels for cross attention heads. num_layers (`int`, defaults to `28`): The number of layers of Transformer blocks to use. - activation_fn (`str`, defaults to `"swiglu"`): + activation_fn (`str`, defaults to `"gelu-approximate"`): Activation function to use in feed-forward. qk_norm (`str`, defaults to `"rms_norm_across_heads"`): The normalization layer to use. diff --git a/src/diffusers/pipelines/ltx/pipeline_ltx.py b/src/diffusers/pipelines/ltx/pipeline_ltx.py index 8eb7d41285ea..01914cbf0164 100644 --- a/src/diffusers/pipelines/ltx/pipeline_ltx.py +++ b/src/diffusers/pipelines/ltx/pipeline_ltx.py @@ -481,7 +481,7 @@ def __call__( The height in pixels of the generated image. This is set to 480 by default for the best results. width (`int`, *optional*, defaults to `self.default_width`): The width in pixels of the generated image. This is set to 848 by default for the best results. - num_frames (`int`, defaults to `19`): + num_frames (`int`, defaults to `81 `): The number of video frames to generate num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the @@ -490,7 +490,7 @@ def __call__( Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed will be used. Must be in descending order. - guidance_scale (`float`, defaults to `4.5`): + guidance_scale (`float`, defaults to `3 `): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > @@ -529,7 +529,7 @@ def __call__( The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the `._callback_tensor_inputs` attribute of your pipeline class. - max_sequence_length (`int` defaults to `256`): + max_sequence_length (`int` defaults to `128 `): Maximum sequence length to use with the `prompt`. Examples: From 7134e2dcfb9e46ec2fd9725df63fc4505dc09d33 Mon Sep 17 00:00:00 2001 From: Aryan Date: Thu, 28 Nov 2024 04:54:22 +0100 Subject: [PATCH 21/51] update --- src/diffusers/models/transformers/transformer_ltx.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_ltx.py b/src/diffusers/models/transformers/transformer_ltx.py index 12f5dbb8da9f..279ea4de7841 100644 --- a/src/diffusers/models/transformers/transformer_ltx.py +++ b/src/diffusers/models/transformers/transformer_ltx.py @@ -201,7 +201,7 @@ def __init__( num_attention_heads: int, attention_head_dim: int, cross_attention_dim: int, - qk_norm: str = "rms_norm", + qk_norm: str = "rms_norm_across_heads", activation_fn: str = "gelu-approximate", attention_bias: bool = True, attention_out_bias: bool = True, @@ -238,8 +238,7 @@ def __init__( self.ff = FeedForward(dim, activation_fn=activation_fn) - # TODO(aryan): Create a layer for this - self.scale_shift_table = nn.Parameter(torch.randn(6, dim)) + self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5) def forward( self, @@ -368,7 +367,7 @@ def __init__( self.proj_out = nn.Linear(inner_dim, out_channels) # TODO(aryan): create a layer for this - self.scale_shift_table = nn.Parameter(torch.randn(2, inner_dim)) + self.scale_shift_table = nn.Parameter(torch.randn(2, inner_dim) / inner_dim**0.5) self.adaln_single = AdaLayerNormSingle(inner_dim, use_additional_conditions=False) self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=inner_dim) From d4a0f8ea400169ec804d409b31c66abc36309814 Mon Sep 17 00:00:00 2001 From: Aryan Date: Thu, 28 Nov 2024 04:54:50 +0100 Subject: [PATCH 22/51] copy t2v to i2v pipeline --- .../pipelines/ltx/pipeline_ltx_image2video.py | 697 ++++++++++++++++++ 1 file changed, 697 insertions(+) create mode 100644 src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py diff --git a/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py b/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py new file mode 100644 index 000000000000..310d279e4461 --- /dev/null +++ b/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py @@ -0,0 +1,697 @@ +# Copyright 2024 Black Forest Labs and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Callable, Dict, List, Optional, Union + +import numpy as np +import torch +from transformers import T5EncoderModel, T5TokenizerFast + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...models.autoencoders import AutoencoderKLLTX +from ...models.transformers import LTXTransformer3DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import is_torch_xla_available, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ...video_processor import VideoProcessor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import LTXPipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import LTXPipeline + >>> from diffusers.utils import export_to_video + + >>> pipe = LTXPipeline.from_pretrained("a-r-r-o-w/LTX-Video-diffusers", torch_dtype=torch.bfloat16) + >>> pipe.to("cuda") + + >>> prompt = "A woman with long brown hair and light skin smiles at another woman with long blonde hair. The woman with brown hair wears a black jacket and has a small, barely noticeable mole on her right cheek. The camera angle is a close-up, focused on the woman with brown hair's face. The lighting is warm and natural, likely from the setting sun, casting a soft glow on the scene. The scene appears to be real-life footage" + >>> negative_prompt = "worst quality, inconsistent motion, blurry, jittery, distorted" + + >>> video = pipe( + ... prompt=prompt, + ... negative_prompt=negative_prompt, + ... width=704, + ... height=480, + ... num_frames=161, + ... num_inference_steps=50, + ... ).frames[0] + >>> export_to_video(video, "output.mp4", fps=24) + ``` +""" + + +# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.16, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class LTXImageToVideoPipeline(DiffusionPipeline): + r""" + Pipeline for image-to-video generation. + + Reference: https://github.com/Lightricks/LTX-Video + + Args: + transformer ([`MochiTransformer3DModel`]): + Conditional Transformer architecture to denoise the encoded video latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`T5EncoderModel`]): + [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically + the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer). + tokenizer (`T5TokenizerFast`): + Second Tokenizer of class + [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast). + """ + + model_cpu_offload_seq = "text_encoder->transformer->vae" + _optional_components = [] + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + + def __init__( + self, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKLLTX, + text_encoder: T5EncoderModel, + tokenizer: T5TokenizerFast, + transformer: LTXTransformer3DModel, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + transformer=transformer, + scheduler=scheduler, + ) + + self.vae_spatial_scale_factor = self.vae.spatial_compression_ratio if hasattr(self, "vae") else 32 + self.vae_temporal_scale_factor = self.vae.temporal_compression_ratio if hasattr(self, "vae") else 8 + self.transformer_spatial_patch_size = self.transformer.config.patch_size if hasattr(self, "transformer") else 1 + self.transformer_temporal_patch_size = ( + self.transformer.config.patch_size_t if hasattr(self, "transformer") else 1 + ) + + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_spatial_scale_factor) + self.tokenizer_max_length = ( + self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 128 + ) + + self.default_height = 512 + self.default_width = 704 + self.default_frames = 121 + + # Copied from diffusers.pipelines.mochi.pipeline_mochi.MochiPipeline._get_t5_prompt_embeds with 256->128 + def _get_t5_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + num_videos_per_prompt: int = 1, + max_sequence_length: int = 128, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + prompt_attention_mask = text_inputs.attention_mask + prompt_attention_mask = prompt_attention_mask.bool().to(device) + + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {max_sequence_length} tokens: {removed_text}" + ) + + prompt_embeds = self.text_encoder(text_input_ids.to(device))[0] + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + prompt_attention_mask = prompt_attention_mask.view(batch_size, -1) + prompt_attention_mask = prompt_attention_mask.repeat(num_videos_per_prompt, 1) + + return prompt_embeds, prompt_attention_mask + + # Copied from diffusers.pipelines.mochi.pipeline_mochi.MochiPipeline.encode_prompt with 256->128 + def encode_prompt( + self, + prompt: Union[str, List[str]], + negative_prompt: Optional[Union[str, List[str]]] = None, + do_classifier_free_guidance: bool = True, + num_videos_per_prompt: int = 1, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + prompt_attention_mask: Optional[torch.Tensor] = None, + negative_prompt_attention_mask: Optional[torch.Tensor] = None, + max_sequence_length: int = 128, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + Whether to use classifier free guidance or not. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + Number of videos that should be generated per prompt. torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + device: (`torch.device`, *optional*): + torch device + dtype: (`torch.dtype`, *optional*): + torch dtype + """ + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds, prompt_attention_mask = self._get_t5_prompt_embeds( + prompt=prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds, negative_prompt_attention_mask = self._get_t5_prompt_embeds( + prompt=negative_prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask + + def check_inputs( + self, + prompt, + height, + width, + callback_on_step_end_tensor_inputs=None, + prompt_embeds=None, + negative_prompt_embeds=None, + prompt_attention_mask=None, + negative_prompt_attention_mask=None, + ): + if height % 32 != 0 or width % 32 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 32 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if prompt_embeds is not None and prompt_attention_mask is None: + raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.") + + if negative_prompt_embeds is not None and negative_prompt_attention_mask is None: + raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.") + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + if prompt_attention_mask.shape != negative_prompt_attention_mask.shape: + raise ValueError( + "`prompt_attention_mask` and `negative_prompt_attention_mask` must have the same shape when passed directly, but" + f" got: `prompt_attention_mask` {prompt_attention_mask.shape} != `negative_prompt_attention_mask`" + f" {negative_prompt_attention_mask.shape}." + ) + + def prepare_latents( + self, + batch_size, + num_channels_latents, + height, + width, + num_frames, + dtype, + device, + generator, + latents=None, + ): + height = height // self.vae_spatial_scale_factor + width = width // self.vae_spatial_scale_factor + num_frames = (num_frames - 1) // self.vae_temporal_scale_factor + 1 + + shape = (batch_size, num_channels_latents, num_frames, height, width) + + if latents is not None: + return latents.to(device=device, dtype=dtype) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + return latents + + def decode_latents(self, latents: torch.Tensor): + # unscale/denormalize the latents + latents_mean = self.vae.latents_mean.view(1, self.vae.config.latent_channels, 1, 1, 1).to( + latents.device, latents.dtype + ) + latents_std = self.vae.latents_std.view(1, self.vae.config.latent_channels, 1, 1, 1).to( + latents.device, latents.dtype + ) + latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean + video = self.vae.decode(latents, return_dict=False)[0] + return video + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1.0 + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + negative_prompt: Optional[Union[str, List[str]]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_frames: int = 81, + frame_rate: int = 25, + num_inference_steps: int = 50, + timesteps: List[int] = None, + guidance_scale: float = 3, + num_videos_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + prompt_attention_mask: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_attention_mask: Optional[torch.Tensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 128, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + height (`int`, *optional*, defaults to `self.default_height`): + The height in pixels of the generated image. This is set to 480 by default for the best results. + width (`int`, *optional*, defaults to `self.default_width`): + The width in pixels of the generated image. This is set to 848 by default for the best results. + num_frames (`int`, defaults to `19`): + The number of video frames to generate + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + guidance_scale (`float`, defaults to `4.5`): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of videos to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + prompt_attention_mask (`torch.Tensor`, *optional*): + Pre-generated attention mask for text embeddings. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. For PixArt-Sigma this negative prompt should be "". If not + provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. + negative_prompt_attention_mask (`torch.FloatTensor`, *optional*): + Pre-generated attention mask for negative text embeddings. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.ltx.LTXPipelineOutput`] instead of a plain tuple. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int` defaults to `256`): + Maximum sequence length to use with the `prompt`. + + Examples: + + Returns: + [`~pipelines.ltx.LTXPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.ltx.LTXPipelineOutput`] is returned, otherwise a `tuple` is + returned where the first element is a list with the generated images. + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + height = height or self.default_height + width = width or self.default_width + latent_frame_rate = frame_rate // self.vae_temporal_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt=prompt, + height=height, + width=width, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_attention_mask=negative_prompt_attention_mask, + ) + + self._guidance_scale = guidance_scale + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # 3. Prepare text embeddings + ( + prompt_embeds, + prompt_attention_mask, + negative_prompt_embeds, + negative_prompt_attention_mask, + ) = self.encode_prompt( + prompt=prompt, + negative_prompt=negative_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_attention_mask=negative_prompt_attention_mask, + max_sequence_length=max_sequence_length, + device=device, + ) + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0) + + # 4. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels + latents = self.prepare_latents( + batch_size * num_videos_per_prompt, + num_channels_latents, + height, + width, + num_frames, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 5. Prepare timesteps + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) + image_seq_len = latents.size(1) + mu = calculate_shift( + image_seq_len, + self.scheduler.config.base_image_seq_len, + self.scheduler.config.max_image_seq_len, + self.scheduler.config.base_shift, + self.scheduler.config.max_shift, + ) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + timesteps, + sigmas=sigmas, + mu=mu, + ) + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # 6. Prepare micro-conditions + rope_interpolation_scale = ( + 1 / latent_frame_rate, + self.vae_spatial_scale_factor, + self.vae_spatial_scale_factor, + ) + + # 7. Denoising loop + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latent_model_input.shape[0]).to(latents.dtype) + + noise_pred = self.transformer( + hidden_states=latent_model_input, + encoder_hidden_states=prompt_embeds, + timestep=timestep, + encoder_attention_mask=prompt_attention_mask, + rope_interpolation_scale=rope_interpolation_scale, + return_dict=False, + )[0] + noise_pred = noise_pred.float() + + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents.float(), return_dict=False)[0] + latents = latents.to(dtype=latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + if output_type == "latent": + video = latents + else: + video = self.decode_latents(latents) + video = self.video_processor.postprocess_video(video, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video,) + + return LTXPipelineOutput(frames=video) From 06db66b4021dfc7068769119f32e46000d5495a0 Mon Sep 17 00:00:00 2001 From: Aryan Date: Thu, 28 Nov 2024 10:08:22 +0100 Subject: [PATCH 23/51] update --- src/diffusers/models/normalization.py | 6 --- .../models/transformers/transformer_ltx.py | 27 +++++----- src/diffusers/pipelines/ltx/pipeline_ltx.py | 54 ++++++++++--------- 3 files changed, 43 insertions(+), 44 deletions(-) diff --git a/src/diffusers/models/normalization.py b/src/diffusers/models/normalization.py index 500a19c51588..10fc9ee00383 100644 --- a/src/diffusers/models/normalization.py +++ b/src/diffusers/models/normalization.py @@ -604,9 +604,6 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: return hidden_states - def extra_repr(self) -> str: - return f"{super().extra_repr()}, channel_dim={self.channel_dim}" - class RMSNormNd(RMSNorm): def __init__( @@ -633,6 +630,3 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states = super().forward(hidden_states) return hidden_states - - def extra_repr(self): - return f"{super().extra_repr()}, channel_dim={self.channel_dim}" diff --git a/src/diffusers/models/transformers/transformer_ltx.py b/src/diffusers/models/transformers/transformer_ltx.py index a7d79fb09338..743b0882f5b9 100644 --- a/src/diffusers/models/transformers/transformer_ltx.py +++ b/src/diffusers/models/transformers/transformer_ltx.py @@ -46,15 +46,6 @@ def __init__(self): "LTXAttentionProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." ) - def _apply_rotary_emb(self, x: torch.Tensor, image_rotary_emb: Tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor: - cos, sin = image_rotary_emb - - x_real, x_imag = x.unflatten(2, (-1, 2)).unbind(-1) # [B, S, H, D//2] - x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(2) - - out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype) - return out - def __call__( self, attn: Attention, @@ -84,8 +75,18 @@ def __call__( key = attn.norm_k(key) if image_rotary_emb is not None and apply_rotary_emb: - query = self._apply_rotary_emb(query, image_rotary_emb) - key = self._apply_rotary_emb(key, image_rotary_emb) + + def apply_rotary_emb(x, freqs): + cos, sin = freqs + + x_real, x_imag = x.unflatten(2, (-1, 2)).unbind(-1) # [B, S, H, D//2] + x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(2) + + out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype) + return out + + query = self.apply_rotary_emb(query, image_rotary_emb) + key = self.apply_rotary_emb(key, image_rotary_emb) query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2) key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2) @@ -103,7 +104,7 @@ def __call__( return hidden_states -class LTXRoPE(nn.Module): +class LTXRotaryPosEmbed(nn.Module): def __init__( self, dim: int, @@ -335,7 +336,7 @@ def __init__( self.patchify_proj = nn.Linear(in_channels, inner_dim) - self.rope = LTXRoPE( + self.rope = LTXRotaryPosEmbed( dim=inner_dim, base_num_frames=20, base_height=2048, diff --git a/src/diffusers/pipelines/ltx/pipeline_ltx.py b/src/diffusers/pipelines/ltx/pipeline_ltx.py index 01914cbf0164..1a30680de45c 100644 --- a/src/diffusers/pipelines/ltx/pipeline_ltx.py +++ b/src/diffusers/pipelines/ltx/pipeline_ltx.py @@ -146,11 +146,11 @@ class LTXPipeline(DiffusionPipeline): Reference: https://github.com/Lightricks/LTX-Video Args: - transformer ([`MochiTransformer3DModel`]): + transformer ([`LTXTransformer3DModel`]): Conditional Transformer architecture to denoise the encoded video latents. scheduler ([`FlowMatchEulerDiscreteScheduler`]): A scheduler to be used in combination with `transformer` to denoise the encoded image latents. - vae ([`AutoencoderKL`]): + vae ([`AutoencoderKLLTX`]): Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. text_encoder ([`T5EncoderModel`]): [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically @@ -185,14 +185,14 @@ def __init__( scheduler=scheduler, ) - self.vae_spatial_scale_factor = self.vae.spatial_compression_ratio if hasattr(self, "vae") else 32 - self.vae_temporal_scale_factor = self.vae.temporal_compression_ratio if hasattr(self, "vae") else 8 + self.vae_spatial_compression_ratio = self.vae.spatial_compression_ratio if hasattr(self, "vae") else 32 + self.vae_temporal_compression_ratio = self.vae.temporal_compression_ratio if hasattr(self, "vae") else 8 self.transformer_spatial_patch_size = self.transformer.config.patch_size if hasattr(self, "transformer") else 1 self.transformer_temporal_patch_size = ( self.transformer.config.patch_size_t if hasattr(self, "transformer") else 1 ) - self.video_processor = VideoProcessor(vae_scale_factor=self.vae_spatial_scale_factor) + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_spatial_compression_ratio) self.tokenizer_max_length = ( self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 128 ) @@ -389,24 +389,25 @@ def check_inputs( def prepare_latents( self, - batch_size, - num_channels_latents, - height, - width, - num_frames, - dtype, - device, - generator, - latents=None, - ): - height = height // self.vae_spatial_scale_factor - width = width // self.vae_spatial_scale_factor - num_frames = (num_frames - 1) // self.vae_temporal_scale_factor + 1 + batch_size: int = 1, + num_channels_latents: int = 128, + height: int = 512, + width: int = 704, + num_frames: int = 161, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + generator: Optional[torch.Generator] = None, + latents: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if latents is not None: + return latents.to(device=device, dtype=dtype) + + height = height // self.vae_spatial_compression_ratio + width = width // self.vae_spatial_compression_ratio + num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1 shape = (batch_size, num_channels_latents, num_frames, height, width) - if latents is not None: - return latents.to(device=device, dtype=dtype) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" @@ -545,7 +546,7 @@ def __call__( height = height or self.default_height width = width or self.default_width - latent_frame_rate = frame_rate // self.vae_temporal_scale_factor + latent_frame_rate = frame_rate // self.vae_temporal_compression_ratio # 1. Check inputs. Raise error if not correct self.check_inputs( @@ -609,10 +610,13 @@ def __call__( ) # 5. Prepare timesteps + latent_frames = latents.size(2) + latent_height = height // self.vae_spatial_compression_ratio + latent_width = width // self.vae_spatial_compression_ratio + video_sequence_length = latent_height * latent_width * latent_frames sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) - image_seq_len = latents.size(1) mu = calculate_shift( - image_seq_len, + video_sequence_length, self.scheduler.config.base_image_seq_len, self.scheduler.config.max_image_seq_len, self.scheduler.config.base_shift, @@ -632,8 +636,8 @@ def __call__( # 6. Prepare micro-conditions rope_interpolation_scale = ( 1 / latent_frame_rate, - self.vae_spatial_scale_factor, - self.vae_spatial_scale_factor, + self.vae_spatial_compression_ratio, + self.vae_spatial_compression_ratio, ) # 7. Denoising loop From f8f30a59db0f00c29073a1ce0299b6d57bdfd097 Mon Sep 17 00:00:00 2001 From: Aryan Date: Thu, 28 Nov 2024 11:00:01 +0100 Subject: [PATCH 24/51] apply review suggestions --- scripts/convert_ltx_to_diffusers.py | 9 +++--- src/diffusers/models/normalization.py | 29 ++++++----------- .../models/transformers/transformer_ltx.py | 31 +++++++++---------- src/diffusers/pipelines/ltx/pipeline_ltx.py | 2 ++ 4 files changed, 31 insertions(+), 40 deletions(-) diff --git a/scripts/convert_ltx_to_diffusers.py b/scripts/convert_ltx_to_diffusers.py index e801a7a8535a..758936b9e4f3 100644 --- a/scripts/convert_ltx_to_diffusers.py +++ b/scripts/convert_ltx_to_diffusers.py @@ -8,7 +8,7 @@ from diffusers import AutoencoderKLLTX, FlowMatchEulerDiscreteScheduler, LTXPipeline, LTXTransformer3DModel -def remove_keys_inplace(key: str, state_dict: Dict[str, Any]): +def remove_keys_(key: str, state_dict: Dict[str, Any]): state_dict.pop(key) @@ -47,15 +47,14 @@ def remove_keys_inplace(key: str, state_dict: Dict[str, Any]): # common "conv_shortcut": "conv_shortcut.conv", "res_blocks": "resnets", - "norm3.norm": "norm3", "per_channel_statistics.mean-of-means": "latents_mean", "per_channel_statistics.std-of-means": "latents_std", } VAE_SPECIAL_KEYS_REMAP = { - "per_channel_statistics.channel": remove_keys_inplace, - "per_channel_statistics.mean-of-means": remove_keys_inplace, - "per_channel_statistics.mean-of-stds": remove_keys_inplace, + "per_channel_statistics.channel": remove_keys_, + "per_channel_statistics.mean-of-means": remove_keys_, + "per_channel_statistics.mean-of-stds": remove_keys_, } diff --git a/src/diffusers/models/normalization.py b/src/diffusers/models/normalization.py index 10fc9ee00383..9e0c21e4d500 100644 --- a/src/diffusers/models/normalization.py +++ b/src/diffusers/models/normalization.py @@ -572,7 +572,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: return F.normalize(hidden_states, p=self.p, dim=self.dim, eps=self.eps) -class LayerNormNd(nn.LayerNorm): +class LayerNormNd(nn.Module): def __init__( self, normalized_shape: Union[int, List[int], Tuple[int], torch.Size], @@ -583,29 +583,23 @@ def __init__( dtype=None, channel_dim: int = -1, ) -> None: - super().__init__( - normalized_shape=normalized_shape, - eps=eps, - elementwise_affine=elementwise_affine, - bias=bias, - device=device, - dtype=dtype, - ) + super().__init__() + self.norm = nn.LayerNorm(normalized_shape, eps, elementwise_affine, bias, device, dtype) self.channel_dim = channel_dim def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: if self.channel_dim != -1: hidden_states = hidden_states.movedim(self.channel_dim, -1) - hidden_states = super().forward(hidden_states) + hidden_states = self.norm(hidden_states) hidden_states = hidden_states.movedim(-1, self.channel_dim) else: - hidden_states = super().forward(hidden_states) + hidden_states = self.norm(hidden_states) return hidden_states -class RMSNormNd(RMSNorm): +class RMSNormNd(nn.Module): def __init__( self, dim: int, @@ -613,20 +607,17 @@ def __init__( elementwise_affine: bool = True, channel_dim: int = -1, ) -> None: - super().__init__( - dim=dim, - eps=eps, - elementwise_affine=elementwise_affine, - ) + super().__init__() + self.norm = RMSNorm(dim, eps=eps, elementwise_affine=elementwise_affine) self.channel_dim = channel_dim def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: if self.channel_dim != -1: hidden_states = hidden_states.movedim(self.channel_dim, -1) - hidden_states = super().forward(hidden_states) + hidden_states = self.norm(hidden_states) hidden_states = hidden_states.movedim(-1, self.channel_dim) else: - hidden_states = super().forward(hidden_states) + hidden_states = self.norm(hidden_states) return hidden_states diff --git a/src/diffusers/models/transformers/transformer_ltx.py b/src/diffusers/models/transformers/transformer_ltx.py index 743b0882f5b9..e9bb3282e31d 100644 --- a/src/diffusers/models/transformers/transformer_ltx.py +++ b/src/diffusers/models/transformers/transformer_ltx.py @@ -62,10 +62,10 @@ def __call__( attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) - apply_rotary_emb = False + use_rotary_emb = False if encoder_hidden_states is None: encoder_hidden_states = hidden_states - apply_rotary_emb = True + use_rotary_emb = True query = attn.to_q(hidden_states) key = attn.to_k(encoder_hidden_states) @@ -74,19 +74,9 @@ def __call__( query = attn.norm_q(query) key = attn.norm_k(key) - if image_rotary_emb is not None and apply_rotary_emb: - - def apply_rotary_emb(x, freqs): - cos, sin = freqs - - x_real, x_imag = x.unflatten(2, (-1, 2)).unbind(-1) # [B, S, H, D//2] - x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(2) - - out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype) - return out - - query = self.apply_rotary_emb(query, image_rotary_emb) - key = self.apply_rotary_emb(key, image_rotary_emb) + if image_rotary_emb is not None and use_rotary_emb: + query = apply_rotary_emb(query, image_rotary_emb) + key = apply_rotary_emb(key, image_rotary_emb) query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2) key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2) @@ -367,7 +357,6 @@ def __init__( self.norm_out = nn.LayerNorm(inner_dim, eps=1e-6, elementwise_affine=False) self.proj_out = nn.Linear(inner_dim, out_channels) - # TODO(aryan): create a layer for this self.scale_shift_table = nn.Parameter(torch.randn(2, inner_dim) / inner_dim**0.5) self.adaln_single = AdaLayerNormSingle(inner_dim, use_additional_conditions=False) @@ -467,3 +456,13 @@ def custom_forward(*inputs): if not return_dict: return (output,) return Transformer2DModelOutput(sample=output) + + +def apply_rotary_emb(x, freqs): + cos, sin = freqs + + x_real, x_imag = x.unflatten(2, (-1, 2)).unbind(-1) # [B, S, H, D//2] + x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(2) + + out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype) + return out diff --git a/src/diffusers/pipelines/ltx/pipeline_ltx.py b/src/diffusers/pipelines/ltx/pipeline_ltx.py index 1a30680de45c..955330290ae5 100644 --- a/src/diffusers/pipelines/ltx/pipeline_ltx.py +++ b/src/diffusers/pipelines/ltx/pipeline_ltx.py @@ -632,6 +632,8 @@ def __call__( ) num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) self._num_timesteps = len(timesteps) + print(self.scheduler.sigmas) + print(len(self.scheduler.sigmas)) # 6. Prepare micro-conditions rope_interpolation_scale = ( From 4e89c8d3a69630615507acf14374850040a94c0a Mon Sep 17 00:00:00 2001 From: Aryan Date: Thu, 28 Nov 2024 11:10:40 +0100 Subject: [PATCH 25/51] update --- scripts/convert_ltx_to_diffusers.py | 2 ++ .../models/transformers/transformer_ltx.py | 16 ++++++++-------- 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/scripts/convert_ltx_to_diffusers.py b/scripts/convert_ltx_to_diffusers.py index 758936b9e4f3..75d13b7b4755 100644 --- a/scripts/convert_ltx_to_diffusers.py +++ b/scripts/convert_ltx_to_diffusers.py @@ -15,6 +15,8 @@ def remove_keys_(key: str, state_dict: Dict[str, Any]): TOKENIZER_MAX_LENGTH = 128 TRANSFORMER_KEYS_RENAME_DICT = { + "patchify_proj": "proj_in", + "adaln_single": "time_embed", "q_norm": "norm_q", "k_norm": "norm_k", } diff --git a/src/diffusers/models/transformers/transformer_ltx.py b/src/diffusers/models/transformers/transformer_ltx.py index e9bb3282e31d..9e21a9d9e664 100644 --- a/src/diffusers/models/transformers/transformer_ltx.py +++ b/src/diffusers/models/transformers/transformer_ltx.py @@ -324,8 +324,13 @@ def __init__( out_channels = out_channels or in_channels inner_dim = num_attention_heads * attention_head_dim - self.patchify_proj = nn.Linear(in_channels, inner_dim) + self.proj_in = nn.Linear(in_channels, inner_dim) + self.scale_shift_table = nn.Parameter(torch.randn(2, inner_dim) / inner_dim**0.5) + self.time_embed = AdaLayerNormSingle(inner_dim, use_additional_conditions=False) + + self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=inner_dim) + self.rope = LTXRotaryPosEmbed( dim=inner_dim, base_num_frames=20, @@ -357,11 +362,6 @@ def __init__( self.norm_out = nn.LayerNorm(inner_dim, eps=1e-6, elementwise_affine=False) self.proj_out = nn.Linear(inner_dim, out_channels) - self.scale_shift_table = nn.Parameter(torch.randn(2, inner_dim) / inner_dim**0.5) - self.adaln_single = AdaLayerNormSingle(inner_dim, use_additional_conditions=False) - - self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=inner_dim) - self.gradient_checkpointing = False def _set_gradient_checkpointing(self, module, value=False): @@ -396,9 +396,9 @@ def forward( batch_size, -1, post_patch_num_frames, p_t, post_patch_height, p, post_patch_width, p ) hidden_states = hidden_states.permute(0, 2, 4, 6, 1, 3, 5, 7).flatten(4, 7).flatten(1, 3) - hidden_states = self.patchify_proj(hidden_states) + hidden_states = self.proj_in(hidden_states) - temb, embedded_timestep = self.adaln_single( + temb, embedded_timestep = self.time_embed( timestep.flatten(), batch_size=batch_size, hidden_dtype=hidden_states.dtype, From 5391cebe542d22672b9073df1c4d273e7a4e466b Mon Sep 17 00:00:00 2001 From: Aryan Date: Thu, 28 Nov 2024 11:12:56 +0100 Subject: [PATCH 26/51] make style --- src/diffusers/models/transformers/transformer_ltx.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/transformers/transformer_ltx.py b/src/diffusers/models/transformers/transformer_ltx.py index 9e21a9d9e664..67f486ae1ace 100644 --- a/src/diffusers/models/transformers/transformer_ltx.py +++ b/src/diffusers/models/transformers/transformer_ltx.py @@ -330,7 +330,7 @@ def __init__( self.time_embed = AdaLayerNormSingle(inner_dim, use_additional_conditions=False) self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=inner_dim) - + self.rope = LTXRotaryPosEmbed( dim=inner_dim, base_num_frames=20, From c2018808ac16920e7558b4a843184900de04322b Mon Sep 17 00:00:00 2001 From: Aryan Date: Fri, 29 Nov 2024 04:58:46 +0100 Subject: [PATCH 27/51] remove framewise encoding/decoding --- .../models/autoencoders/autoencoder_kl_ltx.py | 59 ++++++++----------- src/diffusers/models/normalization.py | 3 - src/diffusers/pipelines/ltx/pipeline_ltx.py | 2 - 3 files changed, 24 insertions(+), 40 deletions(-) diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_ltx.py b/src/diffusers/models/autoencoders/autoencoder_kl_ltx.py index 2ce3b9d83540..5765f9b9403d 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_ltx.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_ltx.py @@ -899,12 +899,12 @@ def _encode(self, x: torch.Tensor) -> torch.Tensor: return self.tiled_encode(x) if self.use_framewise_encoding: - enc = [] - for i in range(0, num_frames, self.num_sample_frames_batch_size): - x_intermediate = x[:, :, i : i + self.num_sample_frames_batch_size] - x_intermediate = self.encoder(x_intermediate) - enc.append(x_intermediate) - enc = torch.cat(enc, dim=2) + # TODO(aryan): requires investigation + raise NotImplementedError( + "Frame-wise encoding has not been implemented for AutoencoderKLLTX, at the moment, due to " + "quality issues caused by splitting inference across frame dimension. If you believe this " + "should be possible, please submit a PR to https://github.com/huggingface/diffusers/pulls." + ) else: enc = self.encoder(x) @@ -946,12 +946,12 @@ def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOut return self.tiled_decode(z, return_dict=return_dict) if self.use_framewise_decoding: - dec = [] - for i in range(0, num_frames, self.num_latent_frames_batch_size): - z_intermediate = z[:, :, i : i + self.num_latent_frames_batch_size] - z_intermediate = self.decoder(z_intermediate) - dec.append(z_intermediate) - dec = torch.cat(dec, dim=2) + # TODO(aryan): requires investigation + raise NotImplementedError( + "Frame-wise decoding has not been implemented for AutoencoderKLLTX, at the moment, due to " + "quality issues caused by splitting inference across frame dimension. If you believe this " + "should be possible, please submit a PR to https://github.com/huggingface/diffusers/pulls." + ) else: dec = self.decoder(z) @@ -1031,17 +1031,12 @@ def tiled_encode(self, x: torch.Tensor) -> torch.Tensor: row = [] for j in range(0, width, self.tile_sample_stride_width): if self.use_framewise_encoding: - time = [] - for k in range(0, num_frames, self.num_sample_frames_batch_size): - tile = x[ - :, - :, - k : k + self.num_sample_frames_batch_size, - i : i + self.tile_sample_min_height, - j : j + self.tile_sample_min_width, - ] - tile = self.encoder(tile) - time.append(tile) + # TODO(aryan): requires investigation + raise NotImplementedError( + "Frame-wise encoding has not been implemented for AutoencoderKLLTX, at the moment, due to " + "quality issues caused by splitting inference across frame dimension. If you believe this " + "should be possible, please submit a PR to https://github.com/huggingface/diffusers/pulls." + ) else: time = self.encoder( x[:, :, :, i : i + self.tile_sample_min_height, j : j + self.tile_sample_min_width] @@ -1100,18 +1095,12 @@ def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[Decod row = [] for j in range(0, width, tile_latent_stride_width): if self.use_framewise_decoding: - time = [] - for k in range(0, num_frames, self.num_latent_frames_batch_size): - tile = z[ - :, - :, - k : k + self.num_latent_frames_batch_size, - i : i + tile_latent_min_height, - j : j + tile_latent_min_width, - ] - tile = self.decoder(tile) - time.append(tile) - time = torch.cat(time, dim=2) + # TODO(aryan): requires investigation + raise NotImplementedError( + "Frame-wise decoding has not been implemented for AutoencoderKLLTX, at the moment, due to " + "quality issues caused by splitting inference across frame dimension. If you believe this " + "should be possible, please submit a PR to https://github.com/huggingface/diffusers/pulls." + ) else: time = self.decoder(z[:, :, :, i : i + tile_latent_min_height, j : j + tile_latent_min_width]) diff --git a/src/diffusers/models/normalization.py b/src/diffusers/models/normalization.py index 9e0c21e4d500..f13cea59da06 100644 --- a/src/diffusers/models/normalization.py +++ b/src/diffusers/models/normalization.py @@ -543,9 +543,6 @@ def forward(self, hidden_states): return hidden_states - def extra_repr(self) -> str: - return f"features={self.dim}, eps={self.eps}, elementwise_affine={self.elementwise_affine}" - class GlobalResponseNorm(nn.Module): # Taken from https://github.com/facebookresearch/ConvNeXt-V2/blob/3608f67cc1dae164790c5d0aead7bf2d73d9719b/models/utils.py#L105 diff --git a/src/diffusers/pipelines/ltx/pipeline_ltx.py b/src/diffusers/pipelines/ltx/pipeline_ltx.py index 955330290ae5..1a30680de45c 100644 --- a/src/diffusers/pipelines/ltx/pipeline_ltx.py +++ b/src/diffusers/pipelines/ltx/pipeline_ltx.py @@ -632,8 +632,6 @@ def __call__( ) num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) self._num_timesteps = len(timesteps) - print(self.scheduler.sigmas) - print(len(self.scheduler.sigmas)) # 6. Prepare micro-conditions rope_interpolation_scale = ( From 30a3bb723494c32a51a32cd2efedd4d1d5f62523 Mon Sep 17 00:00:00 2001 From: Aryan Date: Fri, 29 Nov 2024 05:29:26 +0100 Subject: [PATCH 28/51] pack/unpack latents --- .../models/transformers/transformer_ltx.py | 38 ++++-------- src/diffusers/pipelines/ltx/pipeline_ltx.py | 59 ++++++++++++++----- 2 files changed, 55 insertions(+), 42 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_ltx.py b/src/diffusers/models/transformers/transformer_ltx.py index 67f486ae1ace..5ec4b04a544d 100644 --- a/src/diffusers/models/transformers/transformer_ltx.py +++ b/src/diffusers/models/transformers/transformer_ltx.py @@ -116,17 +116,14 @@ def __init__( self.theta = theta def forward( - self, hidden_states: torch.Tensor, rope_interpolation_scale: Optional[Tuple[torch.Tensor, float, float]] = None + self, hidden_states: torch.Tensor, num_frames: int, height: int, width: int, rope_interpolation_scale: Optional[Tuple[torch.Tensor, float, float]] = None ) -> Tuple[torch.Tensor, torch.Tensor]: - batch_size, num_channels, num_frames, height, width = hidden_states.shape - post_patch_num_frames = num_frames // self.patch_size_t - post_patch_height = height // self.patch_size - post_patch_width = width // self.patch_size + batch_size = hidden_states.size(0) # Always compute rope in fp32 - grid_h = torch.arange(post_patch_height, dtype=torch.float32, device=hidden_states.device) - grid_w = torch.arange(post_patch_width, dtype=torch.float32, device=hidden_states.device) - grid_f = torch.arange(post_patch_num_frames, dtype=torch.float32, device=hidden_states.device) + grid_h = torch.arange(height, dtype=torch.float32, device=hidden_states.device) + grid_w = torch.arange(width, dtype=torch.float32, device=hidden_states.device) + grid_f = torch.arange(num_frames, dtype=torch.float32, device=hidden_states.device) grid = torch.meshgrid(grid_f, grid_h, grid_w, indexing="ij") grid = torch.stack(grid, dim=0) grid = grid.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1) @@ -374,28 +371,20 @@ def forward( encoder_hidden_states: torch.Tensor, timestep: torch.LongTensor, encoder_attention_mask: torch.Tensor, + num_frames: int, + height: int, + width: int, rope_interpolation_scale: Optional[Tuple[float, float, float]] = None, return_dict: bool = True, ) -> torch.Tensor: - image_rotary_emb = self.rope(hidden_states, rope_interpolation_scale) + image_rotary_emb = self.rope(hidden_states, num_frames, height, width, rope_interpolation_scale) # convert encoder_attention_mask to a bias the same way we do for attention_mask if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2: encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0 encoder_attention_mask = encoder_attention_mask.unsqueeze(1) - batch_size, num_channels, num_frames, height, width = hidden_states.shape - p = self.config.patch_size - p_t = self.config.patch_size_t - - post_patch_height = height // p - post_patch_width = width // p - post_patch_num_frames = num_frames // p_t - - hidden_states = hidden_states.reshape( - batch_size, -1, post_patch_num_frames, p_t, post_patch_height, p, post_patch_width, p - ) - hidden_states = hidden_states.permute(0, 2, 4, 6, 1, 3, 5, 7).flatten(4, 7).flatten(1, 3) + batch_size = hidden_states.size(0) hidden_states = self.proj_in(hidden_states) temb, embedded_timestep = self.time_embed( @@ -446,12 +435,7 @@ def custom_forward(*inputs): hidden_states = self.norm_out(hidden_states) hidden_states = hidden_states * (1 + scale) + shift - hidden_states = self.proj_out(hidden_states) - - hidden_states = hidden_states.reshape( - batch_size, post_patch_num_frames, post_patch_height, post_patch_width, -1, p_t, p, p - ) - output = hidden_states.permute(0, 4, 1, 5, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(2, 3) + output = self.proj_out(hidden_states) if not return_dict: return (output,) diff --git a/src/diffusers/pipelines/ltx/pipeline_ltx.py b/src/diffusers/pipelines/ltx/pipeline_ltx.py index 1a30680de45c..a9b18cb31a33 100644 --- a/src/diffusers/pipelines/ltx/pipeline_ltx.py +++ b/src/diffusers/pipelines/ltx/pipeline_ltx.py @@ -387,6 +387,34 @@ def check_inputs( f" {negative_prompt_attention_mask.shape}." ) + @staticmethod + def _pack_latents(latents: torch.Tensor, patch_size: int = 1, patch_size_t: int = 1) -> torch.Tensor: + batch_size, num_channels, num_frames, height, width = latents.shape + post_patch_num_frames = num_frames // patch_size_t + post_patch_height = height // patch_size + post_patch_width = width // patch_size + latents = latents.reshape( + batch_size, + -1, + post_patch_num_frames, + patch_size_t, + post_patch_height, + patch_size, + post_patch_width, + patch_size, + ) + latents = latents.permute(0, 2, 4, 6, 1, 3, 5, 7).flatten(4, 7).flatten(1, 3) + return latents + + @staticmethod + def _unpack_latents( + latents: torch.Tensor, num_frames: int, height: int, width: int, patch_size: int = 1, patch_size_t: int = 1 + ) -> torch.Tensor: + batch_size, num_channels, video_sequence_length = latents.shape + latents = latents.reshape(batch_size, num_frames, height, width, -1, patch_size_t, patch_size, patch_size) + latents = latents.permute(0, 4, 1, 5, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(2, 3) + return latents + def prepare_latents( self, batch_size: int = 1, @@ -415,20 +443,9 @@ def prepare_latents( ) latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = self._pack_latents(latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size) return latents - def decode_latents(self, latents: torch.Tensor): - # unscale/denormalize the latents - latents_mean = self.vae.latents_mean.view(1, self.vae.config.latent_channels, 1, 1, 1).to( - latents.device, latents.dtype - ) - latents_std = self.vae.latents_std.view(1, self.vae.config.latent_channels, 1, 1, 1).to( - latents.device, latents.dtype - ) - latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean - video = self.vae.decode(latents, return_dict=False)[0] - return video - @property def guidance_scale(self): return self._guidance_scale @@ -610,10 +627,10 @@ def __call__( ) # 5. Prepare timesteps - latent_frames = latents.size(2) + latent_num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1 latent_height = height // self.vae_spatial_compression_ratio latent_width = width // self.vae_spatial_compression_ratio - video_sequence_length = latent_height * latent_width * latent_frames + video_sequence_length = latent_num_frames * latent_height * latent_width sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) mu = calculate_shift( video_sequence_length, @@ -656,6 +673,9 @@ def __call__( encoder_hidden_states=prompt_embeds, timestep=timestep, encoder_attention_mask=prompt_attention_mask, + num_frames=latent_num_frames, + height=latent_height, + width=latent_width, rope_interpolation_scale=rope_interpolation_scale, return_dict=False, )[0] @@ -689,7 +709,16 @@ def __call__( if output_type == "latent": video = latents else: - video = self.decode_latents(latents) + latents = self._unpack_latents(latents, latent_num_frames, latent_height, latent_width, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size) + # unscale/denormalize the latents + latents_mean = self.vae.latents_mean.view(1, self.vae.config.latent_channels, 1, 1, 1).to( + latents.device, latents.dtype + ) + latents_std = self.vae.latents_std.view(1, self.vae.config.latent_channels, 1, 1, 1).to( + latents.device, latents.dtype + ) + latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean + video = self.vae.decode(latents, return_dict=False)[0] video = self.video_processor.postprocess_video(video, output_type=output_type) # Offload all models From 1f008fc93a10d78a679a508c0617489fbf865427 Mon Sep 17 00:00:00 2001 From: Aryan Date: Fri, 29 Nov 2024 09:54:51 +0100 Subject: [PATCH 29/51] image2video --- docs/source/en/api/pipelines/ltx.md | 6 + src/diffusers/__init__.py | 2 + .../models/autoencoders/autoencoder_kl_ltx.py | 2 +- .../models/transformers/transformer_ltx.py | 7 +- src/diffusers/pipelines/__init__.py | 4 +- src/diffusers/pipelines/ltx/__init__.py | 2 + src/diffusers/pipelines/ltx/pipeline_ltx.py | 38 ++- .../pipelines/ltx/pipeline_ltx_image2video.py | 240 +++++++++++++--- tests/pipelines/ltx/test_ltx_image2video.py | 260 ++++++++++++++++++ 9 files changed, 504 insertions(+), 57 deletions(-) create mode 100644 tests/pipelines/ltx/test_ltx_image2video.py diff --git a/docs/source/en/api/pipelines/ltx.md b/docs/source/en/api/pipelines/ltx.md index 18de92fe2804..17032fede952 100644 --- a/docs/source/en/api/pipelines/ltx.md +++ b/docs/source/en/api/pipelines/ltx.md @@ -28,6 +28,12 @@ Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers.m - all - __call__ +## LTXImageToVideoPipeline + +[[autodoc]] LTXImageToVideoPipeline + - all + - __call__ + ## LTXPipelineOutput [[autodoc]] pipelines.ltx.pipeline_output.LTXPipelineOutput diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 5db324f22119..908dc3caad97 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -317,6 +317,7 @@ "LDMTextToImagePipeline", "LEditsPPPipelineStableDiffusion", "LEditsPPPipelineStableDiffusionXL", + "LTXImageToVideoPipeline", "LTXPipeline", "LuminaText2ImgPipeline", "MarigoldDepthPipeline", @@ -790,6 +791,7 @@ LDMTextToImagePipeline, LEditsPPPipelineStableDiffusion, LEditsPPPipelineStableDiffusionXL, + LTXImageToVideoPipeline, LTXPipeline, LuminaText2ImgPipeline, MarigoldDepthPipeline, diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_ltx.py b/src/diffusers/models/autoencoders/autoencoder_kl_ltx.py index 5765f9b9403d..ab06cbbb2f82 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_ltx.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_ltx.py @@ -802,7 +802,7 @@ def __init__( ) latents_mean = torch.zeros((latent_channels,), requires_grad=False) - latents_std = torch.zeros((latent_channels,), requires_grad=False) + latents_std = torch.ones((latent_channels,), requires_grad=False) self.register_buffer("latents_mean", latents_mean, persistent=True) self.register_buffer("latents_std", latents_std, persistent=True) diff --git a/src/diffusers/models/transformers/transformer_ltx.py b/src/diffusers/models/transformers/transformer_ltx.py index 5ec4b04a544d..5472e4814d71 100644 --- a/src/diffusers/models/transformers/transformer_ltx.py +++ b/src/diffusers/models/transformers/transformer_ltx.py @@ -116,7 +116,12 @@ def __init__( self.theta = theta def forward( - self, hidden_states: torch.Tensor, num_frames: int, height: int, width: int, rope_interpolation_scale: Optional[Tuple[torch.Tensor, float, float]] = None + self, + hidden_states: torch.Tensor, + num_frames: int, + height: int, + width: int, + rope_interpolation_scale: Optional[Tuple[torch.Tensor, float, float]] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: batch_size = hidden_states.size(0) diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 621bcce6ae80..2e5c86bdf950 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -245,7 +245,7 @@ ] ) _import_structure["latte"] = ["LattePipeline"] - _import_structure["ltx"] = ["LTXPipeline"] + _import_structure["ltx"] = ["LTXPipeline", "LTXImageToVideoPipeline"] _import_structure["lumina"] = ["LuminaText2ImgPipeline"] _import_structure["marigold"].extend( [ @@ -578,7 +578,7 @@ LEditsPPPipelineStableDiffusion, LEditsPPPipelineStableDiffusionXL, ) - from .ltx import LTXPipeline + from .ltx import LTXImageToVideoPipeline, LTXPipeline from .lumina import LuminaText2ImgPipeline from .marigold import ( MarigoldDepthPipeline, diff --git a/src/diffusers/pipelines/ltx/__init__.py b/src/diffusers/pipelines/ltx/__init__.py index 96fc7b3c24cd..20cc1c216522 100644 --- a/src/diffusers/pipelines/ltx/__init__.py +++ b/src/diffusers/pipelines/ltx/__init__.py @@ -23,6 +23,7 @@ _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) else: _import_structure["pipeline_ltx"] = ["LTXPipeline"] + _import_structure["pipeline_ltx_image2video"] = ["LTXImageToVideoPipeline"] if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: try: @@ -33,6 +34,7 @@ from ...utils.dummy_torch_and_transformers_objects import * else: from .pipeline_ltx import LTXPipeline + from .pipeline_ltx_image2video import LTXImageToVideoPipeline else: import sys diff --git a/src/diffusers/pipelines/ltx/pipeline_ltx.py b/src/diffusers/pipelines/ltx/pipeline_ltx.py index a9b18cb31a33..ad62466ceb98 100644 --- a/src/diffusers/pipelines/ltx/pipeline_ltx.py +++ b/src/diffusers/pipelines/ltx/pipeline_ltx.py @@ -415,6 +415,24 @@ def _unpack_latents( latents = latents.permute(0, 4, 1, 5, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(2, 3) return latents + @staticmethod + def _normalize_latents( + latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0 + ) -> torch.Tensor: + latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) + latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) + latents = (latents - latents_mean) * scaling_factor / latents_std + return latents + + @staticmethod + def _denormalize_latents( + latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0 + ) -> torch.Tensor: + latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) + latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) + latents = latents * latents_std / scaling_factor + latents_mean + return latents + def prepare_latents( self, batch_size: int = 1, @@ -443,7 +461,9 @@ def prepare_latents( ) latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) - latents = self._pack_latents(latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size) + latents = self._pack_latents( + latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size + ) return latents @property @@ -709,15 +729,17 @@ def __call__( if output_type == "latent": video = latents else: - latents = self._unpack_latents(latents, latent_num_frames, latent_height, latent_width, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size) - # unscale/denormalize the latents - latents_mean = self.vae.latents_mean.view(1, self.vae.config.latent_channels, 1, 1, 1).to( - latents.device, latents.dtype + latents = self._unpack_latents( + latents, + latent_num_frames, + latent_height, + latent_width, + self.transformer_spatial_patch_size, + self.transformer_temporal_patch_size, ) - latents_std = self.vae.latents_std.view(1, self.vae.config.latent_channels, 1, 1, 1).to( - latents.device, latents.dtype + latents = self._denormalize_latents( + latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor ) - latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean video = self.vae.decode(latents, return_dict=False)[0] video = self.video_processor.postprocess_video(video, output_type=output_type) diff --git a/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py b/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py index 310d279e4461..350222629dba 100644 --- a/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py +++ b/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py @@ -20,6 +20,7 @@ from transformers import T5EncoderModel, T5TokenizerFast from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...image_processor import PipelineImageInput from ...models.autoencoders import AutoencoderKLLTX from ...models.transformers import LTXTransformer3DModel from ...schedulers import FlowMatchEulerDiscreteScheduler @@ -139,6 +140,20 @@ def retrieve_timesteps( return timesteps, num_inference_steps +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + class LTXImageToVideoPipeline(DiffusionPipeline): r""" Pipeline for image-to-video generation. @@ -146,11 +161,11 @@ class LTXImageToVideoPipeline(DiffusionPipeline): Reference: https://github.com/Lightricks/LTX-Video Args: - transformer ([`MochiTransformer3DModel`]): + transformer ([`LTXTransformer3DModel`]): Conditional Transformer architecture to denoise the encoded video latents. scheduler ([`FlowMatchEulerDiscreteScheduler`]): A scheduler to be used in combination with `transformer` to denoise the encoded image latents. - vae ([`AutoencoderKL`]): + vae ([`AutoencoderKLLTX`]): Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. text_encoder ([`T5EncoderModel`]): [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically @@ -185,14 +200,14 @@ def __init__( scheduler=scheduler, ) - self.vae_spatial_scale_factor = self.vae.spatial_compression_ratio if hasattr(self, "vae") else 32 - self.vae_temporal_scale_factor = self.vae.temporal_compression_ratio if hasattr(self, "vae") else 8 + self.vae_spatial_compression_ratio = self.vae.spatial_compression_ratio if hasattr(self, "vae") else 32 + self.vae_temporal_compression_ratio = self.vae.temporal_compression_ratio if hasattr(self, "vae") else 8 self.transformer_spatial_patch_size = self.transformer.config.patch_size if hasattr(self, "transformer") else 1 self.transformer_temporal_patch_size = ( self.transformer.config.patch_size_t if hasattr(self, "transformer") else 1 ) - self.video_processor = VideoProcessor(vae_scale_factor=self.vae_spatial_scale_factor) + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_spatial_compression_ratio) self.tokenizer_max_length = ( self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 128 ) @@ -387,46 +402,122 @@ def check_inputs( f" {negative_prompt_attention_mask.shape}." ) + @staticmethod + # Copied from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline._pack_latents + def _pack_latents(latents: torch.Tensor, patch_size: int = 1, patch_size_t: int = 1) -> torch.Tensor: + batch_size, num_channels, num_frames, height, width = latents.shape + post_patch_num_frames = num_frames // patch_size_t + post_patch_height = height // patch_size + post_patch_width = width // patch_size + latents = latents.reshape( + batch_size, + -1, + post_patch_num_frames, + patch_size_t, + post_patch_height, + patch_size, + post_patch_width, + patch_size, + ) + latents = latents.permute(0, 2, 4, 6, 1, 3, 5, 7).flatten(4, 7).flatten(1, 3) + return latents + + @staticmethod + # Copied from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline._unpack_latents + def _unpack_latents( + latents: torch.Tensor, num_frames: int, height: int, width: int, patch_size: int = 1, patch_size_t: int = 1 + ) -> torch.Tensor: + batch_size, num_channels, video_sequence_length = latents.shape + latents = latents.reshape(batch_size, num_frames, height, width, -1, patch_size_t, patch_size, patch_size) + latents = latents.permute(0, 4, 1, 5, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(2, 3) + return latents + + @staticmethod + # Copied from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline._normalize_latents + def _normalize_latents( + latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0 + ) -> torch.Tensor: + latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) + latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) + latents = (latents - latents_mean) * scaling_factor / latents_std + return latents + + @staticmethod + # Copied from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline._denormalize_latents + def _denormalize_latents( + latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0 + ) -> torch.Tensor: + latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) + latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) + latents = latents * latents_std / scaling_factor + latents_mean + return latents + def prepare_latents( self, - batch_size, - num_channels_latents, - height, - width, - num_frames, - dtype, - device, - generator, - latents=None, - ): - height = height // self.vae_spatial_scale_factor - width = width // self.vae_spatial_scale_factor - num_frames = (num_frames - 1) // self.vae_temporal_scale_factor + 1 + image: Optional[torch.Tensor] = None, + batch_size: int = 1, + num_channels_latents: int = 128, + height: int = 512, + width: int = 704, + num_frames: int = 161, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + generator: Optional[torch.Generator] = None, + latents: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + height = height // self.vae_spatial_compression_ratio + width = width // self.vae_spatial_compression_ratio + num_frames = ( + (num_frames - 1) // self.vae_temporal_compression_ratio + 1 if latents is None else latents.size(2) + ) shape = (batch_size, num_channels_latents, num_frames, height, width) + mask_shape = (batch_size, 1, num_frames, height, width) if latents is not None: - return latents.to(device=device, dtype=dtype) - if isinstance(generator, list) and len(generator) != batch_size: - raise ValueError( - f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" - f" size of {batch_size}. Make sure the batch size matches the length of the generators." + conditioning_mask = latents.new_zeros(shape) + conditioning_mask[:, :, 0] = 1.0 + conditioning_mask = self._pack_latents( + conditioning_mask, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size ) + return latents.to(device=device, dtype=dtype), conditioning_mask - latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) - return latents + if isinstance(generator, list): + if len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + init_latents = [ + retrieve_latents(self.vae.encode(image[i].unsqueeze(0).unsqueeze(2)), generator[i]) + for i in range(batch_size) + ] + else: + init_latents = [ + retrieve_latents(self.vae.encode(img.unsqueeze(0).unsqueeze(2)), generator) for img in image + ] + + init_latents = torch.cat(init_latents, dim=0).to(dtype) + init_latents = self._normalize_latents(init_latents, self.vae.latents_mean, self.vae.latents_std) + init_latents = init_latents.repeat(1, 1, num_frames, 1, 1) + conditioning_mask = torch.zeros(mask_shape, device=device, dtype=dtype) + conditioning_mask[:, :, 0] = 1.0 + + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = init_latents * conditioning_mask + noise * (1 - conditioning_mask) - def decode_latents(self, latents: torch.Tensor): - # unscale/denormalize the latents - latents_mean = self.vae.latents_mean.view(1, self.vae.config.latent_channels, 1, 1, 1).to( - latents.device, latents.dtype + init_latents = self._pack_latents( + init_latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size ) - latents_std = self.vae.latents_std.view(1, self.vae.config.latent_channels, 1, 1, 1).to( - latents.device, latents.dtype + conditioning_mask = self._pack_latents( + conditioning_mask, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size + ).squeeze(-1) + latents = self._pack_latents( + latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size ) - latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean - video = self.vae.decode(latents, return_dict=False)[0] - return video + + return latents, conditioning_mask @property def guidance_scale(self): @@ -448,6 +539,7 @@ def interrupt(self): @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, + image: PipelineImageInput = None, prompt: Union[str, List[str]] = None, negative_prompt: Optional[Union[str, List[str]]] = None, height: Optional[int] = None, @@ -481,7 +573,7 @@ def __call__( The height in pixels of the generated image. This is set to 480 by default for the best results. width (`int`, *optional*, defaults to `self.default_width`): The width in pixels of the generated image. This is set to 848 by default for the best results. - num_frames (`int`, defaults to `19`): + num_frames (`int`, defaults to `81 `): The number of video frames to generate num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the @@ -490,7 +582,7 @@ def __call__( Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed will be used. Must be in descending order. - guidance_scale (`float`, defaults to `4.5`): + guidance_scale (`float`, defaults to `3 `): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > @@ -529,7 +621,7 @@ def __call__( The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the `._callback_tensor_inputs` attribute of your pipeline class. - max_sequence_length (`int` defaults to `256`): + max_sequence_length (`int` defaults to `128 `): Maximum sequence length to use with the `prompt`. Examples: @@ -545,7 +637,7 @@ def __call__( height = height or self.default_height width = width or self.default_width - latent_frame_rate = frame_rate // self.vae_temporal_scale_factor + latent_frame_rate = frame_rate // self.vae_temporal_compression_ratio # 1. Check inputs. Raise error if not correct self.check_inputs( @@ -595,8 +687,13 @@ def __call__( prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0) # 4. Prepare latent variables + if latents is None: + image = self.video_processor.preprocess(image, height=height, width=width) + image = image.to(device=device, dtype=prompt_embeds.dtype) + num_channels_latents = self.transformer.config.in_channels - latents = self.prepare_latents( + latents, conditioning_mask = self.prepare_latents( + image, batch_size * num_videos_per_prompt, num_channels_latents, height, @@ -608,11 +705,17 @@ def __call__( latents, ) + if self.do_classifier_free_guidance: + conditioning_mask = torch.cat([conditioning_mask, conditioning_mask]) + # 5. Prepare timesteps + latent_num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1 + latent_height = height // self.vae_spatial_compression_ratio + latent_width = width // self.vae_spatial_compression_ratio + video_sequence_length = latent_num_frames * latent_height * latent_width sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) - image_seq_len = latents.size(1) mu = calculate_shift( - image_seq_len, + video_sequence_length, self.scheduler.config.base_image_seq_len, self.scheduler.config.max_image_seq_len, self.scheduler.config.base_shift, @@ -632,8 +735,8 @@ def __call__( # 6. Prepare micro-conditions rope_interpolation_scale = ( 1 / latent_frame_rate, - self.vae_spatial_scale_factor, - self.vae_spatial_scale_factor, + self.vae_spatial_compression_ratio, + self.vae_spatial_compression_ratio, ) # 7. Denoising loop @@ -646,12 +749,16 @@ def __call__( # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = t.expand(latent_model_input.shape[0]).to(latents.dtype) + timestep = timestep.unsqueeze(-1) * (1 - conditioning_mask) noise_pred = self.transformer( hidden_states=latent_model_input, encoder_hidden_states=prompt_embeds, timestep=timestep, encoder_attention_mask=prompt_attention_mask, + num_frames=latent_num_frames, + height=latent_height, + width=latent_width, rope_interpolation_scale=rope_interpolation_scale, return_dict=False, )[0] @@ -660,11 +767,43 @@ def __call__( if self.do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + timestep, _ = timestep.chunk(2) # compute the previous noisy sample x_t -> x_t-1 latents_dtype = latents.dtype - latents = self.scheduler.step(noise_pred, t, latents.float(), return_dict=False)[0] + # latents = self.scheduler.step(noise_pred, t, latents.float(), return_dict=False)[0] + # latents = latents.to(dtype=latents_dtype) + + # ============= TODO(aryan): needs a look by YiYi + latents = latents.float() + + noise_pred = self._unpack_latents( + noise_pred, + latent_num_frames, + latent_height, + latent_width, + self.transformer_spatial_patch_size, + self.transformer_temporal_patch_size, + ) + latents = self._unpack_latents( + latents, + latent_num_frames, + latent_height, + latent_width, + self.transformer_spatial_patch_size, + self.transformer_temporal_patch_size, + ) + + noise_pred = noise_pred[:, :, 1:] + noise_latents = latents[:, :, 1:] + pred_latents = self.scheduler.step(noise_pred, t, noise_latents, return_dict=False)[0] + + latents = torch.cat([latents[:, :, :1], pred_latents], dim=2) + latents = self._pack_latents( + latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size + ) latents = latents.to(dtype=latents_dtype) + # ============= if callback_on_step_end is not None: callback_kwargs = {} @@ -685,7 +824,18 @@ def __call__( if output_type == "latent": video = latents else: - video = self.decode_latents(latents) + latents = self._unpack_latents( + latents, + latent_num_frames, + latent_height, + latent_width, + self.transformer_spatial_patch_size, + self.transformer_temporal_patch_size, + ) + latents = self._denormalize_latents( + latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor + ) + video = self.vae.decode(latents, return_dict=False)[0] video = self.video_processor.postprocess_video(video, output_type=output_type) # Offload all models diff --git a/tests/pipelines/ltx/test_ltx_image2video.py b/tests/pipelines/ltx/test_ltx_image2video.py new file mode 100644 index 000000000000..c5350b662ccf --- /dev/null +++ b/tests/pipelines/ltx/test_ltx_image2video.py @@ -0,0 +1,260 @@ +# Copyright 2024 The HuggingFace Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import unittest + +import numpy as np +import torch +from transformers import AutoTokenizer, T5EncoderModel +from PIL import Image + +from diffusers import AutoencoderKLLTX, FlowMatchEulerDiscreteScheduler, LTXImageToVideoPipeline, LTXTransformer3DModel +from diffusers.utils.testing_utils import enable_full_determinism, torch_device + +from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS +from ..test_pipelines_common import PipelineTesterMixin, to_np + + +enable_full_determinism() + + +class LTXImageToVideoPipelineFastTests(PipelineTesterMixin, unittest.TestCase): + pipeline_class = LTXImageToVideoPipeline + params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"} + batch_params = TEXT_TO_IMAGE_BATCH_PARAMS.union({"image"}) + image_params = TEXT_TO_IMAGE_IMAGE_PARAMS + image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS + required_optional_params = frozenset( + [ + "num_inference_steps", + "generator", + "latents", + "return_dict", + "callback_on_step_end", + "callback_on_step_end_tensor_inputs", + ] + ) + test_xformers_attention = False + + def get_dummy_components(self): + torch.manual_seed(0) + transformer = LTXTransformer3DModel( + in_channels=8, + out_channels=8, + patch_size=1, + patch_size_t=1, + num_attention_heads=4, + attention_head_dim=8, + cross_attention_dim=32, + num_layers=1, + caption_channels=32, + ) + + torch.manual_seed(0) + vae = AutoencoderKLLTX( + latent_channels=8, + block_out_channels=(8, 8, 8, 8), + spatio_temporal_scaling=(True, True, False, False), + layers_per_block=(1, 1, 1, 1, 1), + patch_size=1, + patch_size_t=1, + encoder_causal=True, + decoder_causal=False, + ) + vae.use_framewise_encoding = False + vae.use_framewise_decoding = False + + torch.manual_seed(0) + scheduler = FlowMatchEulerDiscreteScheduler() + text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5") + tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5") + + components = { + "transformer": transformer, + "vae": vae, + "scheduler": scheduler, + "text_encoder": text_encoder, + "tokenizer": tokenizer, + } + return components + + def get_dummy_inputs(self, device, seed=0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + + image = torch.randn((1, 3, 32, 32), generator=generator, device=device) + + inputs = { + "image": image, + "prompt": "dance monkey", + "negative_prompt": "", + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 3.0, + "height": 32, + "width": 32, + # 8 * k + 1 is the recommendation + "num_frames": 9, + "max_sequence_length": 16, + "output_type": "pt", + } + + return inputs + + def test_inference(self): + device = "cpu" + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + video = pipe(**inputs).frames + generated_video = video[0] + + self.assertEqual(generated_video.shape, (9, 3, 32, 32)) + expected_video = torch.randn(9, 3, 32, 32) + max_diff = np.abs(generated_video - expected_video).max() + self.assertLessEqual(max_diff, 1e10) + + def test_callback_inputs(self): + sig = inspect.signature(self.pipeline_class.__call__) + has_callback_tensor_inputs = "callback_on_step_end_tensor_inputs" in sig.parameters + has_callback_step_end = "callback_on_step_end" in sig.parameters + + if not (has_callback_tensor_inputs and has_callback_step_end): + return + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + self.assertTrue( + hasattr(pipe, "_callback_tensor_inputs"), + f" {self.pipeline_class} should have `_callback_tensor_inputs` that defines a list of tensor variables its callback function can use as inputs", + ) + + def callback_inputs_subset(pipe, i, t, callback_kwargs): + # iterate over callback args + for tensor_name, tensor_value in callback_kwargs.items(): + # check that we're only passing in allowed tensor inputs + assert tensor_name in pipe._callback_tensor_inputs + + return callback_kwargs + + def callback_inputs_all(pipe, i, t, callback_kwargs): + for tensor_name in pipe._callback_tensor_inputs: + assert tensor_name in callback_kwargs + + # iterate over callback args + for tensor_name, tensor_value in callback_kwargs.items(): + # check that we're only passing in allowed tensor inputs + assert tensor_name in pipe._callback_tensor_inputs + + return callback_kwargs + + inputs = self.get_dummy_inputs(torch_device) + + # Test passing in a subset + inputs["callback_on_step_end"] = callback_inputs_subset + inputs["callback_on_step_end_tensor_inputs"] = ["latents"] + output = pipe(**inputs)[0] + + # Test passing in a everything + inputs["callback_on_step_end"] = callback_inputs_all + inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs + output = pipe(**inputs)[0] + + def callback_inputs_change_tensor(pipe, i, t, callback_kwargs): + is_last = i == (pipe.num_timesteps - 1) + if is_last: + callback_kwargs["latents"] = torch.zeros_like(callback_kwargs["latents"]) + return callback_kwargs + + inputs["callback_on_step_end"] = callback_inputs_change_tensor + inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs + output = pipe(**inputs)[0] + assert output.abs().sum() < 1e10 + + def test_inference_batch_single_identical(self): + self._test_inference_batch_single_identical(batch_size=3, expected_max_diff=1e-3) + + def test_attention_slicing_forward_pass( + self, test_max_difference=True, test_mean_pixel_difference=True, expected_max_diff=1e-3 + ): + if not self.test_attention_slicing: + return + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + for component in pipe.components.values(): + if hasattr(component, "set_default_attn_processor"): + component.set_default_attn_processor() + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + generator_device = "cpu" + inputs = self.get_dummy_inputs(generator_device) + output_without_slicing = pipe(**inputs)[0] + + pipe.enable_attention_slicing(slice_size=1) + inputs = self.get_dummy_inputs(generator_device) + output_with_slicing1 = pipe(**inputs)[0] + + pipe.enable_attention_slicing(slice_size=2) + inputs = self.get_dummy_inputs(generator_device) + output_with_slicing2 = pipe(**inputs)[0] + + if test_max_difference: + max_diff1 = np.abs(to_np(output_with_slicing1) - to_np(output_without_slicing)).max() + max_diff2 = np.abs(to_np(output_with_slicing2) - to_np(output_without_slicing)).max() + self.assertLess( + max(max_diff1, max_diff2), + expected_max_diff, + "Attention slicing should not affect the inference results", + ) + + def test_vae_tiling(self, expected_diff_max: float = 0.2): + generator_device = "cpu" + components = self.get_dummy_components() + + pipe = self.pipeline_class(**components) + pipe.to("cpu") + pipe.set_progress_bar_config(disable=None) + + # Without tiling + inputs = self.get_dummy_inputs(generator_device) + inputs["height"] = inputs["width"] = 128 + output_without_tiling = pipe(**inputs)[0] + + # With tiling + pipe.vae.enable_tiling( + tile_sample_min_height=96, + tile_sample_min_width=96, + tile_sample_stride_height=64, + tile_sample_stride_width=64, + ) + inputs = self.get_dummy_inputs(generator_device) + inputs["height"] = inputs["width"] = 128 + output_with_tiling = pipe(**inputs)[0] + + self.assertLess( + (to_np(output_without_tiling) - to_np(output_with_tiling)).max(), + expected_diff_max, + "VAE tiling should not affect the inference results", + ) From 8e16389c2ef6e4fd30496fe3835daa339baac01a Mon Sep 17 00:00:00 2001 From: Aryan Date: Fri, 29 Nov 2024 09:55:00 +0100 Subject: [PATCH 30/51] update --- tests/pipelines/ltx/test_ltx_image2video.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/pipelines/ltx/test_ltx_image2video.py b/tests/pipelines/ltx/test_ltx_image2video.py index c5350b662ccf..c5f4009fcf54 100644 --- a/tests/pipelines/ltx/test_ltx_image2video.py +++ b/tests/pipelines/ltx/test_ltx_image2video.py @@ -18,7 +18,6 @@ import numpy as np import torch from transformers import AutoTokenizer, T5EncoderModel -from PIL import Image from diffusers import AutoencoderKLLTX, FlowMatchEulerDiscreteScheduler, LTXImageToVideoPipeline, LTXTransformer3DModel from diffusers.utils.testing_utils import enable_full_determinism, torch_device @@ -97,7 +96,7 @@ def get_dummy_inputs(self, device, seed=0): generator = torch.Generator(device=device).manual_seed(seed) image = torch.randn((1, 3, 32, 32), generator=generator, device=device) - + inputs = { "image": image, "prompt": "dance monkey", From 57c41dfa968b2b470b8cedcc8a9bcd13b107b9ed Mon Sep 17 00:00:00 2001 From: Aryan Date: Fri, 29 Nov 2024 09:59:31 +0100 Subject: [PATCH 31/51] make fix-copies --- .../utils/dummy_torch_and_transformers_objects.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index 4b54638fc652..bec3ebe9601f 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -1067,6 +1067,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class LTXImageToVideoPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class LTXPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] From 606e6b25bed90a6d14247f4f2be1f15926391c76 Mon Sep 17 00:00:00 2001 From: Aryan Date: Fri, 29 Nov 2024 13:43:53 +0100 Subject: [PATCH 32/51] update --- .../pipelines/ltx/pipeline_ltx_image2video.py | 12 ++++++++---- .../transformers/test_models_transformer_ltx.py | 11 +++++++---- 2 files changed, 15 insertions(+), 8 deletions(-) diff --git a/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py b/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py index 350222629dba..ba7be0432c64 100644 --- a/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py +++ b/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py @@ -44,16 +44,20 @@ Examples: ```py >>> import torch - >>> from diffusers import LTXPipeline - >>> from diffusers.utils import export_to_video + >>> from diffusers import LTXImageToVideoPipeline + >>> from diffusers.utils import export_to_video, load_image - >>> pipe = LTXPipeline.from_pretrained("a-r-r-o-w/LTX-Video-diffusers", torch_dtype=torch.bfloat16) + >>> pipe = LTXImageToVideoPipeline.from_pretrained("a-r-r-o-w/LTX-Video-diffusers", torch_dtype=torch.bfloat16) >>> pipe.to("cuda") - >>> prompt = "A woman with long brown hair and light skin smiles at another woman with long blonde hair. The woman with brown hair wears a black jacket and has a small, barely noticeable mole on her right cheek. The camera angle is a close-up, focused on the woman with brown hair's face. The lighting is warm and natural, likely from the setting sun, casting a soft glow on the scene. The scene appears to be real-life footage" + >>> image = load_image( + ... "https://huggingface.co/datasets/a-r-r-o-w/tiny-meme-dataset-captioned/resolve/main/images/8.png" + ... ) + >>> prompt = "A young girl stands calmly in the foreground, looking directly at the camera, as a house fire rages in the background. Flames engulf the structure, with smoke billowing into the air. Firefighters in protective gear rush to the scene, a fire truck labeled '38' visible behind them. The girl's neutral expression contrasts sharply with the chaos of the fire, creating a poignant and emotionally charged scene." >>> negative_prompt = "worst quality, inconsistent motion, blurry, jittery, distorted" >>> video = pipe( + ... image=image, ... prompt=prompt, ... negative_prompt=negative_prompt, ... width=704, diff --git a/tests/models/transformers/test_models_transformer_ltx.py b/tests/models/transformers/test_models_transformer_ltx.py index 72a77a23f03c..1b0e4f0e4181 100644 --- a/tests/models/transformers/test_models_transformer_ltx.py +++ b/tests/models/transformers/test_models_transformer_ltx.py @@ -26,7 +26,7 @@ enable_full_determinism() -class MochiTransformerTests(ModelTesterMixin, unittest.TestCase): +class LTXTransformerTests(ModelTesterMixin, unittest.TestCase): model_class = LTXTransformer3DModel main_input_name = "hidden_states" uses_custom_attn_processor = True @@ -41,7 +41,7 @@ def dummy_input(self): embedding_dim = 16 sequence_length = 16 - hidden_states = torch.randn((batch_size, num_channels, num_frames, height, width)).to(torch_device) + hidden_states = torch.randn((batch_size, num_frames * height * width, num_channels)).to(torch_device) encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device) encoder_attention_mask = torch.ones((batch_size, sequence_length)).bool().to(torch_device) timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device) @@ -51,15 +51,18 @@ def dummy_input(self): "encoder_hidden_states": encoder_hidden_states, "timestep": timestep, "encoder_attention_mask": encoder_attention_mask, + "num_frames": num_frames, + "height": height, + "width": width, } @property def input_shape(self): - return (4, 2, 16, 16) + return (512, 4) @property def output_shape(self): - return (4, 2, 16, 16) + return (512, 4) def prepare_init_args_and_inputs_for_common(self): init_dict = { From f4b5341087f65c229d81b81aa395682f419e4108 Mon Sep 17 00:00:00 2001 From: Aryan Date: Fri, 29 Nov 2024 14:03:17 +0100 Subject: [PATCH 33/51] update --- src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py b/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py index ba7be0432c64..785bceaef484 100644 --- a/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py +++ b/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py @@ -511,9 +511,6 @@ def prepare_latents( noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) latents = init_latents * conditioning_mask + noise * (1 - conditioning_mask) - init_latents = self._pack_latents( - init_latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size - ) conditioning_mask = self._pack_latents( conditioning_mask, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size ).squeeze(-1) @@ -570,6 +567,8 @@ def __call__( Function invoked when calling the pipeline for generation. Args: + image (`PipelineImageInput`): + The input image to condition the generation on. Must be an image, a list of images or a `torch.Tensor`. prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. From d556b7feb526ff5ee0b2eb81fdf5f470a0cd8e53 Mon Sep 17 00:00:00 2001 From: Aryan Date: Sat, 30 Nov 2024 06:33:53 +0100 Subject: [PATCH 34/51] rope scale fix --- src/diffusers/pipelines/ltx/pipeline_ltx.py | 2 +- src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/ltx/pipeline_ltx.py b/src/diffusers/pipelines/ltx/pipeline_ltx.py index ad62466ceb98..c4d1f2d565fc 100644 --- a/src/diffusers/pipelines/ltx/pipeline_ltx.py +++ b/src/diffusers/pipelines/ltx/pipeline_ltx.py @@ -583,7 +583,7 @@ def __call__( height = height or self.default_height width = width or self.default_width - latent_frame_rate = frame_rate // self.vae_temporal_compression_ratio + latent_frame_rate = frame_rate / self.vae_temporal_compression_ratio # 1. Check inputs. Raise error if not correct self.check_inputs( diff --git a/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py b/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py index 785bceaef484..e225ce36dc8a 100644 --- a/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py +++ b/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py @@ -640,7 +640,7 @@ def __call__( height = height or self.default_height width = width or self.default_width - latent_frame_rate = frame_rate // self.vae_temporal_compression_ratio + latent_frame_rate = frame_rate / self.vae_temporal_compression_ratio # 1. Check inputs. Raise error if not correct self.check_inputs( From 42ca5e6f0946497b8f454bc3bfb1647a802674d6 Mon Sep 17 00:00:00 2001 From: Aryan Date: Sat, 30 Nov 2024 14:56:21 +0100 Subject: [PATCH 35/51] debug layerwise code --- .../models/transformers/transformer_ltx.py | 38 +++++++++++++++++-- src/diffusers/pipelines/ltx/pipeline_ltx.py | 5 ++- 2 files changed, 37 insertions(+), 6 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_ltx.py b/src/diffusers/models/transformers/transformer_ltx.py index 5472e4814d71..5db60b8d83c3 100644 --- a/src/diffusers/models/transformers/transformer_ltx.py +++ b/src/diffusers/models/transformers/transformer_ltx.py @@ -73,10 +73,14 @@ def __call__( query = attn.norm_q(query) key = attn.norm_k(key) + # torch.save(query, "query.pt") + # torch.save(key, "key.pt") if image_rotary_emb is not None and use_rotary_emb: query = apply_rotary_emb(query, image_rotary_emb) key = apply_rotary_emb(key, image_rotary_emb) + # torch.save(query, "query_rope.pt") + # torch.save(key, "key_rope.pt") query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2) key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2) @@ -85,11 +89,13 @@ def __call__( hidden_states = F.scaled_dot_product_attention( query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False ) + # torch.save(hidden_states, "sdpa.pt") hidden_states = hidden_states.transpose(1, 2).flatten(2, 3) hidden_states = hidden_states.to(query.dtype) hidden_states = attn.to_out[0](hidden_states) hidden_states = attn.to_out[1](hidden_states) + # torch.save(hidden_states, "to_out.pt") return hidden_states @@ -243,19 +249,22 @@ def forward( ) -> torch.Tensor: batch_size = hidden_states.size(0) norm_hidden_states = self.norm1(hidden_states) + # torch.save(norm_hidden_states, "block_norm1.pt") num_ada_params = self.scale_shift_table.shape[0] ada_values = self.scale_shift_table[None, None] + temb.reshape(batch_size, temb.size(1), num_ada_params, -1) - shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ada_values.unbind(dim=2) norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa + # torch.save(norm_hidden_states, "block_scale_shift1.pt") attn_hidden_states = self.attn1( hidden_states=norm_hidden_states, encoder_hidden_states=None, image_rotary_emb=image_rotary_emb, ) + # torch.save(attn_hidden_states, "block_attn1.pt") hidden_states = hidden_states + attn_hidden_states * gate_msa + # torch.save(hidden_states, "block_attn1_result.pt") attn_hidden_states = self.attn2( hidden_states, @@ -264,10 +273,13 @@ def forward( attention_mask=encoder_attention_mask, ) hidden_states = hidden_states + attn_hidden_states + # torch.save(hidden_states, "block_attn2.pt") norm_hidden_states = self.norm2(hidden_states) * (1 + scale_mlp) + shift_mlp + # torch.save(norm_hidden_states, "block_scale_shift2.pt") ff_output = self.ff(norm_hidden_states) hidden_states = hidden_states + ff_output * gate_mlp + # torch.save(hidden_states, "block_out.pt") return hidden_states @@ -383,14 +395,20 @@ def forward( return_dict: bool = True, ) -> torch.Tensor: image_rotary_emb = self.rope(hidden_states, num_frames, height, width, rope_interpolation_scale) + # torch.save(image_rotary_emb, "rope.pt") # convert encoder_attention_mask to a bias the same way we do for attention_mask if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2: encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0 encoder_attention_mask = encoder_attention_mask.unsqueeze(1) + # torch.save(hidden_states, "input_hidden_states.pt") + # torch.save(encoder_hidden_states, "input_encoder_hidden_states.pt") + # torch.save(encoder_attention_mask, "input_encoder_attention_mask.pt") + batch_size = hidden_states.size(0) hidden_states = self.proj_in(hidden_states) + # torch.save(hidden_states, "patchify_proj.pt") temb, embedded_timestep = self.time_embed( timestep.flatten(), @@ -400,11 +418,14 @@ def forward( temb = temb.view(batch_size, -1, temb.size(-1)) embedded_timestep = embedded_timestep.view(batch_size, -1, embedded_timestep.size(-1)) + # torch.save(temb, "timestep.pt") + # torch.save(embedded_timestep, "embedded_timestep.pt") encoder_hidden_states = self.caption_projection(encoder_hidden_states) encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.size(-1)) + # torch.save(encoder_hidden_states, "caption_projection.pt") - for block in self.transformer_blocks: + for i, block in enumerate(self.transformer_blocks): if torch.is_grad_enabled() and self.gradient_checkpointing: def create_custom_forward(module, return_dict=None): @@ -435,12 +456,19 @@ def custom_forward(*inputs): encoder_attention_mask=encoder_attention_mask, ) + # print(f"block_{i}:", hidden_states.flatten().mean(), hidden_states.flatten().std()) + # torch.save(hidden_states, f"block_{i}.pt") + scale_shift_values = self.scale_shift_table[None, None] + embedded_timestep[:, :, None] shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1] hidden_states = self.norm_out(hidden_states) + # torch.save(hidden_states, "norm_out.pt") hidden_states = hidden_states * (1 + scale) + shift + # torch.save(hidden_states, "scale_shift.pt") output = self.proj_out(hidden_states) + # torch.save(output, "proj_out.pt") + # exit() if not return_dict: return (output,) @@ -450,8 +478,10 @@ def custom_forward(*inputs): def apply_rotary_emb(x, freqs): cos, sin = freqs - x_real, x_imag = x.unflatten(2, (-1, 2)).unbind(-1) # [B, S, H, D//2] + x_real, x_imag = x.unflatten(2, (-1, 2)).unbind(-1) # [B, S, D//2, 2] x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(2) - out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype) + # breakpoint() + # out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype) + out = x * cos + x_rotated * sin return out diff --git a/src/diffusers/pipelines/ltx/pipeline_ltx.py b/src/diffusers/pipelines/ltx/pipeline_ltx.py index c4d1f2d565fc..d509b164742f 100644 --- a/src/diffusers/pipelines/ltx/pipeline_ltx.py +++ b/src/diffusers/pipelines/ltx/pipeline_ltx.py @@ -699,7 +699,7 @@ def __call__( rope_interpolation_scale=rope_interpolation_scale, return_dict=False, )[0] - noise_pred = noise_pred.float() + # noise_pred = noise_pred.float() if self.do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) @@ -707,7 +707,8 @@ def __call__( # compute the previous noisy sample x_t -> x_t-1 latents_dtype = latents.dtype - latents = self.scheduler.step(noise_pred, t, latents.float(), return_dict=False)[0] + # latents = self.scheduler.step(noise_pred, t, latents.float(), return_dict=False)[0] + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] latents = latents.to(dtype=latents_dtype) if callback_on_step_end is not None: From 25023994529c06e3bcf5d2e97cea73ffa3f8f36c Mon Sep 17 00:00:00 2001 From: Aryan Date: Sat, 30 Nov 2024 17:17:29 +0100 Subject: [PATCH 36/51] remove debug --- .../models/transformers/transformer_ltx.py | 40 ++----------------- src/diffusers/pipelines/ltx/pipeline_ltx.py | 5 +-- 2 files changed, 5 insertions(+), 40 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_ltx.py b/src/diffusers/models/transformers/transformer_ltx.py index 5db60b8d83c3..81ffde29d65c 100644 --- a/src/diffusers/models/transformers/transformer_ltx.py +++ b/src/diffusers/models/transformers/transformer_ltx.py @@ -73,14 +73,10 @@ def __call__( query = attn.norm_q(query) key = attn.norm_k(key) - # torch.save(query, "query.pt") - # torch.save(key, "key.pt") if image_rotary_emb is not None and use_rotary_emb: query = apply_rotary_emb(query, image_rotary_emb) key = apply_rotary_emb(key, image_rotary_emb) - # torch.save(query, "query_rope.pt") - # torch.save(key, "key_rope.pt") query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2) key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2) @@ -89,14 +85,11 @@ def __call__( hidden_states = F.scaled_dot_product_attention( query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False ) - # torch.save(hidden_states, "sdpa.pt") hidden_states = hidden_states.transpose(1, 2).flatten(2, 3) hidden_states = hidden_states.to(query.dtype) hidden_states = attn.to_out[0](hidden_states) hidden_states = attn.to_out[1](hidden_states) - # torch.save(hidden_states, "to_out.pt") - return hidden_states @@ -249,22 +242,18 @@ def forward( ) -> torch.Tensor: batch_size = hidden_states.size(0) norm_hidden_states = self.norm1(hidden_states) - # torch.save(norm_hidden_states, "block_norm1.pt") num_ada_params = self.scale_shift_table.shape[0] ada_values = self.scale_shift_table[None, None] + temb.reshape(batch_size, temb.size(1), num_ada_params, -1) shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ada_values.unbind(dim=2) norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa - # torch.save(norm_hidden_states, "block_scale_shift1.pt") attn_hidden_states = self.attn1( hidden_states=norm_hidden_states, encoder_hidden_states=None, image_rotary_emb=image_rotary_emb, ) - # torch.save(attn_hidden_states, "block_attn1.pt") hidden_states = hidden_states + attn_hidden_states * gate_msa - # torch.save(hidden_states, "block_attn1_result.pt") attn_hidden_states = self.attn2( hidden_states, @@ -273,13 +262,10 @@ def forward( attention_mask=encoder_attention_mask, ) hidden_states = hidden_states + attn_hidden_states - # torch.save(hidden_states, "block_attn2.pt") norm_hidden_states = self.norm2(hidden_states) * (1 + scale_mlp) + shift_mlp - # torch.save(norm_hidden_states, "block_scale_shift2.pt") ff_output = self.ff(norm_hidden_states) hidden_states = hidden_states + ff_output * gate_mlp - # torch.save(hidden_states, "block_out.pt") return hidden_states @@ -395,20 +381,14 @@ def forward( return_dict: bool = True, ) -> torch.Tensor: image_rotary_emb = self.rope(hidden_states, num_frames, height, width, rope_interpolation_scale) - # torch.save(image_rotary_emb, "rope.pt") # convert encoder_attention_mask to a bias the same way we do for attention_mask if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2: encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0 encoder_attention_mask = encoder_attention_mask.unsqueeze(1) - # torch.save(hidden_states, "input_hidden_states.pt") - # torch.save(encoder_hidden_states, "input_encoder_hidden_states.pt") - # torch.save(encoder_attention_mask, "input_encoder_attention_mask.pt") - batch_size = hidden_states.size(0) hidden_states = self.proj_in(hidden_states) - # torch.save(hidden_states, "patchify_proj.pt") temb, embedded_timestep = self.time_embed( timestep.flatten(), @@ -418,14 +398,11 @@ def forward( temb = temb.view(batch_size, -1, temb.size(-1)) embedded_timestep = embedded_timestep.view(batch_size, -1, embedded_timestep.size(-1)) - # torch.save(temb, "timestep.pt") - # torch.save(embedded_timestep, "embedded_timestep.pt") encoder_hidden_states = self.caption_projection(encoder_hidden_states) encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.size(-1)) - # torch.save(encoder_hidden_states, "caption_projection.pt") - for i, block in enumerate(self.transformer_blocks): + for block in self.transformer_blocks: if torch.is_grad_enabled() and self.gradient_checkpointing: def create_custom_forward(module, return_dict=None): @@ -456,19 +433,12 @@ def custom_forward(*inputs): encoder_attention_mask=encoder_attention_mask, ) - # print(f"block_{i}:", hidden_states.flatten().mean(), hidden_states.flatten().std()) - # torch.save(hidden_states, f"block_{i}.pt") - scale_shift_values = self.scale_shift_table[None, None] + embedded_timestep[:, :, None] shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1] hidden_states = self.norm_out(hidden_states) - # torch.save(hidden_states, "norm_out.pt") hidden_states = hidden_states * (1 + scale) + shift - # torch.save(hidden_states, "scale_shift.pt") output = self.proj_out(hidden_states) - # torch.save(output, "proj_out.pt") - # exit() if not return_dict: return (output,) @@ -477,11 +447,7 @@ def custom_forward(*inputs): def apply_rotary_emb(x, freqs): cos, sin = freqs - - x_real, x_imag = x.unflatten(2, (-1, 2)).unbind(-1) # [B, S, D//2, 2] + x_real, x_imag = x.unflatten(2, (-1, 2)).unbind(-1) # [B, S, D // 2, 2] x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(2) - - # breakpoint() - # out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype) - out = x * cos + x_rotated * sin + out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype) return out diff --git a/src/diffusers/pipelines/ltx/pipeline_ltx.py b/src/diffusers/pipelines/ltx/pipeline_ltx.py index d509b164742f..c4d1f2d565fc 100644 --- a/src/diffusers/pipelines/ltx/pipeline_ltx.py +++ b/src/diffusers/pipelines/ltx/pipeline_ltx.py @@ -699,7 +699,7 @@ def __call__( rope_interpolation_scale=rope_interpolation_scale, return_dict=False, )[0] - # noise_pred = noise_pred.float() + noise_pred = noise_pred.float() if self.do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) @@ -707,8 +707,7 @@ def __call__( # compute the previous noisy sample x_t -> x_t-1 latents_dtype = latents.dtype - # latents = self.scheduler.step(noise_pred, t, latents.float(), return_dict=False)[0] - latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + latents = self.scheduler.step(noise_pred, t, latents.float(), return_dict=False)[0] latents = latents.to(dtype=latents_dtype) if callback_on_step_end is not None: From 8c9d3d0f0e6798ce1416a6cb60545e12d0fd5d17 Mon Sep 17 00:00:00 2001 From: Aryan Date: Sun, 1 Dec 2024 17:18:54 +0530 Subject: [PATCH 37/51] Apply suggestions from code review Co-authored-by: YiYi Xu --- src/diffusers/pipelines/ltx/pipeline_ltx.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/diffusers/pipelines/ltx/pipeline_ltx.py b/src/diffusers/pipelines/ltx/pipeline_ltx.py index c4d1f2d565fc..0777cc21e66f 100644 --- a/src/diffusers/pipelines/ltx/pipeline_ltx.py +++ b/src/diffusers/pipelines/ltx/pipeline_ltx.py @@ -640,7 +640,7 @@ def __call__( height, width, num_frames, - prompt_embeds.dtype, + torch.float32, device, generator, latents, @@ -684,9 +684,10 @@ def __call__( continue latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + latent_model_input = latent_model_input.to(prompt_embeds.type) # broadcast to batch dimension in a way that's compatible with ONNX/Core ML - timestep = t.expand(latent_model_input.shape[0]).to(latents.dtype) + timestep = t.expand(latent_model_input.shape[0]) noise_pred = self.transformer( hidden_states=latent_model_input, @@ -707,7 +708,7 @@ def __call__( # compute the previous noisy sample x_t -> x_t-1 latents_dtype = latents.dtype - latents = self.scheduler.step(noise_pred, t, latents.float(), return_dict=False)[0] + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] latents = latents.to(dtype=latents_dtype) if callback_on_step_end is not None: @@ -740,6 +741,7 @@ def __call__( latents = self._denormalize_latents( latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor ) + latents = latents.to(prompt_embeds.dtype) video = self.vae.decode(latents, return_dict=False)[0] video = self.video_processor.postprocess_video(video, output_type=output_type) From eb962d189ae44a8dc8d0769fae3f1702a54665c2 Mon Sep 17 00:00:00 2001 From: Aryan Date: Sun, 1 Dec 2024 13:05:22 +0100 Subject: [PATCH 38/51] propagate precision changes to i2v pipeline --- src/diffusers/pipelines/ltx/pipeline_ltx.py | 4 +--- .../pipelines/ltx/pipeline_ltx_image2video.py | 13 ++++--------- 2 files changed, 5 insertions(+), 12 deletions(-) diff --git a/src/diffusers/pipelines/ltx/pipeline_ltx.py b/src/diffusers/pipelines/ltx/pipeline_ltx.py index 0777cc21e66f..96c66d4f48e2 100644 --- a/src/diffusers/pipelines/ltx/pipeline_ltx.py +++ b/src/diffusers/pipelines/ltx/pipeline_ltx.py @@ -684,7 +684,7 @@ def __call__( continue latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents - latent_model_input = latent_model_input.to(prompt_embeds.type) + latent_model_input = latent_model_input.to(prompt_embeds.dtype) # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = t.expand(latent_model_input.shape[0]) @@ -707,9 +707,7 @@ def __call__( noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 - latents_dtype = latents.dtype latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] - latents = latents.to(dtype=latents_dtype) if callback_on_step_end is not None: callback_kwargs = {} diff --git a/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py b/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py index e225ce36dc8a..2d4567078b36 100644 --- a/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py +++ b/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py @@ -702,7 +702,7 @@ def __call__( height, width, num_frames, - prompt_embeds.dtype, + torch.float32, device, generator, latents, @@ -749,9 +749,10 @@ def __call__( continue latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + latent_model_input = latent_model_input.to(prompt_embeds.dtype) # broadcast to batch dimension in a way that's compatible with ONNX/Core ML - timestep = t.expand(latent_model_input.shape[0]).to(latents.dtype) + timestep = t.expand(latent_model_input.shape[0]) timestep = timestep.unsqueeze(-1) * (1 - conditioning_mask) noise_pred = self.transformer( @@ -773,13 +774,7 @@ def __call__( timestep, _ = timestep.chunk(2) # compute the previous noisy sample x_t -> x_t-1 - latents_dtype = latents.dtype - # latents = self.scheduler.step(noise_pred, t, latents.float(), return_dict=False)[0] - # latents = latents.to(dtype=latents_dtype) - # ============= TODO(aryan): needs a look by YiYi - latents = latents.float() - noise_pred = self._unpack_latents( noise_pred, latent_num_frames, @@ -805,7 +800,6 @@ def __call__( latents = self._pack_latents( latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size ) - latents = latents.to(dtype=latents_dtype) # ============= if callback_on_step_end is not None: @@ -838,6 +832,7 @@ def __call__( latents = self._denormalize_latents( latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor ) + latents = latents.to(prompt_embeds.dtype) video = self.vae.decode(latents, return_dict=False)[0] video = self.video_processor.postprocess_video(video, output_type=output_type) From 5196b2a9b1e0d3818f3387adc1a156b3ef995110 Mon Sep 17 00:00:00 2001 From: Aryan Date: Sun, 1 Dec 2024 16:30:23 +0100 Subject: [PATCH 39/51] remove downcast --- src/diffusers/models/transformers/transformer_ltx.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_ltx.py b/src/diffusers/models/transformers/transformer_ltx.py index 81ffde29d65c..34f3faf03a83 100644 --- a/src/diffusers/models/transformers/transformer_ltx.py +++ b/src/diffusers/models/transformers/transformer_ltx.py @@ -161,9 +161,6 @@ def forward( cos_freqs = torch.cat([cos_padding, cos_freqs], dim=-1) sin_freqs = torch.cat([sin_padding, sin_freqs], dim=-1) - cos_freqs = cos_freqs.to(dtype=hidden_states.dtype) - sin_freqs = sin_freqs.to(dtype=hidden_states.dtype) - return cos_freqs, sin_freqs From d76232df7ad685c0b7dd366c99fa3dce1080cd63 Mon Sep 17 00:00:00 2001 From: Aryan Date: Mon, 2 Dec 2024 21:09:10 +0100 Subject: [PATCH 40/51] address review comments --- scripts/convert_ltx_to_diffusers.py | 1 + .../models/autoencoders/autoencoder_kl_ltx.py | 22 ++++---- src/diffusers/models/normalization.py | 53 +------------------ .../models/transformers/transformer_ltx.py | 6 +-- .../pipelines/ltx/pipeline_ltx_image2video.py | 2 - .../scheduling_flow_match_euler_discrete.py | 15 ++++++ 6 files changed, 30 insertions(+), 69 deletions(-) diff --git a/scripts/convert_ltx_to_diffusers.py b/scripts/convert_ltx_to_diffusers.py index 75d13b7b4755..053f33253751 100644 --- a/scripts/convert_ltx_to_diffusers.py +++ b/scripts/convert_ltx_to_diffusers.py @@ -49,6 +49,7 @@ def remove_keys_(key: str, state_dict: Dict[str, Any]): # common "conv_shortcut": "conv_shortcut.conv", "res_blocks": "resnets", + "norm3.norm": "norm3", "per_channel_statistics.mean-of-means": "latents_mean", "per_channel_statistics.std-of-means": "latents_std", } diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_ltx.py b/src/diffusers/models/autoencoders/autoencoder_kl_ltx.py index ab06cbbb2f82..1afa19091b74 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_ltx.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_ltx.py @@ -23,7 +23,7 @@ from ..activations import get_activation from ..modeling_outputs import AutoencoderKLOutput from ..modeling_utils import ModelMixin -from ..normalization import LayerNormNd, RMSNormNd +from ..normalization import RMSNorm from .vae import DecoderOutput, DiagonalGaussianDistribution @@ -117,12 +117,12 @@ def __init__( self.nonlinearity = get_activation(non_linearity) - self.norm1 = RMSNormNd(dim=in_channels, eps=1e-8, elementwise_affine=elementwise_affine, channel_dim=1) + self.norm1 = RMSNorm(in_channels, eps=1e-8, elementwise_affine=elementwise_affine) self.conv1 = LTXCausalConv3d( in_channels=in_channels, out_channels=out_channels, kernel_size=3, is_causal=is_causal ) - self.norm2 = RMSNormNd(dim=out_channels, eps=1e-8, elementwise_affine=elementwise_affine, channel_dim=1) + self.norm2 = RMSNorm(out_channels, eps=1e-8, elementwise_affine=elementwise_affine) self.dropout = nn.Dropout(dropout) self.conv2 = LTXCausalConv3d( in_channels=out_channels, out_channels=out_channels, kernel_size=3, is_causal=is_causal @@ -131,7 +131,7 @@ def __init__( self.norm3 = None self.conv_shortcut = None if in_channels != out_channels: - self.norm3 = LayerNormNd(in_channels, eps=eps, elementwise_affine=True, bias=True, channel_dim=1) + self.norm3 = nn.LayerNorm(in_channels, eps=eps, elementwise_affine=True, bias=True) self.conv_shortcut = LTXCausalConv3d( in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1, is_causal=is_causal ) @@ -139,17 +139,17 @@ def __init__( def forward(self, inputs: torch.Tensor) -> torch.Tensor: hidden_states = inputs - hidden_states = self.norm1(hidden_states) + hidden_states = self.norm1(hidden_states.movedim(1, -1)).movedim(-1, 1) hidden_states = self.nonlinearity(hidden_states) hidden_states = self.conv1(hidden_states) - hidden_states = self.norm2(hidden_states) + hidden_states = self.norm2(hidden_states.movedim(1, -1)).movedim(-1, 1) hidden_states = self.nonlinearity(hidden_states) hidden_states = self.dropout(hidden_states) hidden_states = self.conv2(hidden_states) if self.norm3 is not None: - inputs = self.norm3(inputs) + inputs = self.norm3(inputs.movedim(1, -1)).movedim(-1, 1) if self.conv_shortcut is not None: inputs = self.conv_shortcut(inputs) @@ -545,7 +545,7 @@ def __init__( ) # out - self.norm_out = RMSNormNd(dim=out_channels, eps=1e-8, elementwise_affine=False, channel_dim=1) + self.norm_out = RMSNorm(out_channels, eps=1e-8, elementwise_affine=False) self.conv_act = nn.SiLU() self.conv_out = LTXCausalConv3d( in_channels=output_channel, out_channels=out_channels + 1, kernel_size=3, stride=1, is_causal=is_causal @@ -589,7 +589,7 @@ def create_forward(*inputs): hidden_states = self.mid_block(hidden_states) - hidden_states = self.norm_out(hidden_states) + hidden_states = self.norm_out(hidden_states.movedim(1, -1)).movedim(-1, 1) hidden_states = self.conv_act(hidden_states) hidden_states = self.conv_out(hidden_states) @@ -675,7 +675,7 @@ def __init__( self.up_blocks.append(up_block) # out - self.norm_out = RMSNormNd(dim=out_channels, eps=1e-8, elementwise_affine=False, channel_dim=1) + self.norm_out = RMSNorm(out_channels, eps=1e-8, elementwise_affine=False) self.conv_act = nn.SiLU() self.conv_out = LTXCausalConv3d( in_channels=output_channel, out_channels=self.out_channels, kernel_size=3, stride=1, is_causal=is_causal @@ -704,7 +704,7 @@ def create_forward(*inputs): for up_block in self.up_blocks: hidden_states = up_block(hidden_states) - hidden_states = self.norm_out(hidden_states) + hidden_states = self.norm_out(hidden_states.movedim(1, -1)).movedim(-1, 1) hidden_states = self.conv_act(hidden_states) hidden_states = self.conv_out(hidden_states) diff --git a/src/diffusers/models/normalization.py b/src/diffusers/models/normalization.py index f13cea59da06..00dd7a144981 100644 --- a/src/diffusers/models/normalization.py +++ b/src/diffusers/models/normalization.py @@ -14,7 +14,7 @@ # limitations under the License. import numbers -from typing import Dict, List, Optional, Tuple, Union +from typing import Dict, Optional, Tuple import torch import torch.nn as nn @@ -567,54 +567,3 @@ def __init__(self, p: int = 2, dim: int = -1, eps: float = 1e-12): def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: return F.normalize(hidden_states, p=self.p, dim=self.dim, eps=self.eps) - - -class LayerNormNd(nn.Module): - def __init__( - self, - normalized_shape: Union[int, List[int], Tuple[int], torch.Size], - eps: float = 1e-5, - elementwise_affine: bool = True, - bias: bool = True, - device=None, - dtype=None, - channel_dim: int = -1, - ) -> None: - super().__init__() - - self.norm = nn.LayerNorm(normalized_shape, eps, elementwise_affine, bias, device, dtype) - self.channel_dim = channel_dim - - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - if self.channel_dim != -1: - hidden_states = hidden_states.movedim(self.channel_dim, -1) - hidden_states = self.norm(hidden_states) - hidden_states = hidden_states.movedim(-1, self.channel_dim) - else: - hidden_states = self.norm(hidden_states) - - return hidden_states - - -class RMSNormNd(nn.Module): - def __init__( - self, - dim: int, - eps: float, - elementwise_affine: bool = True, - channel_dim: int = -1, - ) -> None: - super().__init__() - - self.norm = RMSNorm(dim, eps=eps, elementwise_affine=elementwise_affine) - self.channel_dim = channel_dim - - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - if self.channel_dim != -1: - hidden_states = hidden_states.movedim(self.channel_dim, -1) - hidden_states = self.norm(hidden_states) - hidden_states = hidden_states.movedim(-1, self.channel_dim) - else: - hidden_states = self.norm(hidden_states) - - return hidden_states diff --git a/src/diffusers/models/transformers/transformer_ltx.py b/src/diffusers/models/transformers/transformer_ltx.py index 34f3faf03a83..36d3c1d3a3e4 100644 --- a/src/diffusers/models/transformers/transformer_ltx.py +++ b/src/diffusers/models/transformers/transformer_ltx.py @@ -62,10 +62,8 @@ def __call__( attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) - use_rotary_emb = False if encoder_hidden_states is None: encoder_hidden_states = hidden_states - use_rotary_emb = True query = attn.to_q(hidden_states) key = attn.to_k(encoder_hidden_states) @@ -74,7 +72,7 @@ def __call__( query = attn.norm_q(query) key = attn.norm_k(key) - if image_rotary_emb is not None and use_rotary_emb: + if image_rotary_emb is not None: query = apply_rotary_emb(query, image_rotary_emb) key = apply_rotary_emb(key, image_rotary_emb) @@ -255,7 +253,7 @@ def forward( attn_hidden_states = self.attn2( hidden_states, encoder_hidden_states=encoder_hidden_states, - image_rotary_emb=image_rotary_emb, + image_rotary_emb=None, attention_mask=encoder_attention_mask, ) hidden_states = hidden_states + attn_hidden_states diff --git a/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py b/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py index 2d4567078b36..0955090e4fdf 100644 --- a/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py +++ b/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py @@ -774,7 +774,6 @@ def __call__( timestep, _ = timestep.chunk(2) # compute the previous noisy sample x_t -> x_t-1 - # ============= TODO(aryan): needs a look by YiYi noise_pred = self._unpack_latents( noise_pred, latent_num_frames, @@ -800,7 +799,6 @@ def __call__( latents = self._pack_latents( latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size ) - # ============= if callback_on_step_end is not None: callback_kwargs = {} diff --git a/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py b/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py index 7b71391b70c4..31fbe087b1e9 100644 --- a/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py +++ b/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py @@ -183,6 +183,21 @@ def time_shift(self, mu: float, sigma: float, t: torch.Tensor): return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) def stretch_shift_to_terminal(self, t: torch.Tensor) -> torch.Tensor: + r""" + Stretches and shifts the timestep schedule to ensure it terminates at the configured `shift_terminal` config + value. + + Reference: + https://github.com/Lightricks/LTX-Video/blob/a01a171f8fe3d99dce2728d60a73fecf4d4238ae/ltx_video/schedulers/rf.py#L51 + + Args: + t (`torch.Tensor`): + A tensor of timesteps to be stretched and shifted. + + Returns: + `torch.Tensor`: + A tensor of adjusted timesteps such that the final value equals `self.config.shift_terminal`. + """ one_minus_z = 1 - t scale_factor = one_minus_z[-1] / (1 - self.config.shift_terminal) stretched_t = 1 - (one_minus_z / scale_factor) From f18cf1a55d39677ddd11ef3c024f99c5f1b352ac Mon Sep 17 00:00:00 2001 From: Aryan Date: Mon, 2 Dec 2024 21:13:27 +0100 Subject: [PATCH 41/51] fix comment --- src/diffusers/models/transformers/transformer_ltx.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/transformers/transformer_ltx.py b/src/diffusers/models/transformers/transformer_ltx.py index 36d3c1d3a3e4..1aca1a246996 100644 --- a/src/diffusers/models/transformers/transformer_ltx.py +++ b/src/diffusers/models/transformers/transformer_ltx.py @@ -442,7 +442,7 @@ def custom_forward(*inputs): def apply_rotary_emb(x, freqs): cos, sin = freqs - x_real, x_imag = x.unflatten(2, (-1, 2)).unbind(-1) # [B, S, D // 2, 2] + x_real, x_imag = x.unflatten(2, (-1, 2)).unbind(-1) # [B, S, H, D // 2] x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(2) out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype) return out From 4e8b2a4c18add6d75e33ea004cbc4bc2b67f7891 Mon Sep 17 00:00:00 2001 From: Aryan Date: Wed, 4 Dec 2024 16:20:47 +0100 Subject: [PATCH 42/51] address review comments --- src/diffusers/pipelines/ltx/pipeline_ltx.py | 32 ++++++++++--------- .../pipelines/ltx/pipeline_ltx_image2video.py | 29 ++++++++++------- 2 files changed, 35 insertions(+), 26 deletions(-) diff --git a/src/diffusers/pipelines/ltx/pipeline_ltx.py b/src/diffusers/pipelines/ltx/pipeline_ltx.py index 96c66d4f48e2..ac5c1862e15f 100644 --- a/src/diffusers/pipelines/ltx/pipeline_ltx.py +++ b/src/diffusers/pipelines/ltx/pipeline_ltx.py @@ -197,10 +197,6 @@ def __init__( self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 128 ) - self.default_height = 512 - self.default_width = 704 - self.default_frames = 121 - # Copied from diffusers.pipelines.mochi.pipeline_mochi.MochiPipeline._get_t5_prompt_embeds with 256->128 def _get_t5_prompt_embeds( self, @@ -389,6 +385,10 @@ def check_inputs( @staticmethod def _pack_latents(latents: torch.Tensor, patch_size: int = 1, patch_size_t: int = 1) -> torch.Tensor: + # Unpacked latents of shape are [B, C, F, H, W] are patched into tokens of shape [B, C, F // p_t, p_t, H // p, p, W // p, p]. + # The patch dimensions are then permuted and collapsed into the channel dimension of shape: + # [B, F // p_t * H // p * W // p, C * p_t * p * p] (an ndim=3 tensor). + # dim=0 is the batch size, dim=1 is the effective video sequence length, dim=2 is the effective number of input features batch_size, num_channels, num_frames, height, width = latents.shape post_patch_num_frames = num_frames // patch_size_t post_patch_height = height // patch_size @@ -410,7 +410,10 @@ def _pack_latents(latents: torch.Tensor, patch_size: int = 1, patch_size_t: int def _unpack_latents( latents: torch.Tensor, num_frames: int, height: int, width: int, patch_size: int = 1, patch_size_t: int = 1 ) -> torch.Tensor: - batch_size, num_channels, video_sequence_length = latents.shape + # Packed latents of shape [B, S, D] (S is the effective video sequence length, D is the effective feature dimensions) + # are unpacked and reshaped into a video tensor of shape [B, C, F, H, W]. This is the inverse operation of + # what happens in the `_pack_latents` method. + batch_size = latents.size(0) latents = latents.reshape(batch_size, num_frames, height, width, -1, patch_size_t, patch_size, patch_size) latents = latents.permute(0, 4, 1, 5, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(2, 3) return latents @@ -419,6 +422,7 @@ def _unpack_latents( def _normalize_latents( latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0 ) -> torch.Tensor: + # Normalize latents across the channel dimension [B, C, F, H, W] latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) latents = (latents - latents_mean) * scaling_factor / latents_std @@ -428,6 +432,7 @@ def _normalize_latents( def _denormalize_latents( latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0 ) -> torch.Tensor: + # Denormalize latents across the channel dimension [B, C, F, H, W] latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) latents = latents * latents_std / scaling_factor + latents_mean @@ -488,9 +493,9 @@ def __call__( self, prompt: Union[str, List[str]] = None, negative_prompt: Optional[Union[str, List[str]]] = None, - height: Optional[int] = None, - width: Optional[int] = None, - num_frames: int = 81, + height: int = 512, + width: int = 704, + num_frames: int = 161, frame_rate: int = 25, num_inference_steps: int = 50, timesteps: List[int] = None, @@ -515,11 +520,11 @@ def __call__( prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. - height (`int`, *optional*, defaults to `self.default_height`): + height (`int`, defaults to `512`): The height in pixels of the generated image. This is set to 480 by default for the best results. - width (`int`, *optional*, defaults to `self.default_width`): + width (`int`, defaults to `704`): The width in pixels of the generated image. This is set to 848 by default for the best results. - num_frames (`int`, defaults to `81 `): + num_frames (`int`, defaults to `161`): The number of video frames to generate num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the @@ -581,10 +586,6 @@ def __call__( if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs - height = height or self.default_height - width = width or self.default_width - latent_frame_rate = frame_rate / self.vae_temporal_compression_ratio - # 1. Check inputs. Raise error if not correct self.check_inputs( prompt=prompt, @@ -671,6 +672,7 @@ def __call__( self._num_timesteps = len(timesteps) # 6. Prepare micro-conditions + latent_frame_rate = frame_rate / self.vae_temporal_compression_ratio rope_interpolation_scale = ( 1 / latent_frame_rate, self.vae_spatial_compression_ratio, diff --git a/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py b/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py index 0955090e4fdf..943a934989a8 100644 --- a/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py +++ b/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py @@ -353,6 +353,7 @@ def encode_prompt( return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask + # Copied from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline.check_inputs def check_inputs( self, prompt, @@ -409,6 +410,10 @@ def check_inputs( @staticmethod # Copied from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline._pack_latents def _pack_latents(latents: torch.Tensor, patch_size: int = 1, patch_size_t: int = 1) -> torch.Tensor: + # Unpacked latents of shape are [B, C, F, H, W] are patched into tokens of shape [B, C, F // p_t, p_t, H // p, p, W // p, p]. + # The patch dimensions are then permuted and collapsed into the channel dimension of shape: + # [B, F // p_t * H // p * W // p, C * p_t * p * p] (an ndim=3 tensor). + # dim=0 is the batch size, dim=1 is the effective video sequence length, dim=2 is the effective number of input features batch_size, num_channels, num_frames, height, width = latents.shape post_patch_num_frames = num_frames // patch_size_t post_patch_height = height // patch_size @@ -431,7 +436,10 @@ def _pack_latents(latents: torch.Tensor, patch_size: int = 1, patch_size_t: int def _unpack_latents( latents: torch.Tensor, num_frames: int, height: int, width: int, patch_size: int = 1, patch_size_t: int = 1 ) -> torch.Tensor: - batch_size, num_channels, video_sequence_length = latents.shape + # Packed latents of shape [B, S, D] (S is the effective video sequence length, D is the effective feature dimensions) + # are unpacked and reshaped into a video tensor of shape [B, C, F, H, W]. This is the inverse operation of + # what happens in the `_pack_latents` method. + batch_size = latents.size(0) latents = latents.reshape(batch_size, num_frames, height, width, -1, patch_size_t, patch_size, patch_size) latents = latents.permute(0, 4, 1, 5, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(2, 3) return latents @@ -441,6 +449,7 @@ def _unpack_latents( def _normalize_latents( latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0 ) -> torch.Tensor: + # Normalize latents across the channel dimension [B, C, F, H, W] latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) latents = (latents - latents_mean) * scaling_factor / latents_std @@ -451,6 +460,7 @@ def _normalize_latents( def _denormalize_latents( latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0 ) -> torch.Tensor: + # Denormalize latents across the channel dimension [B, C, F, H, W] latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) latents = latents * latents_std / scaling_factor + latents_mean @@ -543,9 +553,9 @@ def __call__( image: PipelineImageInput = None, prompt: Union[str, List[str]] = None, negative_prompt: Optional[Union[str, List[str]]] = None, - height: Optional[int] = None, - width: Optional[int] = None, - num_frames: int = 81, + height: int = 512, + width: int = 704, + num_frames: int = 161, frame_rate: int = 25, num_inference_steps: int = 50, timesteps: List[int] = None, @@ -572,11 +582,11 @@ def __call__( prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. - height (`int`, *optional*, defaults to `self.default_height`): + height (`int`, defaults to `512`): The height in pixels of the generated image. This is set to 480 by default for the best results. - width (`int`, *optional*, defaults to `self.default_width`): + width (`int`, defaults to `704`): The width in pixels of the generated image. This is set to 848 by default for the best results. - num_frames (`int`, defaults to `81 `): + num_frames (`int`, defaults to `161`): The number of video frames to generate num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the @@ -638,10 +648,6 @@ def __call__( if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs - height = height or self.default_height - width = width or self.default_width - latent_frame_rate = frame_rate / self.vae_temporal_compression_ratio - # 1. Check inputs. Raise error if not correct self.check_inputs( prompt=prompt, @@ -736,6 +742,7 @@ def __call__( self._num_timesteps = len(timesteps) # 6. Prepare micro-conditions + latent_frame_rate = frame_rate / self.vae_temporal_compression_ratio rope_interpolation_scale = ( 1 / latent_frame_rate, self.vae_spatial_compression_ratio, From 9ba6a0636e6245f3bf259879b7bf5716d853b521 Mon Sep 17 00:00:00 2001 From: Aryan Date: Tue, 10 Dec 2024 14:12:33 +0530 Subject: [PATCH 43/51] [Single File] LTX support for loading original weights (#10135) * from original file mixin for ltx * undo config mapping fn changes * update --- src/diffusers/loaders/single_file_model.py | 11 ++ src/diffusers/loaders/single_file_utils.py | 105 ++++++++++++++++++ .../models/autoencoders/autoencoder_kl_ltx.py | 3 +- .../models/transformers/transformer_ltx.py | 3 +- 4 files changed, 120 insertions(+), 2 deletions(-) diff --git a/src/diffusers/loaders/single_file_model.py b/src/diffusers/loaders/single_file_model.py index be3139057078..11ade76f7f70 100644 --- a/src/diffusers/loaders/single_file_model.py +++ b/src/diffusers/loaders/single_file_model.py @@ -27,6 +27,8 @@ convert_flux_transformer_checkpoint_to_diffusers, convert_ldm_unet_checkpoint, convert_ldm_vae_checkpoint, + convert_ltx_transformer_checkpoint_to_diffusers, + convert_ltx_vae_checkpoint_to_diffusers, convert_sd3_transformer_checkpoint_to_diffusers, convert_stable_cascade_unet_single_file_to_diffusers, create_controlnet_diffusers_config_from_ldm, @@ -82,6 +84,14 @@ "checkpoint_mapping_fn": convert_flux_transformer_checkpoint_to_diffusers, "default_subfolder": "transformer", }, + "LTXTransformer3DModel": { + "checkpoint_mapping_fn": convert_ltx_transformer_checkpoint_to_diffusers, + "default_subfolder": "transformer", + }, + "AutoencoderKLLTX": { + "checkpoint_mapping_fn": convert_ltx_vae_checkpoint_to_diffusers, + "default_subfolder": "vae", + }, } @@ -270,6 +280,7 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] = subfolder=subfolder, local_files_only=local_files_only, token=token, + revision=revision, ) expected_kwargs, optional_kwargs = cls._get_signature_keys(cls) diff --git a/src/diffusers/loaders/single_file_utils.py b/src/diffusers/loaders/single_file_utils.py index 10742873ded1..c6006c959f03 100644 --- a/src/diffusers/loaders/single_file_utils.py +++ b/src/diffusers/loaders/single_file_utils.py @@ -92,6 +92,12 @@ "double_blocks.0.img_attn.norm.key_norm.scale", "model.diffusion_model.double_blocks.0.img_attn.norm.key_norm.scale", ], + "ltx-video": [ + ( + "model.diffusion_model.patchify_proj.weight", + "model.diffusion_model.transformer_blocks.27.scale_shift_table", + ), + ], } DIFFUSERS_DEFAULT_PIPELINE_PATHS = { @@ -138,6 +144,7 @@ "animatediff_rgb": {"pretrained_model_name_or_path": "guoyww/animatediff-sparsectrl-rgb"}, "flux-dev": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-dev"}, "flux-schnell": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-schnell"}, + "ltx-video": {"pretrained_model_name_or_path": "Lightricks/LTX-Video"}, } # Use to configure model sample size when original config is provided @@ -564,6 +571,10 @@ def infer_diffusers_model_type(checkpoint): model_type = "flux-dev" else: model_type = "flux-schnell" + + elif any(all(key in checkpoint for key in key_list) for key_list in CHECKPOINT_KEY_NAMES["ltx-video"]): + model_type = "ltx-video" + else: model_type = "v1" @@ -2198,3 +2209,97 @@ def swap_scale_shift(weight): ) return converted_state_dict + + +def convert_ltx_transformer_checkpoint_to_diffusers(checkpoint, **kwargs): + converted_state_dict = {key: checkpoint.pop(key) for key in list(checkpoint.keys())} + + def remove_keys_(key: str, state_dict): + state_dict.pop(key) + + TRANSFORMER_KEYS_RENAME_DICT = { + "model.diffusion_model.": "", + "patchify_proj": "proj_in", + "adaln_single": "time_embed", + "q_norm": "norm_q", + "k_norm": "norm_k", + } + + TRANSFORMER_SPECIAL_KEYS_REMAP = { + "vae": remove_keys_, + } + + for key in list(converted_state_dict.keys()): + new_key = key + for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items(): + new_key = new_key.replace(replace_key, rename_key) + converted_state_dict[new_key] = converted_state_dict.pop(key) + + for key in list(converted_state_dict.keys()): + for special_key, handler_fn_inplace in TRANSFORMER_SPECIAL_KEYS_REMAP.items(): + if special_key not in key: + continue + handler_fn_inplace(key, converted_state_dict) + + return converted_state_dict + + +def convert_ltx_vae_checkpoint_to_diffusers(checkpoint, **kwargs): + converted_state_dict = {key: checkpoint.pop(key) for key in list(checkpoint.keys())} + + def remove_keys_(key: str, state_dict): + state_dict.pop(key) + + VAE_KEYS_RENAME_DICT = { + # common + "vae.": "", + # decoder + "up_blocks.0": "mid_block", + "up_blocks.1": "up_blocks.0", + "up_blocks.2": "up_blocks.1.upsamplers.0", + "up_blocks.3": "up_blocks.1", + "up_blocks.4": "up_blocks.2.conv_in", + "up_blocks.5": "up_blocks.2.upsamplers.0", + "up_blocks.6": "up_blocks.2", + "up_blocks.7": "up_blocks.3.conv_in", + "up_blocks.8": "up_blocks.3.upsamplers.0", + "up_blocks.9": "up_blocks.3", + # encoder + "down_blocks.0": "down_blocks.0", + "down_blocks.1": "down_blocks.0.downsamplers.0", + "down_blocks.2": "down_blocks.0.conv_out", + "down_blocks.3": "down_blocks.1", + "down_blocks.4": "down_blocks.1.downsamplers.0", + "down_blocks.5": "down_blocks.1.conv_out", + "down_blocks.6": "down_blocks.2", + "down_blocks.7": "down_blocks.2.downsamplers.0", + "down_blocks.8": "down_blocks.3", + "down_blocks.9": "mid_block", + # common + "conv_shortcut": "conv_shortcut.conv", + "res_blocks": "resnets", + "norm3.norm": "norm3", + "per_channel_statistics.mean-of-means": "latents_mean", + "per_channel_statistics.std-of-means": "latents_std", + } + + VAE_SPECIAL_KEYS_REMAP = { + "per_channel_statistics.channel": remove_keys_, + "per_channel_statistics.mean-of-means": remove_keys_, + "per_channel_statistics.mean-of-stds": remove_keys_, + "model.diffusion_model": remove_keys_, + } + + for key in list(converted_state_dict.keys()): + new_key = key + for replace_key, rename_key in VAE_KEYS_RENAME_DICT.items(): + new_key = new_key.replace(replace_key, rename_key) + converted_state_dict[new_key] = converted_state_dict.pop(key) + + for key in list(converted_state_dict.keys()): + for special_key, handler_fn_inplace in VAE_SPECIAL_KEYS_REMAP.items(): + if special_key not in key: + continue + handler_fn_inplace(key, converted_state_dict) + + return converted_state_dict diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_ltx.py b/src/diffusers/models/autoencoders/autoencoder_kl_ltx.py index 1afa19091b74..1099a79bb344 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_ltx.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_ltx.py @@ -19,6 +19,7 @@ import torch.nn as nn from ...configuration_utils import ConfigMixin, register_to_config +from ...loaders import FromOriginalModelMixin from ...utils.accelerate_utils import apply_forward_hook from ..activations import get_activation from ..modeling_outputs import AutoencoderKLOutput @@ -718,7 +719,7 @@ def create_forward(*inputs): return hidden_states -class AutoencoderKLLTX(ModelMixin, ConfigMixin): +class AutoencoderKLLTX(ModelMixin, ConfigMixin, FromOriginalModelMixin): r""" A VAE model with KL loss for encoding images into latents and decoding latent representations into images. Used in [LTX](https://huggingface.co/Lightricks/LTX-Video). diff --git a/src/diffusers/models/transformers/transformer_ltx.py b/src/diffusers/models/transformers/transformer_ltx.py index 1aca1a246996..b81f2709d1bb 100644 --- a/src/diffusers/models/transformers/transformer_ltx.py +++ b/src/diffusers/models/transformers/transformer_ltx.py @@ -21,6 +21,7 @@ import torch.nn.functional as F from ...configuration_utils import ConfigMixin, register_to_config +from ...loaders import FromOriginalModelMixin from ...utils import is_torch_version, logging from ...utils.torch_utils import maybe_allow_in_graph from ..attention import FeedForward @@ -266,7 +267,7 @@ def forward( @maybe_allow_in_graph -class LTXTransformer3DModel(ModelMixin, ConfigMixin): +class LTXTransformer3DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin): r""" A Transformer model for video-like data used in [LTX](https://huggingface.co/Lightricks/LTX-Video). From db1698331ab1226b1761af989a25826d708aca20 Mon Sep 17 00:00:00 2001 From: Aryan Date: Tue, 10 Dec 2024 10:18:33 +0100 Subject: [PATCH 44/51] add single file to pipelines --- src/diffusers/loaders/single_file_utils.py | 14 +++++--------- src/diffusers/pipelines/ltx/pipeline_ltx.py | 3 ++- .../pipelines/ltx/pipeline_ltx_image2video.py | 3 ++- 3 files changed, 9 insertions(+), 11 deletions(-) diff --git a/src/diffusers/loaders/single_file_utils.py b/src/diffusers/loaders/single_file_utils.py index c6006c959f03..7faec1161162 100644 --- a/src/diffusers/loaders/single_file_utils.py +++ b/src/diffusers/loaders/single_file_utils.py @@ -2212,10 +2212,9 @@ def swap_scale_shift(weight): def convert_ltx_transformer_checkpoint_to_diffusers(checkpoint, **kwargs): - converted_state_dict = {key: checkpoint.pop(key) for key in list(checkpoint.keys())} - - def remove_keys_(key: str, state_dict): - state_dict.pop(key) + converted_state_dict = { + key: checkpoint.pop(key) for key in list(checkpoint.keys()) if "model.diffusion_model." in key + } TRANSFORMER_KEYS_RENAME_DICT = { "model.diffusion_model.": "", @@ -2225,9 +2224,7 @@ def remove_keys_(key: str, state_dict): "k_norm": "norm_k", } - TRANSFORMER_SPECIAL_KEYS_REMAP = { - "vae": remove_keys_, - } + TRANSFORMER_SPECIAL_KEYS_REMAP = {} for key in list(converted_state_dict.keys()): new_key = key @@ -2245,7 +2242,7 @@ def remove_keys_(key: str, state_dict): def convert_ltx_vae_checkpoint_to_diffusers(checkpoint, **kwargs): - converted_state_dict = {key: checkpoint.pop(key) for key in list(checkpoint.keys())} + converted_state_dict = {key: checkpoint.pop(key) for key in list(checkpoint.keys()) if "vae." in key} def remove_keys_(key: str, state_dict): state_dict.pop(key) @@ -2287,7 +2284,6 @@ def remove_keys_(key: str, state_dict): "per_channel_statistics.channel": remove_keys_, "per_channel_statistics.mean-of-means": remove_keys_, "per_channel_statistics.mean-of-stds": remove_keys_, - "model.diffusion_model": remove_keys_, } for key in list(converted_state_dict.keys()): diff --git a/src/diffusers/pipelines/ltx/pipeline_ltx.py b/src/diffusers/pipelines/ltx/pipeline_ltx.py index ac5c1862e15f..545ad2122034 100644 --- a/src/diffusers/pipelines/ltx/pipeline_ltx.py +++ b/src/diffusers/pipelines/ltx/pipeline_ltx.py @@ -20,6 +20,7 @@ from transformers import T5EncoderModel, T5TokenizerFast from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...loaders import FromSingleFileMixin from ...models.autoencoders import AutoencoderKLLTX from ...models.transformers import LTXTransformer3DModel from ...schedulers import FlowMatchEulerDiscreteScheduler @@ -139,7 +140,7 @@ def retrieve_timesteps( return timesteps, num_inference_steps -class LTXPipeline(DiffusionPipeline): +class LTXPipeline(DiffusionPipeline, FromSingleFileMixin): r""" Pipeline for text-to-video generation. diff --git a/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py b/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py index 943a934989a8..0b43ab1204e7 100644 --- a/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py +++ b/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py @@ -21,6 +21,7 @@ from ...callbacks import MultiPipelineCallbacks, PipelineCallback from ...image_processor import PipelineImageInput +from ...loaders import FromSingleFileMixin from ...models.autoencoders import AutoencoderKLLTX from ...models.transformers import LTXTransformer3DModel from ...schedulers import FlowMatchEulerDiscreteScheduler @@ -158,7 +159,7 @@ def retrieve_latents( raise AttributeError("Could not access latents of provided encoder_output") -class LTXImageToVideoPipeline(DiffusionPipeline): +class LTXImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin): r""" Pipeline for image-to-video generation. From 69400deecec76e4be8d2cf3b6ced96e89294c364 Mon Sep 17 00:00:00 2001 From: Aryan Date: Tue, 10 Dec 2024 10:20:09 +0100 Subject: [PATCH 45/51] update docs --- docs/source/en/api/pipelines/ltx.md | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/docs/source/en/api/pipelines/ltx.md b/docs/source/en/api/pipelines/ltx.md index 17032fede952..007f43f77b0c 100644 --- a/docs/source/en/api/pipelines/ltx.md +++ b/docs/source/en/api/pipelines/ltx.md @@ -22,6 +22,35 @@ Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers.m +## Loading Single Files + +Loading the original LTX Video checkpoints is also possible with [`~ModelMixin.from_single_file`]. + +```python +import torch +from diffusers import AutoencoderKLLTX, LTXImageToVideoPipeline, LTXTransformer3DModel + +single_file_url = "https://huggingface.co/Lightricks/LTX-Video/ltx-video-2b-v0.9.safetensors" +transformer = LTXTransformer3DModel.from_single_file(single_file_url, torch_dtype=torch.bfloat16) +vae = AutoencoderKLLTX.from_single_file(single_file_url, torch_dtype=torch.bfloat16) +pipe = LTXImageToVideoPipeline.from_pretrained("Lightricks/LTX-Video", transformer=transformer, vae=vae, torch_dtype=torch.bfloat16) + +# ... inference code ... +``` + +Alternative, the pipeline can be used to load the weights with [~FromSingleFileMixin.from_single_file`]. + +```python +import torch +from diffusers import LTXImageToVideoPipeline +from transformers import T5EncoderModel, T5Tokenizer + +single_file_url = "https://huggingface.co/Lightricks/LTX-Video/ltx-video-2b-v0.9.safetensors" +text_encoder = T5EncoderModel.from_pretrained("Lightricks/LTX-Video", subfolder="text_encoder", torch_dtype=torch.bfloat16) +tokenizer = T5Tokenizer.from_pretrained("Lightricks/LTX-Video", subfolder="tokenizer", torch_dtype=torch.bfloat16) +pipe = LTXImageToVideoPipeline.from_single_file(single_file_url, text_encoder=text_encoder, tokenizer=tokenizer, torch_dtype=torch.bfloat16) +``` + ## LTXPipeline [[autodoc]] LTXPipeline From f5c4815b5d5ba951bf78fe667d082c2be9df5b26 Mon Sep 17 00:00:00 2001 From: Aryan Date: Wed, 11 Dec 2024 14:47:50 +0530 Subject: [PATCH 46/51] Update src/diffusers/models/autoencoders/autoencoder_kl_ltx.py --- src/diffusers/models/autoencoders/autoencoder_kl_ltx.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_ltx.py b/src/diffusers/models/autoencoders/autoencoder_kl_ltx.py index 1099a79bb344..ea5e0f1ec3d6 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_ltx.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_ltx.py @@ -28,7 +28,6 @@ from .vae import DecoderOutput, DiagonalGaussianDistribution -# Adapted from diffusers.models.autoencoders.autoencoder_kl_cogvideox.CogVideoXCausalConv3d class LTXCausalConv3d(nn.Module): def __init__( self, From 2106441ad633be8c1a112fca96467f35565cae6a Mon Sep 17 00:00:00 2001 From: Aryan Date: Wed, 11 Dec 2024 14:47:58 +0530 Subject: [PATCH 47/51] Update src/diffusers/models/autoencoders/autoencoder_kl_ltx.py --- src/diffusers/models/autoencoders/autoencoder_kl_ltx.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_ltx.py b/src/diffusers/models/autoencoders/autoencoder_kl_ltx.py index ea5e0f1ec3d6..852ef2da405b 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_ltx.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_ltx.py @@ -79,7 +79,6 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: return hidden_states -# Adapted from diffusers.models.autoencoders.autoencoder_kl_cogvideox.CogVideoXResnetBlock3d class LTXResnetBlock3d(nn.Module): r""" A 3D ResNet block used in the LTX model. From 4aa78960377d7eb2c54797fc58f7311bed59750d Mon Sep 17 00:00:00 2001 From: Aryan Date: Wed, 11 Dec 2024 20:33:16 +0100 Subject: [PATCH 48/51] rename classes based on ltx review --- docs/source/en/_toctree.yml | 10 +++++----- ...toencoderkl_ltx.md => autoencoderkl_ltx_video.md} | 10 +++++----- ...x_transformer3d.md => ltx_video_transformer3d.md} | 10 +++++----- .../source/en/api/pipelines/{ltx.md => ltx_video.md} | 8 ++++---- scripts/convert_ltx_to_diffusers.py | 10 +++++----- src/diffusers/__init__.py | 8 ++++---- src/diffusers/loaders/single_file_model.py | 4 ++-- src/diffusers/models/__init__.py | 8 ++++---- src/diffusers/models/autoencoders/__init__.py | 2 +- .../models/autoencoders/autoencoder_kl_ltx.py | 10 +++++----- src/diffusers/models/transformers/__init__.py | 2 +- src/diffusers/models/transformers/transformer_ltx.py | 2 +- src/diffusers/pipelines/ltx/pipeline_ltx.py | 12 ++++++------ .../pipelines/ltx/pipeline_ltx_image2video.py | 12 ++++++------ src/diffusers/utils/dummy_pt_objects.py | 4 ++-- .../transformers/test_models_transformer_ltx.py | 6 +++--- tests/pipelines/ltx/test_ltx.py | 6 +++--- tests/pipelines/ltx/test_ltx_image2video.py | 11 ++++++++--- 18 files changed, 70 insertions(+), 65 deletions(-) rename docs/source/en/api/models/{autoencoderkl_ltx.md => autoencoderkl_ltx_video.md} (80%) rename docs/source/en/api/models/{ltx_transformer3d.md => ltx_video_transformer3d.md} (75%) rename docs/source/en/api/pipelines/{ltx.md => ltx_video.md} (87%) diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 87198b344aeb..55524b11172e 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -272,8 +272,8 @@ title: LatteTransformer3DModel - local: api/models/lumina_nextdit2d title: LuminaNextDiT2DModel - - local: api/models/ltx_transformer3d - title: LTXTransformer3DModel + - local: api/models/ltx_video_transformer3d + title: LTXVideoTransformer3DModel - local: api/models/mochi_transformer3d title: MochiTransformer3DModel - local: api/models/pixart_transformer2d @@ -312,8 +312,8 @@ title: AutoencoderKLAllegro - local: api/models/autoencoderkl_cogvideox title: AutoencoderKLCogVideoX - - local: api/models/autoencoderkl_ltx - title: AutoencoderKLLTX + - local: api/models/autoencoderkl_ltx_video + title: AutoencoderKLLTXVideo - local: api/models/autoencoderkl_mochi title: AutoencoderKLMochi - local: api/models/asymmetricautoencoderkl @@ -408,7 +408,7 @@ title: Latte - local: api/pipelines/ledits_pp title: LEDITS++ - - local: api/pipelines/ltx + - local: api/pipelines/ltx_video title: LTX - local: api/pipelines/lumina title: Lumina-T2X diff --git a/docs/source/en/api/models/autoencoderkl_ltx.md b/docs/source/en/api/models/autoencoderkl_ltx_video.md similarity index 80% rename from docs/source/en/api/models/autoencoderkl_ltx.md rename to docs/source/en/api/models/autoencoderkl_ltx_video.md index 7c2519866077..694b5ace6fdf 100644 --- a/docs/source/en/api/models/autoencoderkl_ltx.md +++ b/docs/source/en/api/models/autoencoderkl_ltx_video.md @@ -9,21 +9,21 @@ Unless required by applicable law or agreed to in writing, software distributed an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. --> -# AutoencoderKLLTX +# AutoencoderKLLTXVideo The 3D variational autoencoder (VAE) model with KL loss used in [LTX](https://huggingface.co/Lightricks/LTX-Video) was introduced by Lightricks. The model can be loaded with the following code snippet. ```python -from diffusers import AutoencoderKLLTX +from diffusers import AutoencoderKLLTXVideo -vae = AutoencoderKLLTX.from_pretrained("TODO/TODO", subfolder="vae", torch_dtype=torch.float32).to("cuda") +vae = AutoencoderKLLTXVideo.from_pretrained("TODO/TODO", subfolder="vae", torch_dtype=torch.float32).to("cuda") ``` -## AutoencoderKLLTX +## AutoencoderKLLTXVideo -[[autodoc]] AutoencoderKLLTX +[[autodoc]] AutoencoderKLLTXVideo - decode - encode - all diff --git a/docs/source/en/api/models/ltx_transformer3d.md b/docs/source/en/api/models/ltx_video_transformer3d.md similarity index 75% rename from docs/source/en/api/models/ltx_transformer3d.md rename to docs/source/en/api/models/ltx_video_transformer3d.md index e70a03ad7ea7..8a60bc0432c6 100644 --- a/docs/source/en/api/models/ltx_transformer3d.md +++ b/docs/source/en/api/models/ltx_video_transformer3d.md @@ -9,21 +9,21 @@ Unless required by applicable law or agreed to in writing, software distributed an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. --> -# LTXTransformer3DModel +# LTXVideoTransformer3DModel A Diffusion Transformer model for 3D data from [LTX](https://huggingface.co/Lightricks/LTX-Video) was introduced by Lightricks. The model can be loaded with the following code snippet. ```python -from diffusers import LTXTransformer3DModel +from diffusers import LTXVideoTransformer3DModel -transformer = LTXTransformer3DModel.from_pretrained("TODO/TODO", subfolder="transformer", torch_dtype=torch.bfloat16).to("cuda") +transformer = LTXVideoTransformer3DModel.from_pretrained("TODO/TODO", subfolder="transformer", torch_dtype=torch.bfloat16).to("cuda") ``` -## LTXTransformer3DModel +## LTXVideoTransformer3DModel -[[autodoc]] LTXTransformer3DModel +[[autodoc]] LTXVideoTransformer3DModel ## Transformer2DModelOutput diff --git a/docs/source/en/api/pipelines/ltx.md b/docs/source/en/api/pipelines/ltx_video.md similarity index 87% rename from docs/source/en/api/pipelines/ltx.md rename to docs/source/en/api/pipelines/ltx_video.md index 007f43f77b0c..162e1334ce9a 100644 --- a/docs/source/en/api/pipelines/ltx.md +++ b/docs/source/en/api/pipelines/ltx_video.md @@ -28,17 +28,17 @@ Loading the original LTX Video checkpoints is also possible with [`~ModelMixin.f ```python import torch -from diffusers import AutoencoderKLLTX, LTXImageToVideoPipeline, LTXTransformer3DModel +from diffusers import AutoencoderKLLTXVideo, LTXImageToVideoPipeline, LTXVideoTransformer3DModel single_file_url = "https://huggingface.co/Lightricks/LTX-Video/ltx-video-2b-v0.9.safetensors" -transformer = LTXTransformer3DModel.from_single_file(single_file_url, torch_dtype=torch.bfloat16) -vae = AutoencoderKLLTX.from_single_file(single_file_url, torch_dtype=torch.bfloat16) +transformer = LTXVideoTransformer3DModel.from_single_file(single_file_url, torch_dtype=torch.bfloat16) +vae = AutoencoderKLLTXVideo.from_single_file(single_file_url, torch_dtype=torch.bfloat16) pipe = LTXImageToVideoPipeline.from_pretrained("Lightricks/LTX-Video", transformer=transformer, vae=vae, torch_dtype=torch.bfloat16) # ... inference code ... ``` -Alternative, the pipeline can be used to load the weights with [~FromSingleFileMixin.from_single_file`]. +Alternatively, the pipeline can be used to load the weights with [~FromSingleFileMixin.from_single_file`]. ```python import torch diff --git a/scripts/convert_ltx_to_diffusers.py b/scripts/convert_ltx_to_diffusers.py index 053f33253751..f4398a2e687c 100644 --- a/scripts/convert_ltx_to_diffusers.py +++ b/scripts/convert_ltx_to_diffusers.py @@ -5,7 +5,7 @@ from safetensors.torch import load_file from transformers import T5EncoderModel, T5Tokenizer -from diffusers import AutoencoderKLLTX, FlowMatchEulerDiscreteScheduler, LTXPipeline, LTXTransformer3DModel +from diffusers import AutoencoderKLLTXVideo, FlowMatchEulerDiscreteScheduler, LTXPipeline, LTXVideoTransformer3DModel def remove_keys_(key: str, state_dict: Dict[str, Any]): @@ -83,7 +83,7 @@ def convert_transformer( PREFIX_KEY = "" original_state_dict = get_state_dict(load_file(ckpt_path)) - transformer = LTXTransformer3DModel().to(dtype=dtype) + transformer = LTXVideoTransformer3DModel().to(dtype=dtype) for key in list(original_state_dict.keys()): new_key = key[len(PREFIX_KEY) :] @@ -103,7 +103,7 @@ def convert_transformer( def convert_vae(ckpt_path: str, dtype: torch.dtype): original_state_dict = get_state_dict(load_file(ckpt_path)) - vae = AutoencoderKLLTX().to(dtype=dtype) + vae = AutoencoderKLLTXVideo().to(dtype=dtype) for key in list(original_state_dict.keys()): new_key = key[:] @@ -166,14 +166,14 @@ def get_args(): assert args.transformer_ckpt_path is not None and args.vae_ckpt_path is not None if args.transformer_ckpt_path is not None: - transformer: LTXTransformer3DModel = convert_transformer(args.transformer_ckpt_path, dtype) + transformer: LTXVideoTransformer3DModel = convert_transformer(args.transformer_ckpt_path, dtype) if not args.save_pipeline: transformer.save_pretrained( args.output_path, safe_serialization=True, max_shard_size="5GB", variant=variant ) if args.vae_ckpt_path is not None: - vae: AutoencoderKLLTX = convert_vae(args.vae_ckpt_path, dtype) + vae: AutoencoderKLLTXVideo = convert_vae(args.vae_ckpt_path, dtype) if not args.save_pipeline: vae.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB", variant=variant) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index e2784d879410..7b08d2fd9125 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -84,7 +84,7 @@ "AutoencoderKL", "AutoencoderKLAllegro", "AutoencoderKLCogVideoX", - "AutoencoderKLLTX", + "AutoencoderKLLTXVideo", "AutoencoderKLMochi", "AutoencoderKLTemporalDecoder", "AutoencoderOobleck", @@ -104,7 +104,7 @@ "I2VGenXLUNet", "Kandinsky3UNet", "LatteTransformer3DModel", - "LTXTransformer3DModel", + "LTXVideoTransformer3DModel", "LuminaNextDiT2DModel", "MochiTransformer3DModel", "ModelMixin", @@ -582,7 +582,7 @@ AutoencoderKL, AutoencoderKLAllegro, AutoencoderKLCogVideoX, - AutoencoderKLLTX, + AutoencoderKLLTXVideo, AutoencoderKLMochi, AutoencoderKLTemporalDecoder, AutoencoderOobleck, @@ -602,7 +602,7 @@ I2VGenXLUNet, Kandinsky3UNet, LatteTransformer3DModel, - LTXTransformer3DModel, + LTXVideoTransformer3DModel, LuminaNextDiT2DModel, MochiTransformer3DModel, ModelMixin, diff --git a/src/diffusers/loaders/single_file_model.py b/src/diffusers/loaders/single_file_model.py index 37d7fb57e4f3..d3613e5a76db 100644 --- a/src/diffusers/loaders/single_file_model.py +++ b/src/diffusers/loaders/single_file_model.py @@ -84,11 +84,11 @@ "checkpoint_mapping_fn": convert_flux_transformer_checkpoint_to_diffusers, "default_subfolder": "transformer", }, - "LTXTransformer3DModel": { + "LTXVideoTransformer3DModel": { "checkpoint_mapping_fn": convert_ltx_transformer_checkpoint_to_diffusers, "default_subfolder": "transformer", }, - "AutoencoderKLLTX": { + "AutoencoderKLLTXVideo": { "checkpoint_mapping_fn": convert_ltx_vae_checkpoint_to_diffusers, "default_subfolder": "vae", }, diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index 388def849727..987c3781eaf7 100644 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -31,7 +31,7 @@ _import_structure["autoencoders.autoencoder_kl"] = ["AutoencoderKL"] _import_structure["autoencoders.autoencoder_kl_allegro"] = ["AutoencoderKLAllegro"] _import_structure["autoencoders.autoencoder_kl_cogvideox"] = ["AutoencoderKLCogVideoX"] - _import_structure["autoencoders.autoencoder_kl_ltx"] = ["AutoencoderKLLTX"] + _import_structure["autoencoders.autoencoder_kl_ltx"] = ["AutoencoderKLLTXVideo"] _import_structure["autoencoders.autoencoder_kl_mochi"] = ["AutoencoderKLMochi"] _import_structure["autoencoders.autoencoder_kl_temporal_decoder"] = ["AutoencoderKLTemporalDecoder"] _import_structure["autoencoders.autoencoder_oobleck"] = ["AutoencoderOobleck"] @@ -65,7 +65,7 @@ _import_structure["transformers.transformer_allegro"] = ["AllegroTransformer3DModel"] _import_structure["transformers.transformer_cogview3plus"] = ["CogView3PlusTransformer2DModel"] _import_structure["transformers.transformer_flux"] = ["FluxTransformer2DModel"] - _import_structure["transformers.transformer_ltx"] = ["LTXTransformer3DModel"] + _import_structure["transformers.transformer_ltx"] = ["LTXVideoTransformer3DModel"] _import_structure["transformers.transformer_mochi"] = ["MochiTransformer3DModel"] _import_structure["transformers.transformer_sd3"] = ["SD3Transformer2DModel"] _import_structure["transformers.transformer_temporal"] = ["TransformerTemporalModel"] @@ -95,7 +95,7 @@ AutoencoderKL, AutoencoderKLAllegro, AutoencoderKLCogVideoX, - AutoencoderKLLTX, + AutoencoderKLLTXVideo, AutoencoderKLMochi, AutoencoderKLTemporalDecoder, AutoencoderOobleck, @@ -128,7 +128,7 @@ FluxTransformer2DModel, HunyuanDiT2DModel, LatteTransformer3DModel, - LTXTransformer3DModel, + LTXVideoTransformer3DModel, LuminaNextDiT2DModel, MochiTransformer3DModel, PixArtTransformer2DModel, diff --git a/src/diffusers/models/autoencoders/__init__.py b/src/diffusers/models/autoencoders/__init__.py index 4c448f385d1f..d08e67c40975 100644 --- a/src/diffusers/models/autoencoders/__init__.py +++ b/src/diffusers/models/autoencoders/__init__.py @@ -3,7 +3,7 @@ from .autoencoder_kl import AutoencoderKL from .autoencoder_kl_allegro import AutoencoderKLAllegro from .autoencoder_kl_cogvideox import AutoencoderKLCogVideoX -from .autoencoder_kl_ltx import AutoencoderKLLTX +from .autoencoder_kl_ltx import AutoencoderKLLTXVideo from .autoencoder_kl_mochi import AutoencoderKLMochi from .autoencoder_kl_temporal_decoder import AutoencoderKLTemporalDecoder from .autoencoder_oobleck import AutoencoderOobleck diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_ltx.py b/src/diffusers/models/autoencoders/autoencoder_kl_ltx.py index 852ef2da405b..ff202b980b95 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_ltx.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_ltx.py @@ -717,7 +717,7 @@ def create_forward(*inputs): return hidden_states -class AutoencoderKLLTX(ModelMixin, ConfigMixin, FromOriginalModelMixin): +class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin): r""" A VAE model with KL loss for encoding images into latents and decoding latent representations into images. Used in [LTX](https://huggingface.co/Lightricks/LTX-Video). @@ -900,7 +900,7 @@ def _encode(self, x: torch.Tensor) -> torch.Tensor: if self.use_framewise_encoding: # TODO(aryan): requires investigation raise NotImplementedError( - "Frame-wise encoding has not been implemented for AutoencoderKLLTX, at the moment, due to " + "Frame-wise encoding has not been implemented for AutoencoderKLLTXVideo, at the moment, due to " "quality issues caused by splitting inference across frame dimension. If you believe this " "should be possible, please submit a PR to https://github.com/huggingface/diffusers/pulls." ) @@ -947,7 +947,7 @@ def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOut if self.use_framewise_decoding: # TODO(aryan): requires investigation raise NotImplementedError( - "Frame-wise decoding has not been implemented for AutoencoderKLLTX, at the moment, due to " + "Frame-wise decoding has not been implemented for AutoencoderKLLTXVideo, at the moment, due to " "quality issues caused by splitting inference across frame dimension. If you believe this " "should be possible, please submit a PR to https://github.com/huggingface/diffusers/pulls." ) @@ -1032,7 +1032,7 @@ def tiled_encode(self, x: torch.Tensor) -> torch.Tensor: if self.use_framewise_encoding: # TODO(aryan): requires investigation raise NotImplementedError( - "Frame-wise encoding has not been implemented for AutoencoderKLLTX, at the moment, due to " + "Frame-wise encoding has not been implemented for AutoencoderKLLTXVideo, at the moment, due to " "quality issues caused by splitting inference across frame dimension. If you believe this " "should be possible, please submit a PR to https://github.com/huggingface/diffusers/pulls." ) @@ -1096,7 +1096,7 @@ def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[Decod if self.use_framewise_decoding: # TODO(aryan): requires investigation raise NotImplementedError( - "Frame-wise decoding has not been implemented for AutoencoderKLLTX, at the moment, due to " + "Frame-wise decoding has not been implemented for AutoencoderKLLTXVideo, at the moment, due to " "quality issues caused by splitting inference across frame dimension. If you believe this " "should be possible, please submit a PR to https://github.com/huggingface/diffusers/pulls." ) diff --git a/src/diffusers/models/transformers/__init__.py b/src/diffusers/models/transformers/__init__.py index 895c3aac4dcc..fed64d45fbd0 100644 --- a/src/diffusers/models/transformers/__init__.py +++ b/src/diffusers/models/transformers/__init__.py @@ -17,7 +17,7 @@ from .transformer_allegro import AllegroTransformer3DModel from .transformer_cogview3plus import CogView3PlusTransformer2DModel from .transformer_flux import FluxTransformer2DModel - from .transformer_ltx import LTXTransformer3DModel + from .transformer_ltx import LTXVideoTransformer3DModel from .transformer_mochi import MochiTransformer3DModel from .transformer_sd3 import SD3Transformer2DModel from .transformer_temporal import TransformerTemporalModel diff --git a/src/diffusers/models/transformers/transformer_ltx.py b/src/diffusers/models/transformers/transformer_ltx.py index b81f2709d1bb..8aa3a1590fb9 100644 --- a/src/diffusers/models/transformers/transformer_ltx.py +++ b/src/diffusers/models/transformers/transformer_ltx.py @@ -267,7 +267,7 @@ def forward( @maybe_allow_in_graph -class LTXTransformer3DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin): +class LTXVideoTransformer3DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin): r""" A Transformer model for video-like data used in [LTX](https://huggingface.co/Lightricks/LTX-Video). diff --git a/src/diffusers/pipelines/ltx/pipeline_ltx.py b/src/diffusers/pipelines/ltx/pipeline_ltx.py index 545ad2122034..36178f7e2ed6 100644 --- a/src/diffusers/pipelines/ltx/pipeline_ltx.py +++ b/src/diffusers/pipelines/ltx/pipeline_ltx.py @@ -21,8 +21,8 @@ from ...callbacks import MultiPipelineCallbacks, PipelineCallback from ...loaders import FromSingleFileMixin -from ...models.autoencoders import AutoencoderKLLTX -from ...models.transformers import LTXTransformer3DModel +from ...models.autoencoders import AutoencoderKLLTXVideo +from ...models.transformers import LTXVideoTransformer3DModel from ...schedulers import FlowMatchEulerDiscreteScheduler from ...utils import is_torch_xla_available, logging, replace_example_docstring from ...utils.torch_utils import randn_tensor @@ -147,11 +147,11 @@ class LTXPipeline(DiffusionPipeline, FromSingleFileMixin): Reference: https://github.com/Lightricks/LTX-Video Args: - transformer ([`LTXTransformer3DModel`]): + transformer ([`LTXVideoTransformer3DModel`]): Conditional Transformer architecture to denoise the encoded video latents. scheduler ([`FlowMatchEulerDiscreteScheduler`]): A scheduler to be used in combination with `transformer` to denoise the encoded image latents. - vae ([`AutoencoderKLLTX`]): + vae ([`AutoencoderKLLTXVideo`]): Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. text_encoder ([`T5EncoderModel`]): [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically @@ -171,10 +171,10 @@ class LTXPipeline(DiffusionPipeline, FromSingleFileMixin): def __init__( self, scheduler: FlowMatchEulerDiscreteScheduler, - vae: AutoencoderKLLTX, + vae: AutoencoderKLLTXVideo, text_encoder: T5EncoderModel, tokenizer: T5TokenizerFast, - transformer: LTXTransformer3DModel, + transformer: LTXVideoTransformer3DModel, ): super().__init__() diff --git a/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py b/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py index 0b43ab1204e7..297cebd150d7 100644 --- a/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py +++ b/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py @@ -22,8 +22,8 @@ from ...callbacks import MultiPipelineCallbacks, PipelineCallback from ...image_processor import PipelineImageInput from ...loaders import FromSingleFileMixin -from ...models.autoencoders import AutoencoderKLLTX -from ...models.transformers import LTXTransformer3DModel +from ...models.autoencoders import AutoencoderKLLTXVideo +from ...models.transformers import LTXVideoTransformer3DModel from ...schedulers import FlowMatchEulerDiscreteScheduler from ...utils import is_torch_xla_available, logging, replace_example_docstring from ...utils.torch_utils import randn_tensor @@ -166,11 +166,11 @@ class LTXImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin): Reference: https://github.com/Lightricks/LTX-Video Args: - transformer ([`LTXTransformer3DModel`]): + transformer ([`LTXVideoTransformer3DModel`]): Conditional Transformer architecture to denoise the encoded video latents. scheduler ([`FlowMatchEulerDiscreteScheduler`]): A scheduler to be used in combination with `transformer` to denoise the encoded image latents. - vae ([`AutoencoderKLLTX`]): + vae ([`AutoencoderKLLTXVideo`]): Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. text_encoder ([`T5EncoderModel`]): [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically @@ -190,10 +190,10 @@ class LTXImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin): def __init__( self, scheduler: FlowMatchEulerDiscreteScheduler, - vae: AutoencoderKLLTX, + vae: AutoencoderKLLTXVideo, text_encoder: T5EncoderModel, tokenizer: T5TokenizerFast, - transformer: LTXTransformer3DModel, + transformer: LTXVideoTransformer3DModel, ): super().__init__() diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index 5a67f392d27e..e9e2b4ed2ecd 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -107,7 +107,7 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) -class AutoencoderKLLTX(metaclass=DummyObject): +class AutoencoderKLLTXVideo(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): @@ -407,7 +407,7 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) -class LTXTransformer3DModel(metaclass=DummyObject): +class LTXVideoTransformer3DModel(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): diff --git a/tests/models/transformers/test_models_transformer_ltx.py b/tests/models/transformers/test_models_transformer_ltx.py index 1b0e4f0e4181..128bf04155e7 100644 --- a/tests/models/transformers/test_models_transformer_ltx.py +++ b/tests/models/transformers/test_models_transformer_ltx.py @@ -17,7 +17,7 @@ import torch -from diffusers import LTXTransformer3DModel +from diffusers import LTXVideoTransformer3DModel from diffusers.utils.testing_utils import enable_full_determinism, torch_device from ..test_modeling_common import ModelTesterMixin @@ -27,7 +27,7 @@ class LTXTransformerTests(ModelTesterMixin, unittest.TestCase): - model_class = LTXTransformer3DModel + model_class = LTXVideoTransformer3DModel main_input_name = "hidden_states" uses_custom_attn_processor = True @@ -79,5 +79,5 @@ def prepare_init_args_and_inputs_for_common(self): return init_dict, inputs_dict def test_gradient_checkpointing_is_applied(self): - expected_set = {"LTXTransformer3DModel"} + expected_set = {"LTXVideoTransformer3DModel"} super().test_gradient_checkpointing_is_applied(expected_set=expected_set) diff --git a/tests/pipelines/ltx/test_ltx.py b/tests/pipelines/ltx/test_ltx.py index 18e1cee38fe8..0f9819bfd6d8 100644 --- a/tests/pipelines/ltx/test_ltx.py +++ b/tests/pipelines/ltx/test_ltx.py @@ -19,7 +19,7 @@ import torch from transformers import AutoTokenizer, T5EncoderModel -from diffusers import AutoencoderKLLTX, FlowMatchEulerDiscreteScheduler, LTXPipeline, LTXTransformer3DModel +from diffusers import AutoencoderKLLTXVideo, FlowMatchEulerDiscreteScheduler, LTXPipeline, LTXVideoTransformer3DModel from diffusers.utils.testing_utils import enable_full_determinism, torch_device from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS @@ -49,7 +49,7 @@ class LTXPipelineFastTests(PipelineTesterMixin, unittest.TestCase): def get_dummy_components(self): torch.manual_seed(0) - transformer = LTXTransformer3DModel( + transformer = LTXVideoTransformer3DModel( in_channels=8, out_channels=8, patch_size=1, @@ -62,7 +62,7 @@ def get_dummy_components(self): ) torch.manual_seed(0) - vae = AutoencoderKLLTX( + vae = AutoencoderKLLTXVideo( latent_channels=8, block_out_channels=(8, 8, 8, 8), spatio_temporal_scaling=(True, True, False, False), diff --git a/tests/pipelines/ltx/test_ltx_image2video.py b/tests/pipelines/ltx/test_ltx_image2video.py index c5f4009fcf54..40397e4c3619 100644 --- a/tests/pipelines/ltx/test_ltx_image2video.py +++ b/tests/pipelines/ltx/test_ltx_image2video.py @@ -19,7 +19,12 @@ import torch from transformers import AutoTokenizer, T5EncoderModel -from diffusers import AutoencoderKLLTX, FlowMatchEulerDiscreteScheduler, LTXImageToVideoPipeline, LTXTransformer3DModel +from diffusers import ( + AutoencoderKLLTXVideo, + FlowMatchEulerDiscreteScheduler, + LTXImageToVideoPipeline, + LTXVideoTransformer3DModel, +) from diffusers.utils.testing_utils import enable_full_determinism, torch_device from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS @@ -49,7 +54,7 @@ class LTXImageToVideoPipelineFastTests(PipelineTesterMixin, unittest.TestCase): def get_dummy_components(self): torch.manual_seed(0) - transformer = LTXTransformer3DModel( + transformer = LTXVideoTransformer3DModel( in_channels=8, out_channels=8, patch_size=1, @@ -62,7 +67,7 @@ def get_dummy_components(self): ) torch.manual_seed(0) - vae = AutoencoderKLLTX( + vae = AutoencoderKLLTXVideo( latent_channels=8, block_out_channels=(8, 8, 8, 8), spatio_temporal_scaling=(True, True, False, False), From 93d93b1f9e45bc0232ac0239e1d4d25efeef9388 Mon Sep 17 00:00:00 2001 From: Aryan Date: Wed, 11 Dec 2024 20:43:36 +0100 Subject: [PATCH 49/51] point to original repository for inference --- src/diffusers/pipelines/ltx/pipeline_ltx.py | 2 +- src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/ltx/pipeline_ltx.py b/src/diffusers/pipelines/ltx/pipeline_ltx.py index 36178f7e2ed6..72b95fea1ce1 100644 --- a/src/diffusers/pipelines/ltx/pipeline_ltx.py +++ b/src/diffusers/pipelines/ltx/pipeline_ltx.py @@ -47,7 +47,7 @@ >>> from diffusers import LTXPipeline >>> from diffusers.utils import export_to_video - >>> pipe = LTXPipeline.from_pretrained("a-r-r-o-w/LTX-Video-diffusers", torch_dtype=torch.bfloat16) + >>> pipe = LTXPipeline.from_pretrained("Lightricks/LTX-Video", torch_dtype=torch.bfloat16) >>> pipe.to("cuda") >>> prompt = "A woman with long brown hair and light skin smiles at another woman with long blonde hair. The woman with brown hair wears a black jacket and has a small, barely noticeable mole on her right cheek. The camera angle is a close-up, focused on the woman with brown hair's face. The lighting is warm and natural, likely from the setting sun, casting a soft glow on the scene. The scene appears to be real-life footage" diff --git a/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py b/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py index 297cebd150d7..25ed635a3d17 100644 --- a/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py +++ b/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py @@ -48,7 +48,7 @@ >>> from diffusers import LTXImageToVideoPipeline >>> from diffusers.utils import export_to_video, load_image - >>> pipe = LTXImageToVideoPipeline.from_pretrained("a-r-r-o-w/LTX-Video-diffusers", torch_dtype=torch.bfloat16) + >>> pipe = LTXImageToVideoPipeline.from_pretrained("Lightricks/LTX-Video", torch_dtype=torch.bfloat16) >>> pipe.to("cuda") >>> image = load_image( From c9a9ab51c05d0e58c607be8996e51f278c07853b Mon Sep 17 00:00:00 2001 From: Aryan Date: Wed, 11 Dec 2024 20:52:18 +0100 Subject: [PATCH 50/51] make style --- src/diffusers/loaders/single_file_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/loaders/single_file_utils.py b/src/diffusers/loaders/single_file_utils.py index a941c02d5a80..376a263d7b10 100644 --- a/src/diffusers/loaders/single_file_utils.py +++ b/src/diffusers/loaders/single_file_utils.py @@ -2313,7 +2313,7 @@ def remove_keys_(key: str, state_dict): for replace_key, rename_key in VAE_KEYS_RENAME_DICT.items(): new_key = new_key.replace(replace_key, rename_key) converted_state_dict[new_key] = converted_state_dict.pop(key) - + for special_key, handler_fn_inplace in VAE_SPECIAL_KEYS_REMAP.items(): if special_key not in key: continue From 1e6796892b9692b4b9b1fe36841a2173d7272bc1 Mon Sep 17 00:00:00 2001 From: Aryan Date: Wed, 11 Dec 2024 20:59:01 +0100 Subject: [PATCH 51/51] resolve conflicts correctly --- src/diffusers/loaders/single_file_utils.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/src/diffusers/loaders/single_file_utils.py b/src/diffusers/loaders/single_file_utils.py index 376a263d7b10..21ff2841700d 100644 --- a/src/diffusers/loaders/single_file_utils.py +++ b/src/diffusers/loaders/single_file_utils.py @@ -2314,10 +2314,13 @@ def remove_keys_(key: str, state_dict): new_key = new_key.replace(replace_key, rename_key) converted_state_dict[new_key] = converted_state_dict.pop(key) - for special_key, handler_fn_inplace in VAE_SPECIAL_KEYS_REMAP.items(): - if special_key not in key: - continue - handler_fn_inplace(key, converted_state_dict) + for key in list(converted_state_dict.keys()): + for special_key, handler_fn_inplace in VAE_SPECIAL_KEYS_REMAP.items(): + if special_key not in key: + continue + handler_fn_inplace(key, converted_state_dict) + + return converted_state_dict def convert_autoencoder_dc_checkpoint_to_diffusers(checkpoint, **kwargs):