Skip to content

Commit 5029dbf

Browse files
style
1 parent 3980f97 commit 5029dbf

File tree

11 files changed

+233
-260
lines changed

11 files changed

+233
-260
lines changed

scripts/convert_hunyuan_video1_5_to_diffusers.py

Lines changed: 87 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,30 @@
1+
import argparse
2+
import json
3+
import os
4+
import pathlib
5+
6+
import torch
7+
from accelerate import init_empty_weights
8+
from huggingface_hub import hf_hub_download, snapshot_download
9+
from safetensors.torch import load_file
10+
from transformers import (
11+
AutoModel,
12+
AutoTokenizer,
13+
SiglipImageProcessor,
14+
SiglipVisionModel,
15+
T5EncoderModel,
16+
)
17+
18+
from diffusers import (
19+
AutoencoderKLHunyuanVideo15,
20+
ClassifierFreeGuidance,
21+
FlowMatchEulerDiscreteScheduler,
22+
HunyuanVideo15ImageToVideoPipeline,
23+
HunyuanVideo15Pipeline,
24+
HunyuanVideo15Transformer3DModel,
25+
)
26+
27+
128
# to convert only transformer
229
"""
330
python scripts/convert_hunyuan_video1_5_to_diffusers.py \
@@ -16,21 +43,6 @@
1643
--transformer_type 480p_t2v
1744
"""
1845

19-
import argparse
20-
from typing import Any, Dict
21-
22-
import torch
23-
from accelerate import init_empty_weights
24-
from safetensors.torch import load_file
25-
from huggingface_hub import snapshot_download, hf_hub_download
26-
27-
import pathlib
28-
from diffusers import HunyuanVideo15Transformer3DModel, AutoencoderKLHunyuanVideo15, FlowMatchEulerDiscreteScheduler, ClassifierFreeGuidance, HunyuanVideo15Pipeline, HunyuanVideo15ImageToVideoPipeline
29-
from transformers import AutoModel, AutoTokenizer, T5EncoderModel, ByT5Tokenizer, SiglipVisionModel, SiglipImageProcessor
30-
31-
import json
32-
import argparse
33-
import os
3446

3547
TRANSFORMER_CONFIGS = {
3648
"480p_t2v": {
@@ -107,6 +119,7 @@
107119
},
108120
}
109121

122+
110123
def swap_scale_shift(weight):
111124
shift, scale = weight.chunk(2, dim=0)
112125
new_weight = torch.cat([scale, shift], dim=0)
@@ -123,48 +136,42 @@ def convert_hyvideo15_transformer_to_diffusers(original_state_dict):
123136
converted_state_dict["time_embed.timestep_embedder.linear_1.weight"] = original_state_dict.pop(
124137
"time_in.mlp.0.weight"
125138
)
126-
converted_state_dict["time_embed.timestep_embedder.linear_1.bias"] = original_state_dict.pop(
127-
"time_in.mlp.0.bias"
128-
)
139+
converted_state_dict["time_embed.timestep_embedder.linear_1.bias"] = original_state_dict.pop("time_in.mlp.0.bias")
129140
converted_state_dict["time_embed.timestep_embedder.linear_2.weight"] = original_state_dict.pop(
130141
"time_in.mlp.2.weight"
131142
)
132-
converted_state_dict["time_embed.timestep_embedder.linear_2.bias"] = original_state_dict.pop(
133-
"time_in.mlp.2.bias"
134-
)
143+
converted_state_dict["time_embed.timestep_embedder.linear_2.bias"] = original_state_dict.pop("time_in.mlp.2.bias")
135144

