diff --git a/src/diffusers/loaders/single_file_model.py b/src/diffusers/loaders/single_file_model.py index be3139057078..11ade76f7f70 100644 --- a/src/diffusers/loaders/single_file_model.py +++ b/src/diffusers/loaders/single_file_model.py @@ -27,6 +27,8 @@ convert_flux_transformer_checkpoint_to_diffusers, convert_ldm_unet_checkpoint, convert_ldm_vae_checkpoint, + convert_ltx_transformer_checkpoint_to_diffusers, + convert_ltx_vae_checkpoint_to_diffusers, convert_sd3_transformer_checkpoint_to_diffusers, convert_stable_cascade_unet_single_file_to_diffusers, create_controlnet_diffusers_config_from_ldm, @@ -82,6 +84,14 @@ "checkpoint_mapping_fn": convert_flux_transformer_checkpoint_to_diffusers, "default_subfolder": "transformer", }, + "LTXTransformer3DModel": { + "checkpoint_mapping_fn": convert_ltx_transformer_checkpoint_to_diffusers, + "default_subfolder": "transformer", + }, + "AutoencoderKLLTX": { + "checkpoint_mapping_fn": convert_ltx_vae_checkpoint_to_diffusers, + "default_subfolder": "vae", + }, } @@ -270,6 +280,7 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] = subfolder=subfolder, local_files_only=local_files_only, token=token, + revision=revision, ) expected_kwargs, optional_kwargs = cls._get_signature_keys(cls) diff --git a/src/diffusers/loaders/single_file_utils.py b/src/diffusers/loaders/single_file_utils.py index 10742873ded1..c6006c959f03 100644 --- a/src/diffusers/loaders/single_file_utils.py +++ b/src/diffusers/loaders/single_file_utils.py @@ -92,6 +92,12 @@ "double_blocks.0.img_attn.norm.key_norm.scale", "model.diffusion_model.double_blocks.0.img_attn.norm.key_norm.scale", ], + "ltx-video": [ + ( + "model.diffusion_model.patchify_proj.weight", + "model.diffusion_model.transformer_blocks.27.scale_shift_table", + ), + ], } DIFFUSERS_DEFAULT_PIPELINE_PATHS = { @@ -138,6 +144,7 @@ "animatediff_rgb": {"pretrained_model_name_or_path": "guoyww/animatediff-sparsectrl-rgb"}, "flux-dev": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-dev"}, "flux-schnell": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-schnell"}, + "ltx-video": {"pretrained_model_name_or_path": "Lightricks/LTX-Video"}, } # Use to configure model sample size when original config is provided @@ -564,6 +571,10 @@ def infer_diffusers_model_type(checkpoint): model_type = "flux-dev" else: model_type = "flux-schnell" + + elif any(all(key in checkpoint for key in key_list) for key_list in CHECKPOINT_KEY_NAMES["ltx-video"]): + model_type = "ltx-video" + else: model_type = "v1" @@ -2198,3 +2209,97 @@ def swap_scale_shift(weight): ) return converted_state_dict + + +def convert_ltx_transformer_checkpoint_to_diffusers(checkpoint, **kwargs): + converted_state_dict = {key: checkpoint.pop(key) for key in list(checkpoint.keys())} + + def remove_keys_(key: str, state_dict): + state_dict.pop(key) + + TRANSFORMER_KEYS_RENAME_DICT = { + "model.diffusion_model.": "", + "patchify_proj": "proj_in", + "adaln_single": "time_embed", + "q_norm": "norm_q", + "k_norm": "norm_k", + } + + TRANSFORMER_SPECIAL_KEYS_REMAP = { + "vae": remove_keys_, + } + + for key in list(converted_state_dict.keys()): + new_key = key + for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items(): + new_key = new_key.replace(replace_key, rename_key) + converted_state_dict[new_key] = converted_state_dict.pop(key) + + for key in list(converted_state_dict.keys()): + for special_key, handler_fn_inplace in TRANSFORMER_SPECIAL_KEYS_REMAP.items(): + if special_key not in key: + continue + handler_fn_inplace(key, converted_state_dict) + + return converted_state_dict + + +def convert_ltx_vae_checkpoint_to_diffusers(checkpoint, **kwargs): + converted_state_dict = {key: checkpoint.pop(key) for key in list(checkpoint.keys())} + + def remove_keys_(key: str, state_dict): + state_dict.pop(key) + + VAE_KEYS_RENAME_DICT = { + # common + "vae.": "", + # decoder + "up_blocks.0": "mid_block", + "up_blocks.1": "up_blocks.0", + "up_blocks.2": "up_blocks.1.upsamplers.0", + "up_blocks.3": "up_blocks.1", + "up_blocks.4": "up_blocks.2.conv_in", + "up_blocks.5": "up_blocks.2.upsamplers.0", + "up_blocks.6": "up_blocks.2", + "up_blocks.7": "up_blocks.3.conv_in", + "up_blocks.8": "up_blocks.3.upsamplers.0", + "up_blocks.9": "up_blocks.3", + # encoder + "down_blocks.0": "down_blocks.0", + "down_blocks.1": "down_blocks.0.downsamplers.0", + "down_blocks.2": "down_blocks.0.conv_out", + "down_blocks.3": "down_blocks.1", + "down_blocks.4": "down_blocks.1.downsamplers.0", + "down_blocks.5": "down_blocks.1.conv_out", + "down_blocks.6": "down_blocks.2", + "down_blocks.7": "down_blocks.2.downsamplers.0", + "down_blocks.8": "down_blocks.3", + "down_blocks.9": "mid_block", + # common + "conv_shortcut": "conv_shortcut.conv", + "res_blocks": "resnets", + "norm3.norm": "norm3", + "per_channel_statistics.mean-of-means": "latents_mean", + "per_channel_statistics.std-of-means": "latents_std", + } + + VAE_SPECIAL_KEYS_REMAP = { + "per_channel_statistics.channel": remove_keys_, + "per_channel_statistics.mean-of-means": remove_keys_, + "per_channel_statistics.mean-of-stds": remove_keys_, + "model.diffusion_model": remove_keys_, + } + + for key in list(converted_state_dict.keys()): + new_key = key + for replace_key, rename_key in VAE_KEYS_RENAME_DICT.items(): + new_key = new_key.replace(replace_key, rename_key) + converted_state_dict[new_key] = converted_state_dict.pop(key) + + for key in list(converted_state_dict.keys()): + for special_key, handler_fn_inplace in VAE_SPECIAL_KEYS_REMAP.items(): + if special_key not in key: + continue + handler_fn_inplace(key, converted_state_dict) + + return converted_state_dict diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_ltx.py b/src/diffusers/models/autoencoders/autoencoder_kl_ltx.py index 1afa19091b74..1099a79bb344 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_ltx.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_ltx.py @@ -19,6 +19,7 @@ import torch.nn as nn from ...configuration_utils import ConfigMixin, register_to_config +from ...loaders import FromOriginalModelMixin from ...utils.accelerate_utils import apply_forward_hook from ..activations import get_activation from ..modeling_outputs import AutoencoderKLOutput @@ -718,7 +719,7 @@ def create_forward(*inputs): return hidden_states -class AutoencoderKLLTX(ModelMixin, ConfigMixin): +class AutoencoderKLLTX(ModelMixin, ConfigMixin, FromOriginalModelMixin): r""" A VAE model with KL loss for encoding images into latents and decoding latent representations into images. Used in [LTX](https://huggingface.co/Lightricks/LTX-Video). diff --git a/src/diffusers/models/transformers/transformer_ltx.py b/src/diffusers/models/transformers/transformer_ltx.py index 1aca1a246996..b81f2709d1bb 100644 --- a/src/diffusers/models/transformers/transformer_ltx.py +++ b/src/diffusers/models/transformers/transformer_ltx.py @@ -21,6 +21,7 @@ import torch.nn.functional as F from ...configuration_utils import ConfigMixin, register_to_config +from ...loaders import FromOriginalModelMixin from ...utils import is_torch_version, logging from ...utils.torch_utils import maybe_allow_in_graph from ..attention import FeedForward @@ -266,7 +267,7 @@ def forward( @maybe_allow_in_graph -class LTXTransformer3DModel(ModelMixin, ConfigMixin): +class LTXTransformer3DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin): r""" A Transformer model for video-like data used in [LTX](https://huggingface.co/Lightricks/LTX-Video).