From 428db9f9a4384735c020b0e4a549acb9c8e80208 Mon Sep 17 00:00:00 2001 From: Aryan Date: Thu, 5 Dec 2024 23:03:05 +0100 Subject: [PATCH 1/3] from original file mixin for ltx --- src/diffusers/loaders/single_file_model.py | 12 ++- src/diffusers/loaders/single_file_utils.py | 93 +++++++++++++++++++ .../models/autoencoders/autoencoder_kl_ltx.py | 3 +- .../models/transformers/transformer_ltx.py | 3 +- 4 files changed, 108 insertions(+), 3 deletions(-) diff --git a/src/diffusers/loaders/single_file_model.py b/src/diffusers/loaders/single_file_model.py index be3139057078..9a0c36b80853 100644 --- a/src/diffusers/loaders/single_file_model.py +++ b/src/diffusers/loaders/single_file_model.py @@ -27,6 +27,8 @@ convert_flux_transformer_checkpoint_to_diffusers, convert_ldm_unet_checkpoint, convert_ldm_vae_checkpoint, + convert_ltx_transformer_checkpoint_to_diffusers, + convert_ltx_vae_checkpoint_to_diffusers, convert_sd3_transformer_checkpoint_to_diffusers, convert_stable_cascade_unet_single_file_to_diffusers, create_controlnet_diffusers_config_from_ldm, @@ -82,6 +84,14 @@ "checkpoint_mapping_fn": convert_flux_transformer_checkpoint_to_diffusers, "default_subfolder": "transformer", }, + "LTXTransformer3DModel": { + "checkpoint_mapping_fn": convert_ltx_transformer_checkpoint_to_diffusers, + "config_mapping_fn": lambda *_, **__: {}, + }, + "AutoencoderKLLTX": { + "checkpoint_mapping_fn": convert_ltx_vae_checkpoint_to_diffusers, + "config_mapping_fn": lambda *_, **__: {}, + }, } @@ -219,7 +229,7 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] = mapping_functions = SINGLE_FILE_LOADABLE_CLASSES[mapping_class_name] checkpoint_mapping_fn = mapping_functions["checkpoint_mapping_fn"] - if original_config: + if original_config is not None: if "config_mapping_fn" in mapping_functions: config_mapping_fn = mapping_functions["config_mapping_fn"] else: diff --git a/src/diffusers/loaders/single_file_utils.py b/src/diffusers/loaders/single_file_utils.py index 10742873ded1..d7906a77806a 100644 --- a/src/diffusers/loaders/single_file_utils.py +++ b/src/diffusers/loaders/single_file_utils.py @@ -92,6 +92,9 @@ "double_blocks.0.img_attn.norm.key_norm.scale", "model.diffusion_model.double_blocks.0.img_attn.norm.key_norm.scale", ], + "ltx-video": [ + ("patchify_proj.weight", "transformer_blocks.27.scale_shift_table"), + ], } DIFFUSERS_DEFAULT_PIPELINE_PATHS = { @@ -138,6 +141,7 @@ "animatediff_rgb": {"pretrained_model_name_or_path": "guoyww/animatediff-sparsectrl-rgb"}, "flux-dev": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-dev"}, "flux-schnell": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-schnell"}, + "ltx-video": {"pretrained_model_name_or_path": "Lightricks/LTX-Video"}, } # Use to configure model sample size when original config is provided @@ -564,6 +568,10 @@ def infer_diffusers_model_type(checkpoint): model_type = "flux-dev" else: model_type = "flux-schnell" + + elif any(all(key in checkpoint for key in key_list) for key_list in CHECKPOINT_KEY_NAMES["ltx-video"]): + model_type = "ltx-video" + else: model_type = "v1" @@ -2198,3 +2206,88 @@ def swap_scale_shift(weight): ) return converted_state_dict + + +def convert_ltx_transformer_checkpoint_to_diffusers(checkpoint, **kwargs): + converted_state_dict = {key: checkpoint.pop(key) for key in list(checkpoint.keys())} + + TRANSFORMER_KEYS_RENAME_DICT = { + "patchify_proj": "proj_in", + "adaln_single": "time_embed", + "q_norm": "norm_q", + "k_norm": "norm_k", + } + + TRANSFORMER_SPECIAL_KEYS_REMAP = {} + + for key in list(converted_state_dict.keys()): + new_key = key + for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items(): + new_key = new_key.replace(replace_key, rename_key) + converted_state_dict[new_key] = converted_state_dict.pop(key) + + for key in list(converted_state_dict.keys()): + for special_key, handler_fn_inplace in TRANSFORMER_SPECIAL_KEYS_REMAP.items(): + if special_key not in key: + continue + handler_fn_inplace(key, converted_state_dict) + + return converted_state_dict + + +def convert_ltx_vae_checkpoint_to_diffusers(checkpoint, **kwargs): + converted_state_dict = {key: checkpoint.pop(key) for key in list(checkpoint.keys())} + + def remove_keys_(key: str, state_dict): + state_dict.pop(key) + + VAE_KEYS_RENAME_DICT = { + # decoder + "up_blocks.0": "mid_block", + "up_blocks.1": "up_blocks.0", + "up_blocks.2": "up_blocks.1.upsamplers.0", + "up_blocks.3": "up_blocks.1", + "up_blocks.4": "up_blocks.2.conv_in", + "up_blocks.5": "up_blocks.2.upsamplers.0", + "up_blocks.6": "up_blocks.2", + "up_blocks.7": "up_blocks.3.conv_in", + "up_blocks.8": "up_blocks.3.upsamplers.0", + "up_blocks.9": "up_blocks.3", + # encoder + "down_blocks.0": "down_blocks.0", + "down_blocks.1": "down_blocks.0.downsamplers.0", + "down_blocks.2": "down_blocks.0.conv_out", + "down_blocks.3": "down_blocks.1", + "down_blocks.4": "down_blocks.1.downsamplers.0", + "down_blocks.5": "down_blocks.1.conv_out", + "down_blocks.6": "down_blocks.2", + "down_blocks.7": "down_blocks.2.downsamplers.0", + "down_blocks.8": "down_blocks.3", + "down_blocks.9": "mid_block", + # common + "conv_shortcut": "conv_shortcut.conv", + "res_blocks": "resnets", + "norm3.norm": "norm3", + "per_channel_statistics.mean-of-means": "latents_mean", + "per_channel_statistics.std-of-means": "latents_std", + } + + VAE_SPECIAL_KEYS_REMAP = { + "per_channel_statistics.channel": remove_keys_, + "per_channel_statistics.mean-of-means": remove_keys_, + "per_channel_statistics.mean-of-stds": remove_keys_, + } + + for key in list(converted_state_dict.keys()): + new_key = key + for replace_key, rename_key in VAE_KEYS_RENAME_DICT.items(): + new_key = new_key.replace(replace_key, rename_key) + converted_state_dict[new_key] = converted_state_dict.pop(key) + + for key in list(converted_state_dict.keys()): + for special_key, handler_fn_inplace in VAE_SPECIAL_KEYS_REMAP.items(): + if special_key not in key: + continue + handler_fn_inplace(key, converted_state_dict) + + return converted_state_dict diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_ltx.py b/src/diffusers/models/autoencoders/autoencoder_kl_ltx.py index 1afa19091b74..1099a79bb344 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_ltx.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_ltx.py @@ -19,6 +19,7 @@ import torch.nn as nn from ...configuration_utils import ConfigMixin, register_to_config +from ...loaders import FromOriginalModelMixin from ...utils.accelerate_utils import apply_forward_hook from ..activations import get_activation from ..modeling_outputs import AutoencoderKLOutput @@ -718,7 +719,7 @@ def create_forward(*inputs): return hidden_states -class AutoencoderKLLTX(ModelMixin, ConfigMixin): +class AutoencoderKLLTX(ModelMixin, ConfigMixin, FromOriginalModelMixin): r""" A VAE model with KL loss for encoding images into latents and decoding latent representations into images. Used in [LTX](https://huggingface.co/Lightricks/LTX-Video). diff --git a/src/diffusers/models/transformers/transformer_ltx.py b/src/diffusers/models/transformers/transformer_ltx.py index 1aca1a246996..b81f2709d1bb 100644 --- a/src/diffusers/models/transformers/transformer_ltx.py +++ b/src/diffusers/models/transformers/transformer_ltx.py @@ -21,6 +21,7 @@ import torch.nn.functional as F from ...configuration_utils import ConfigMixin, register_to_config +from ...loaders import FromOriginalModelMixin from ...utils import is_torch_version, logging from ...utils.torch_utils import maybe_allow_in_graph from ..attention import FeedForward @@ -266,7 +267,7 @@ def forward( @maybe_allow_in_graph -class LTXTransformer3DModel(ModelMixin, ConfigMixin): +class LTXTransformer3DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin): r""" A Transformer model for video-like data used in [LTX](https://huggingface.co/Lightricks/LTX-Video). From f09f51c55a02cc98492e4ec20864365a651c3195 Mon Sep 17 00:00:00 2001 From: Aryan Date: Tue, 10 Dec 2024 08:01:27 +0100 Subject: [PATCH 2/3] undo config mapping fn changes --- src/diffusers/loaders/single_file_model.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/diffusers/loaders/single_file_model.py b/src/diffusers/loaders/single_file_model.py index 9a0c36b80853..8290d1771a19 100644 --- a/src/diffusers/loaders/single_file_model.py +++ b/src/diffusers/loaders/single_file_model.py @@ -86,11 +86,9 @@ }, "LTXTransformer3DModel": { "checkpoint_mapping_fn": convert_ltx_transformer_checkpoint_to_diffusers, - "config_mapping_fn": lambda *_, **__: {}, }, "AutoencoderKLLTX": { "checkpoint_mapping_fn": convert_ltx_vae_checkpoint_to_diffusers, - "config_mapping_fn": lambda *_, **__: {}, }, } @@ -229,7 +227,7 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] = mapping_functions = SINGLE_FILE_LOADABLE_CLASSES[mapping_class_name] checkpoint_mapping_fn = mapping_functions["checkpoint_mapping_fn"] - if original_config is not None: + if original_config: if "config_mapping_fn" in mapping_functions: config_mapping_fn = mapping_functions["config_mapping_fn"] else: From ca4b38c1b7ece44650db156d5746288a9994a05b Mon Sep 17 00:00:00 2001 From: Aryan Date: Tue, 10 Dec 2024 08:22:05 +0100 Subject: [PATCH 3/3] update --- src/diffusers/loaders/single_file_model.py | 3 +++ src/diffusers/loaders/single_file_utils.py | 16 ++++++++++++++-- 2 files changed, 17 insertions(+), 2 deletions(-) diff --git a/src/diffusers/loaders/single_file_model.py b/src/diffusers/loaders/single_file_model.py index 8290d1771a19..11ade76f7f70 100644 --- a/src/diffusers/loaders/single_file_model.py +++ b/src/diffusers/loaders/single_file_model.py @@ -86,9 +86,11 @@ }, "LTXTransformer3DModel": { "checkpoint_mapping_fn": convert_ltx_transformer_checkpoint_to_diffusers, + "default_subfolder": "transformer", }, "AutoencoderKLLTX": { "checkpoint_mapping_fn": convert_ltx_vae_checkpoint_to_diffusers, + "default_subfolder": "vae", }, } @@ -278,6 +280,7 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] = subfolder=subfolder, local_files_only=local_files_only, token=token, + revision=revision, ) expected_kwargs, optional_kwargs = cls._get_signature_keys(cls) diff --git a/src/diffusers/loaders/single_file_utils.py b/src/diffusers/loaders/single_file_utils.py index d7906a77806a..c6006c959f03 100644 --- a/src/diffusers/loaders/single_file_utils.py +++ b/src/diffusers/loaders/single_file_utils.py @@ -93,7 +93,10 @@ "model.diffusion_model.double_blocks.0.img_attn.norm.key_norm.scale", ], "ltx-video": [ - ("patchify_proj.weight", "transformer_blocks.27.scale_shift_table"), + ( + "model.diffusion_model.patchify_proj.weight", + "model.diffusion_model.transformer_blocks.27.scale_shift_table", + ), ], } @@ -2211,14 +2214,20 @@ def swap_scale_shift(weight): def convert_ltx_transformer_checkpoint_to_diffusers(checkpoint, **kwargs): converted_state_dict = {key: checkpoint.pop(key) for key in list(checkpoint.keys())} + def remove_keys_(key: str, state_dict): + state_dict.pop(key) + TRANSFORMER_KEYS_RENAME_DICT = { + "model.diffusion_model.": "", "patchify_proj": "proj_in", "adaln_single": "time_embed", "q_norm": "norm_q", "k_norm": "norm_k", } - TRANSFORMER_SPECIAL_KEYS_REMAP = {} + TRANSFORMER_SPECIAL_KEYS_REMAP = { + "vae": remove_keys_, + } for key in list(converted_state_dict.keys()): new_key = key @@ -2242,6 +2251,8 @@ def remove_keys_(key: str, state_dict): state_dict.pop(key) VAE_KEYS_RENAME_DICT = { + # common + "vae.": "", # decoder "up_blocks.0": "mid_block", "up_blocks.1": "up_blocks.0", @@ -2276,6 +2287,7 @@ def remove_keys_(key: str, state_dict): "per_channel_statistics.channel": remove_keys_, "per_channel_statistics.mean-of-means": remove_keys_, "per_channel_statistics.mean-of-stds": remove_keys_, + "model.diffusion_model": remove_keys_, } for key in list(converted_state_dict.keys()):