136145
# 2. context_embedder.time_text_embed.timestep_embedder <- txt_in.t_embedder
137146
converted_state_dict["context_embedder.time_text_embed.timestep_embedder.linear_1.weight"] = (
138147
original_state_dict.pop("txt_in.t_embedder.mlp.0.weight")
139148
)
140-
converted_state_dict["context_embedder.time_text_embed.timestep_embedder.linear_1.bias"] = (
141-
original_state_dict.pop("txt_in.t_embedder.mlp.0.bias")
149+
converted_state_dict["context_embedder.time_text_embed.timestep_embedder.linear_1.bias"] = original_state_dict.pop(
150+
"txt_in.t_embedder.mlp.0.bias"
142151
)
143152
converted_state_dict["context_embedder.time_text_embed.timestep_embedder.linear_2.weight"] = (
144153
original_state_dict.pop("txt_in.t_embedder.mlp.2.weight")
145154
)
146-
converted_state_dict["context_embedder.time_text_embed.timestep_embedder.linear_2.bias"] = (
147-
original_state_dict.pop("txt_in.t_embedder.mlp.2.bias")
155+
converted_state_dict["context_embedder.time_text_embed.timestep_embedder.linear_2.bias"] = original_state_dict.pop(
156+
"txt_in.t_embedder.mlp.2.bias"
148157
)
149158

150159
# 3. context_embedder.time_text_embed.text_embedder <- txt_in.c_embedder
151-
converted_state_dict["context_embedder.time_text_embed.text_embedder.linear_1.weight"] = (
152-
original_state_dict.pop("txt_in.c_embedder.linear_1.weight")
160+
converted_state_dict["context_embedder.time_text_embed.text_embedder.linear_1.weight"] = original_state_dict.pop(
161+
"txt_in.c_embedder.linear_1.weight"
153162
)
154-
converted_state_dict["context_embedder.time_text_embed.text_embedder.linear_1.bias"] = (
155-
original_state_dict.pop("txt_in.c_embedder.linear_1.bias")
163+
converted_state_dict["context_embedder.time_text_embed.text_embedder.linear_1.bias"] = original_state_dict.pop(
164+
"txt_in.c_embedder.linear_1.bias"
156165
)
157-
converted_state_dict["context_embedder.time_text_embed.text_embedder.linear_2.weight"] = (
158-
original_state_dict.pop("txt_in.c_embedder.linear_2.weight")
166+
converted_state_dict["context_embedder.time_text_embed.text_embedder.linear_2.weight"] = original_state_dict.pop(
167+
"txt_in.c_embedder.linear_2.weight"
159168
)
160-
converted_state_dict["context_embedder.time_text_embed.text_embedder.linear_2.bias"] = (
161-
original_state_dict.pop("txt_in.c_embedder.linear_2.bias")
169+
converted_state_dict["context_embedder.time_text_embed.text_embedder.linear_2.bias"] = original_state_dict.pop(
170+
"txt_in.c_embedder.linear_2.bias"
162171
)
163172

164173
# 4. context_embedder.proj_in <- txt_in.input_embedder
165-
converted_state_dict["context_embedder.proj_in.weight"] = original_state_dict.pop(
166-
"txt_in.input_embedder.weight"
167-
)
174+
converted_state_dict["context_embedder.proj_in.weight"] = original_state_dict.pop("txt_in.input_embedder.weight")
168175
converted_state_dict["context_embedder.proj_in.bias"] = original_state_dict.pop("txt_in.input_embedder.bias")
169176

170177
# 5. context_embedder.token_refiner <- txt_in.individual_token_refiner
@@ -375,10 +382,12 @@ def convert_hyvideo15_transformer_to_diffusers(original_state_dict):
375382
)
376383

377384
# 11. norm_out and proj_out <- final_layer
378-
converted_state_dict["norm_out.linear.weight"] = swap_scale_shift(original_state_dict.pop(
379-
"final_layer.adaLN_modulation.1.weight"
380-
))
381-
converted_state_dict["norm_out.linear.bias"] = swap_scale_shift(original_state_dict.pop("final_layer.adaLN_modulation.1.bias"))
385+
converted_state_dict["norm_out.linear.weight"] = swap_scale_shift(
386+
original_state_dict.pop("final_layer.adaLN_modulation.1.weight")
387+
)
388+
converted_state_dict["norm_out.linear.bias"] = swap_scale_shift(
389+
original_state_dict.pop("final_layer.adaLN_modulation.1.bias")
390+
)
382391
converted_state_dict["proj_out.weight"] = original_state_dict.pop("final_layer.linear.weight")
383392
converted_state_dict["proj_out.bias"] = original_state_dict.pop("final_layer.linear.bias")
384393

@@ -572,6 +581,7 @@ def convert_hunyuan_video_15_vae_checkpoint_to_diffusers(
572581

573582
return converted
574583

584+
575585
def load_sharded_safetensors(dir: pathlib.Path):
576586
file_paths = list(dir.glob("diffusion_pytorch_model*.safetensors"))
577587
state_dict = {}
@@ -583,9 +593,9 @@ def load_sharded_safetensors(dir: pathlib.Path):
583593
def load_original_transformer_state_dict(args):
584594
if args.original_state_dict_repo_id is not None:
585595
model_dir = snapshot_download(
586-
args.original_state_dict_repo_id,
596+
args.original_state_dict_repo_id,
587597
repo_type="model",
588-
allow_patterns="transformer/" + args.transformer_type + "/*"
598+
allow_patterns="transformer/" + args.transformer_type + "/*",
589599
)
590600
elif args.original_state_dict_folder is not None:
591601
model_dir = pathlib.Path(args.original_state_dict_folder)
@@ -599,8 +609,7 @@ def load_original_transformer_state_dict(args):
599609
def load_original_vae_state_dict(args):
600610
if args.original_state_dict_repo_id is not None:
601611
ckpt_path = hf_hub_download(
602-
repo_id=args.original_state_dict_repo_id,
603-
filename= "vae/diffusion_pytorch_model.safetensors"
612+
repo_id=args.original_state_dict_repo_id, filename="vae/diffusion_pytorch_model.safetensors"
604613
)
605614
elif args.original_state_dict_folder is not None:
606615
model_dir = pathlib.Path(args.original_state_dict_folder)
@@ -632,24 +641,27 @@ def convert_vae(args):
632641
vae.load_state_dict(state_dict, strict=True, assign=True)
633642
return vae
634643

644+
635645
def load_mllm():
636-
print(f" loading from Qwen/Qwen2.5-VL-7B-Instruct")
637-
text_encoder = AutoModel.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct", torch_dtype=torch.bfloat16,low_cpu_mem_usage=True)
638-
if hasattr(text_encoder, 'language_model'):
646+
print(" loading from Qwen/Qwen2.5-VL-7B-Instruct")
647+
text_encoder = AutoModel.from_pretrained(
648+
"Qwen/Qwen2.5-VL-7B-Instruct", torch_dtype=torch.bfloat16, low_cpu_mem_usage=True
649+
)
650+
if hasattr(text_encoder, "language_model"):
639651
text_encoder = text_encoder.language_model
640652
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct", padding_side="right")
641653
return text_encoder, tokenizer
642654

643655

644-
#copied from https://github.com/Tencent-Hunyuan/HunyuanVideo-1.5/blob/910da2a829c484ea28982e8cff3bbc2cacdf1681/hyvideo/models/text_encoders/byT5/__init__.py#L89
656+
# copied from https://github.com/Tencent-Hunyuan/HunyuanVideo-1.5/blob/910da2a829c484ea28982e8cff3bbc2cacdf1681/hyvideo/models/text_encoders/byT5/__init__.py#L89
645657
def add_special_token(
646658
tokenizer,
647659
text_encoder,
648660
add_color=True,
649661
add_font=True,
650662
multilingual=True,
651-
color_ann_path='assets/color_idx.json',
652-
font_ann_path='assets/multilingual_10-lang_idx.json',
663+
color_ann_path="assets/color_idx.json",
664+
font_ann_path="assets/multilingual_10-lang_idx.json",
653665
):
654666
"""
655667
Add special tokens for color and font to tokenizer and text encoder.
@@ -663,16 +675,16 @@ def add_special_token(
663675
font_ann_path (str): Path to font annotation JSON.
664676
multilingual (bool): Whether to use multilingual font tokens.
665677
"""
666-
with open(font_ann_path, 'r') as f:
678+
with open(font_ann_path, "r") as f:
667679
idx_font_dict = json.load(f)
668-
with open(color_ann_path, 'r') as f:
680+
with open(color_ann_path, "r") as f:
669681
idx_color_dict = json.load(f)
670682

671683
if multilingual:
672-
font_token = [f'<{font_code[:2]}-font-{idx_font_dict[font_code]}>' for font_code in idx_font_dict]
684+
font_token = [f"<{font_code[:2]}-font-{idx_font_dict[font_code]}>" for font_code in idx_font_dict]
673685
else:
674-
font_token = [f'<font-{i}>' for i in range(len(idx_font_dict))]
675-
color_token = [f'<color-{i}>' for i in range(len(idx_color_dict))]
686+
font_token = [f"<font-{i}>" for i in range(len(idx_font_dict))]
687+
color_token = [f"<color-{i}>" for i in range(len(idx_color_dict))]
676688
additional_special_tokens = []
677689
if add_color:
678690
additional_special_tokens += color_token
@@ -688,14 +700,13 @@ def load_byt5(args):
688700
"""
689701
Load ByT5 encoder with Glyph-SDXL-v2 weights and save in HuggingFace format.
690702
"""
691-
692703

693704
# 1. Load base tokenizer and encoder
694705
tokenizer = AutoTokenizer.from_pretrained("google/byt5-small")
695-
706+
696707
# Load as T5EncoderModel
697708
encoder = T5EncoderModel.from_pretrained("google/byt5-small")
698-
709+
699710
byt5_checkpoint_path = os.path.join(args.byt5_path, "checkpoints/byt5_model.pt")
700711
color_ann_path = os.path.join(args.byt5_path, "assets/color_idx.json")
701712
font_ann_path = os.path.join(args.byt5_path, "assets/multilingual_10-lang_idx.json")
@@ -710,48 +721,45 @@ def load_byt5(args):
710721
font_ann_path=font_ann_path,
711722
multilingual=True,
712723
)
713-
714-
724+
715725
# 3. Load Glyph-SDXL-v2 checkpoint
716726
print(f"\n3. Loading Glyph-SDXL-v2 checkpoint: {byt5_checkpoint_path}")
717-
checkpoint = torch.load(byt5_checkpoint_path, map_location='cpu')
718-
727+
checkpoint = torch.load(byt5_checkpoint_path, map_location="cpu")
728+
719729
# Handle different checkpoint formats
720-
if 'state_dict' in checkpoint:
721-
state_dict = checkpoint['state_dict']
730+
if "state_dict" in checkpoint:
731+
state_dict = checkpoint["state_dict"]
722732
else:
723733
state_dict = checkpoint
724-
725-
# add 'encoder.' prefix to the keys
734+
735+
# add 'encoder.' prefix to the keys
726736
# Remove 'module.text_tower.encoder.' prefix if present
727737
cleaned_state_dict = {}
728738
for key, value in state_dict.items():
729-
if key.startswith('module.text_tower.encoder.'):
730-
new_key = 'encoder.' + key[len('module.text_tower.encoder.'):]
739+
if key.startswith("module.text_tower.encoder."):
740+
new_key = "encoder." + key[len("module.text_tower.encoder.") :]
731741
cleaned_state_dict[new_key] = value
732742
else:
733-
new_key = 'encoder.' + key
743+
new_key = "encoder." + key
734744
cleaned_state_dict[new_key] = value
735-
736-
745+
737746
# 4. Load weights
738747
missing_keys, unexpected_keys = encoder.load_state_dict(cleaned_state_dict, strict=False)
739748
if unexpected_keys:
740749
raise ValueError(f"Unexpected keys: {unexpected_keys}")
741750
if "shared.weight" in missing_keys:
742-
print(f" Missing shared.weight as expected")
751+
print(" Missing shared.weight as expected")
743752
missing_keys.remove("shared.weight")
744753
if missing_keys:
745754
raise ValueError(f"Missing keys: {missing_keys}")
746-
747-
755+
748756
return encoder, tokenizer
749757

750758

751759
def load_siglip():
752760
image_encoder = SiglipVisionModel.from_pretrained(
753761
"black-forest-labs/FLUX.1-Redux-dev", subfolder="image_encoder", torch_dtype=torch.bfloat16
754-
)
762+
)
755763
feature_extractor = SiglipImageProcessor.from_pretrained(
756764
"black-forest-labs/FLUX.1-Redux-dev", subfolder="feature_extractor"
757765
)
@@ -763,11 +771,11 @@ def get_args():
763771
parser.add_argument(
764772
"--original_state_dict_repo_id", type=str, default=None, help="Path to original hub_id for the model"
765773
)
766-
parser.add_argument("--original_state_dict_folder", type=str, default=None, help="Local folder name of the original state dict")
767-
parser.add_argument("--output_path", type=str, required=True, help="Path where converted model(s) should be saved")
768774
parser.add_argument(
769-
"--transformer_type", type=str, default="480p_i2v", choices=list(TRANSFORMER_CONFIGS.keys())
775+
"--original_state_dict_folder", type=str, default=None, help="Local folder name of the original state dict"
770776
)
777+
parser.add_argument("--output_path", type=str, required=True, help="Path where converted model(s) should be saved")
778+
parser.add_argument("--transformer_type", type=str, default="480p_i2v", choices=list(TRANSFORMER_CONFIGS.keys()))
771779
parser.add_argument(
772780
"--byt5_path",
773781
type=str,
@@ -826,7 +834,7 @@ def get_args():
826834
feature_extractor=feature_extractor,
827835
)
828836
elif task_type == "t2v":
829-
pipeline = HunyuanVideo15Text2VideoPipeline(
837+
pipeline = HunyuanVideo15Pipeline(
830838
vae=vae,
831839
text_encoder=text_encoder,
832840
text_encoder_2=text_encoder_2,
@@ -840,6 +848,3 @@ def get_args():
840848
raise ValueError(f"Task type {task_type} is not supported")
841849

842850
pipeline.save_pretrained(args.output_path, safe_serialization=True)
843-
844-
845-

src/diffusers/__init__.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -483,11 +483,11 @@
483483
"HunyuanImagePipeline",
484484
"HunyuanImageRefinerPipeline",
485485
"HunyuanSkyreelsImageToVideoPipeline",
486+
"HunyuanVideo15ImageToVideoPipeline",
487+
"HunyuanVideo15Pipeline",
486488
"HunyuanVideoFramepackPipeline",
487489
"HunyuanVideoImageToVideoPipeline",
488490
"HunyuanVideoPipeline",
489-
"HunyuanVideo15Pipeline",
490-
"HunyuanVideo15ImageToVideoPipeline",
491491
"I2VGenXLPipeline",
492492
"IFImg2ImgPipeline",
493493
"IFImg2ImgSuperResolutionPipeline",
@@ -949,9 +949,9 @@
949949
HunyuanDiT2DModel,
950950
HunyuanDiT2DMultiControlNetModel,
951951
HunyuanImageTransformer2DModel,
952+
HunyuanVideo15Transformer3DModel,
952953
HunyuanVideoFramepackTransformer3DModel,
953954
HunyuanVideoTransformer3DModel,
954-
HunyuanVideo15Transformer3DModel,
955955
I2VGenXLUNet,
956956
Kandinsky3UNet,
957957
Kandinsky5Transformer3DModel,
@@ -1176,11 +1176,11 @@
11761176
HunyuanImagePipeline,
11771177
HunyuanImageRefinerPipeline,
11781178
HunyuanSkyreelsImageToVideoPipeline,
1179+
HunyuanVideo15ImageToVideoPipeline,
1180+
HunyuanVideo15Pipeline,
11791181
HunyuanVideoFramepackPipeline,
11801182
HunyuanVideoImageToVideoPipeline,
11811183
HunyuanVideoPipeline,
1182-
HunyuanVideo15Pipeline,
1183-
HunyuanVideo15ImageToVideoPipeline,
11841184
I2VGenXLPipeline,
11851185
IFImg2ImgPipeline,
11861186
IFImg2ImgSuperResolutionPipeline,

0 commit comments

Comments
 (0)