From 752ab0b4942249a8dfa6ff469dcfc7fe9de7d2a9 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Fri, 20 Dec 2024 13:32:14 +0100 Subject: [PATCH 1/2] update --- src/diffusers/loaders/single_file_model.py | 8 +- src/diffusers/loaders/single_file_utils.py | 135 ++++++++++++++++++ .../transformers/transformer_hunyuan_video.py | 4 +- 3 files changed, 145 insertions(+), 2 deletions(-) diff --git a/src/diffusers/loaders/single_file_model.py b/src/diffusers/loaders/single_file_model.py index d102282025c7..79dc2691b9e4 100644 --- a/src/diffusers/loaders/single_file_model.py +++ b/src/diffusers/loaders/single_file_model.py @@ -28,6 +28,7 @@ convert_autoencoder_dc_checkpoint_to_diffusers, convert_controlnet_checkpoint, convert_flux_transformer_checkpoint_to_diffusers, + convert_hunyuan_video_transformer_to_diffusers, convert_ldm_unet_checkpoint, convert_ldm_vae_checkpoint, convert_ltx_transformer_checkpoint_to_diffusers, @@ -101,6 +102,10 @@ "checkpoint_mapping_fn": convert_mochi_transformer_checkpoint_to_diffusers, "default_subfolder": "transformer", }, + "HunyuanVideoTransformer3DModel": { + "checkpoint_mapping_fn": convert_hunyuan_video_transformer_to_diffusers, + "default_subfolder": "transformer", + }, } @@ -220,6 +225,7 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] = local_files_only = kwargs.pop("local_files_only", None) subfolder = kwargs.pop("subfolder", None) revision = kwargs.pop("revision", None) + config_revision = kwargs.pop("config_revision", None) torch_dtype = kwargs.pop("torch_dtype", None) quantization_config = kwargs.pop("quantization_config", None) device = kwargs.pop("device", None) @@ -297,7 +303,7 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] = subfolder=subfolder, local_files_only=local_files_only, token=token, - revision=revision, + revision=config_revision, ) expected_kwargs, optional_kwargs = cls._get_signature_keys(cls) diff --git a/src/diffusers/loaders/single_file_utils.py b/src/diffusers/loaders/single_file_utils.py index f1408c2c409b..3f474b8b984d 100644 --- a/src/diffusers/loaders/single_file_utils.py +++ b/src/diffusers/loaders/single_file_utils.py @@ -108,6 +108,7 @@ "autoencoder-dc": "decoder.stages.1.op_list.0.main.conv.conv.bias", "autoencoder-dc-sana": "encoder.project_in.conv.bias", "mochi-1-preview": ["model.diffusion_model.blocks.0.attn.qkv_x.weight", "blocks.0.attn.qkv_x.weight"], + "hunyuan-video": "txt_in.individual_token_refiner.blocks.0.adaLN_modulation.1.bias", } DIFFUSERS_DEFAULT_PIPELINE_PATHS = { @@ -162,6 +163,7 @@ "autoencoder-dc-f32c32": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f32c32-mix-1.0-diffusers"}, "autoencoder-dc-f32c32-sana": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f32c32-sana-1.0-diffusers"}, "mochi-1-preview": {"pretrained_model_name_or_path": "genmo/mochi-1-preview"}, + "hunyuan-video": {"pretrained_model_name_or_path": "tencent/HunyuanVideo"}, } # Use to configure model sample size when original config is provided @@ -624,6 +626,9 @@ def infer_diffusers_model_type(checkpoint): elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["mochi-1-preview"]): model_type = "mochi-1-preview" + if CHECKPOINT_KEY_NAMES["hunyuan-video"] in checkpoint: + model_type = "hunyuan-video" + else: model_type = "v1" @@ -2522,3 +2527,133 @@ def convert_mochi_transformer_checkpoint_to_diffusers(checkpoint, **kwargs): new_state_dict["pos_frequencies"] = checkpoint.pop("pos_frequencies") return new_state_dict + + +def convert_hunyuan_video_transformer_to_diffusers(checkpoint, **kwargs): + def remap_norm_scale_shift_(key, state_dict): + weight = state_dict.pop(key) + shift, scale = weight.chunk(2, dim=0) + new_weight = torch.cat([scale, shift], dim=0) + state_dict[key.replace("final_layer.adaLN_modulation.1", "norm_out.linear")] = new_weight + + def remap_txt_in_(key, state_dict): + def rename_key(key): + new_key = key.replace("individual_token_refiner.blocks", "token_refiner.refiner_blocks") + new_key = new_key.replace("adaLN_modulation.1", "norm_out.linear") + new_key = new_key.replace("txt_in", "context_embedder") + new_key = new_key.replace("t_embedder.mlp.0", "time_text_embed.timestep_embedder.linear_1") + new_key = new_key.replace("t_embedder.mlp.2", "time_text_embed.timestep_embedder.linear_2") + new_key = new_key.replace("c_embedder", "time_text_embed.text_embedder") + new_key = new_key.replace("mlp", "ff") + return new_key + + if "self_attn_qkv" in key: + weight = state_dict.pop(key) + to_q, to_k, to_v = weight.chunk(3, dim=0) + state_dict[rename_key(key.replace("self_attn_qkv", "attn.to_q"))] = to_q + state_dict[rename_key(key.replace("self_attn_qkv", "attn.to_k"))] = to_k + state_dict[rename_key(key.replace("self_attn_qkv", "attn.to_v"))] = to_v + else: + state_dict[rename_key(key)] = state_dict.pop(key) + + def remap_img_attn_qkv_(key, state_dict): + weight = state_dict.pop(key) + to_q, to_k, to_v = weight.chunk(3, dim=0) + state_dict[key.replace("img_attn_qkv", "attn.to_q")] = to_q + state_dict[key.replace("img_attn_qkv", "attn.to_k")] = to_k + state_dict[key.replace("img_attn_qkv", "attn.to_v")] = to_v + + def remap_txt_attn_qkv_(key, state_dict): + weight = state_dict.pop(key) + to_q, to_k, to_v = weight.chunk(3, dim=0) + state_dict[key.replace("txt_attn_qkv", "attn.add_q_proj")] = to_q + state_dict[key.replace("txt_attn_qkv", "attn.add_k_proj")] = to_k + state_dict[key.replace("txt_attn_qkv", "attn.add_v_proj")] = to_v + + def remap_single_transformer_blocks_(key, state_dict): + hidden_size = 3072 + + if "linear1.weight" in key: + linear1_weight = state_dict.pop(key) + split_size = (hidden_size, hidden_size, hidden_size, linear1_weight.size(0) - 3 * hidden_size) + q, k, v, mlp = torch.split(linear1_weight, split_size, dim=0) + new_key = key.replace("single_blocks", "single_transformer_blocks").removesuffix(".linear1.weight") + state_dict[f"{new_key}.attn.to_q.weight"] = q + state_dict[f"{new_key}.attn.to_k.weight"] = k + state_dict[f"{new_key}.attn.to_v.weight"] = v + state_dict[f"{new_key}.proj_mlp.weight"] = mlp + + elif "linear1.bias" in key: + linear1_bias = state_dict.pop(key) + split_size = (hidden_size, hidden_size, hidden_size, linear1_bias.size(0) - 3 * hidden_size) + q_bias, k_bias, v_bias, mlp_bias = torch.split(linear1_bias, split_size, dim=0) + new_key = key.replace("single_blocks", "single_transformer_blocks").removesuffix(".linear1.bias") + state_dict[f"{new_key}.attn.to_q.bias"] = q_bias + state_dict[f"{new_key}.attn.to_k.bias"] = k_bias + state_dict[f"{new_key}.attn.to_v.bias"] = v_bias + state_dict[f"{new_key}.proj_mlp.bias"] = mlp_bias + + else: + new_key = key.replace("single_blocks", "single_transformer_blocks") + new_key = new_key.replace("linear2", "proj_out") + new_key = new_key.replace("q_norm", "attn.norm_q") + new_key = new_key.replace("k_norm", "attn.norm_k") + state_dict[new_key] = state_dict.pop(key) + + TRANSFORMER_KEYS_RENAME_DICT = { + "img_in": "x_embedder", + "time_in.mlp.0": "time_text_embed.timestep_embedder.linear_1", + "time_in.mlp.2": "time_text_embed.timestep_embedder.linear_2", + "guidance_in.mlp.0": "time_text_embed.guidance_embedder.linear_1", + "guidance_in.mlp.2": "time_text_embed.guidance_embedder.linear_2", + "vector_in.in_layer": "time_text_embed.text_embedder.linear_1", + "vector_in.out_layer": "time_text_embed.text_embedder.linear_2", + "double_blocks": "transformer_blocks", + "img_attn_q_norm": "attn.norm_q", + "img_attn_k_norm": "attn.norm_k", + "img_attn_proj": "attn.to_out.0", + "txt_attn_q_norm": "attn.norm_added_q", + "txt_attn_k_norm": "attn.norm_added_k", + "txt_attn_proj": "attn.to_add_out", + "img_mod.linear": "norm1.linear", + "img_norm1": "norm1.norm", + "img_norm2": "norm2", + "img_mlp": "ff", + "txt_mod.linear": "norm1_context.linear", + "txt_norm1": "norm1.norm", + "txt_norm2": "norm2_context", + "txt_mlp": "ff_context", + "self_attn_proj": "attn.to_out.0", + "modulation.linear": "norm.linear", + "pre_norm": "norm.norm", + "final_layer.norm_final": "norm_out.norm", + "final_layer.linear": "proj_out", + "fc1": "net.0.proj", + "fc2": "net.2", + "input_embedder": "proj_in", + } + + TRANSFORMER_SPECIAL_KEYS_REMAP = { + "txt_in": remap_txt_in_, + "img_attn_qkv": remap_img_attn_qkv_, + "txt_attn_qkv": remap_txt_attn_qkv_, + "single_blocks": remap_single_transformer_blocks_, + "final_layer.adaLN_modulation.1": remap_norm_scale_shift_, + } + + def update_state_dict_(state_dict, old_key, new_key): + state_dict[new_key] = state_dict.pop(old_key) + + for key in list(checkpoint.keys()): + new_key = key[:] + for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items(): + new_key = new_key.replace(replace_key, rename_key) + update_state_dict_(checkpoint, key, new_key) + + for key in list(checkpoint.keys()): + for special_key, handler_fn_inplace in TRANSFORMER_SPECIAL_KEYS_REMAP.items(): + if special_key not in key: + continue + handler_fn_inplace(key, checkpoint) + + return checkpoint diff --git a/src/diffusers/models/transformers/transformer_hunyuan_video.py b/src/diffusers/models/transformers/transformer_hunyuan_video.py index 089389b5f9ad..e3f24d97f3fa 100644 --- a/src/diffusers/models/transformers/transformer_hunyuan_video.py +++ b/src/diffusers/models/transformers/transformer_hunyuan_video.py @@ -18,6 +18,8 @@ import torch.nn as nn import torch.nn.functional as F +from diffusers.loaders import FromOriginalModelMixin + from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import PeftAdapterMixin from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers @@ -500,7 +502,7 @@ def forward( return hidden_states, encoder_hidden_states -class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): +class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin): r""" A Transformer model for video-like data used in [HunyuanVideo](https://huggingface.co/tencent/HunyuanVideo). From 898d7e57c0b0f098a0acb7c912a6c3264b158529 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Mon, 23 Dec 2024 08:11:05 +0530 Subject: [PATCH 2/2] Update src/diffusers/loaders/single_file_utils.py Co-authored-by: Aryan --- src/diffusers/loaders/single_file_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/loaders/single_file_utils.py b/src/diffusers/loaders/single_file_utils.py index 3f474b8b984d..5933c634f4cc 100644 --- a/src/diffusers/loaders/single_file_utils.py +++ b/src/diffusers/loaders/single_file_utils.py @@ -163,7 +163,7 @@ "autoencoder-dc-f32c32": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f32c32-mix-1.0-diffusers"}, "autoencoder-dc-f32c32-sana": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f32c32-sana-1.0-diffusers"}, "mochi-1-preview": {"pretrained_model_name_or_path": "genmo/mochi-1-preview"}, - "hunyuan-video": {"pretrained_model_name_or_path": "tencent/HunyuanVideo"}, + "hunyuan-video": {"pretrained_model_name_or_path": "hunyuanvideo-community/HunyuanVideo"}, } # Use to configure model sample size when original config is provided