1+ import argparse
2+ import json
3+ import os
4+ import pathlib
5+
6+ import torch
7+ from accelerate import init_empty_weights
8+ from huggingface_hub import hf_hub_download , snapshot_download
9+ from safetensors .torch import load_file
10+ from transformers import (
11+ AutoModel ,
12+ AutoTokenizer ,
13+ SiglipImageProcessor ,
14+ SiglipVisionModel ,
15+ T5EncoderModel ,
16+ )
17+
18+ from diffusers import (
19+ AutoencoderKLHunyuanVideo15 ,
20+ ClassifierFreeGuidance ,
21+ FlowMatchEulerDiscreteScheduler ,
22+ HunyuanVideo15ImageToVideoPipeline ,
23+ HunyuanVideo15Pipeline ,
24+ HunyuanVideo15Transformer3DModel ,
25+ )
26+
27+
128# to convert only transformer
229"""
330python scripts/convert_hunyuan_video1_5_to_diffusers.py \
1643 --transformer_type 480p_t2v
1744"""
1845
19- import argparse
20- from typing import Any , Dict
21-
22- import torch
23- from accelerate import init_empty_weights
24- from safetensors .torch import load_file
25- from huggingface_hub import snapshot_download , hf_hub_download
26-
27- import pathlib
28- from diffusers import HunyuanVideo15Transformer3DModel , AutoencoderKLHunyuanVideo15 , FlowMatchEulerDiscreteScheduler , ClassifierFreeGuidance , HunyuanVideo15Pipeline , HunyuanVideo15ImageToVideoPipeline
29- from transformers import AutoModel , AutoTokenizer , T5EncoderModel , ByT5Tokenizer , SiglipVisionModel , SiglipImageProcessor
30-
31- import json
32- import argparse
33- import os
3446
3547TRANSFORMER_CONFIGS = {
3648 "480p_t2v" : {
107119 },
108120}
109121
122+
110123def swap_scale_shift (weight ):
111124 shift , scale = weight .chunk (2 , dim = 0 )
112125 new_weight = torch .cat ([scale , shift ], dim = 0 )
@@ -123,48 +136,42 @@ def convert_hyvideo15_transformer_to_diffusers(original_state_dict):
123136 converted_state_dict ["time_embed.timestep_embedder.linear_1.weight" ] = original_state_dict .pop (
124137 "time_in.mlp.0.weight"
125138 )
126- converted_state_dict ["time_embed.timestep_embedder.linear_1.bias" ] = original_state_dict .pop (
127- "time_in.mlp.0.bias"
128- )
139+ converted_state_dict ["time_embed.timestep_embedder.linear_1.bias" ] = original_state_dict .pop ("time_in.mlp.0.bias" )
129140 converted_state_dict ["time_embed.timestep_embedder.linear_2.weight" ] = original_state_dict .pop (
130141 "time_in.mlp.2.weight"
131142 )
132- converted_state_dict ["time_embed.timestep_embedder.linear_2.bias" ] = original_state_dict .pop (
133- "time_in.mlp.2.bias"
134- )
143+ converted_state_dict ["time_embed.timestep_embedder.linear_2.bias" ] = original_state_dict .pop ("time_in.mlp.2.bias" )
135144
136145 # 2. context_embedder.time_text_embed.timestep_embedder <- txt_in.t_embedder
137146 converted_state_dict ["context_embedder.time_text_embed.timestep_embedder.linear_1.weight" ] = (
138147 original_state_dict .pop ("txt_in.t_embedder.mlp.0.weight" )
139148 )
140- converted_state_dict ["context_embedder.time_text_embed.timestep_embedder.linear_1.bias" ] = (
141- original_state_dict . pop ( "txt_in.t_embedder.mlp.0.bias" )
149+ converted_state_dict ["context_embedder.time_text_embed.timestep_embedder.linear_1.bias" ] = original_state_dict . pop (
150+ "txt_in.t_embedder.mlp.0.bias"
142151 )
143152 converted_state_dict ["context_embedder.time_text_embed.timestep_embedder.linear_2.weight" ] = (
144153 original_state_dict .pop ("txt_in.t_embedder.mlp.2.weight" )
145154 )
146- converted_state_dict ["context_embedder.time_text_embed.timestep_embedder.linear_2.bias" ] = (
147- original_state_dict . pop ( "txt_in.t_embedder.mlp.2.bias" )
155+ converted_state_dict ["context_embedder.time_text_embed.timestep_embedder.linear_2.bias" ] = original_state_dict . pop (
156+ "txt_in.t_embedder.mlp.2.bias"
148157 )
149158
150159 # 3. context_embedder.time_text_embed.text_embedder <- txt_in.c_embedder
151- converted_state_dict ["context_embedder.time_text_embed.text_embedder.linear_1.weight" ] = (
152- original_state_dict . pop ( "txt_in.c_embedder.linear_1.weight" )
160+ converted_state_dict ["context_embedder.time_text_embed.text_embedder.linear_1.weight" ] = original_state_dict . pop (
161+ "txt_in.c_embedder.linear_1.weight"
153162 )
154- converted_state_dict ["context_embedder.time_text_embed.text_embedder.linear_1.bias" ] = (
155- original_state_dict . pop ( "txt_in.c_embedder.linear_1.bias" )
163+ converted_state_dict ["context_embedder.time_text_embed.text_embedder.linear_1.bias" ] = original_state_dict . pop (
164+ "txt_in.c_embedder.linear_1.bias"
156165 )
157- converted_state_dict ["context_embedder.time_text_embed.text_embedder.linear_2.weight" ] = (
158- original_state_dict . pop ( "txt_in.c_embedder.linear_2.weight" )
166+ converted_state_dict ["context_embedder.time_text_embed.text_embedder.linear_2.weight" ] = original_state_dict . pop (
167+ "txt_in.c_embedder.linear_2.weight"
159168 )
160- converted_state_dict ["context_embedder.time_text_embed.text_embedder.linear_2.bias" ] = (
161- original_state_dict . pop ( "txt_in.c_embedder.linear_2.bias" )
169+ converted_state_dict ["context_embedder.time_text_embed.text_embedder.linear_2.bias" ] = original_state_dict . pop (
170+ "txt_in.c_embedder.linear_2.bias"
162171 )
163172
164173 # 4. context_embedder.proj_in <- txt_in.input_embedder
165- converted_state_dict ["context_embedder.proj_in.weight" ] = original_state_dict .pop (
166- "txt_in.input_embedder.weight"
167- )
174+ converted_state_dict ["context_embedder.proj_in.weight" ] = original_state_dict .pop ("txt_in.input_embedder.weight" )
168175 converted_state_dict ["context_embedder.proj_in.bias" ] = original_state_dict .pop ("txt_in.input_embedder.bias" )
169176
170177 # 5. context_embedder.token_refiner <- txt_in.individual_token_refiner
@@ -375,10 +382,12 @@ def convert_hyvideo15_transformer_to_diffusers(original_state_dict):
375382 )
376383
377384 # 11. norm_out and proj_out <- final_layer
378- converted_state_dict ["norm_out.linear.weight" ] = swap_scale_shift (original_state_dict .pop (
379- "final_layer.adaLN_modulation.1.weight"
380- ))
381- converted_state_dict ["norm_out.linear.bias" ] = swap_scale_shift (original_state_dict .pop ("final_layer.adaLN_modulation.1.bias" ))
385+ converted_state_dict ["norm_out.linear.weight" ] = swap_scale_shift (
386+ original_state_dict .pop ("final_layer.adaLN_modulation.1.weight" )
387+ )
388+ converted_state_dict ["norm_out.linear.bias" ] = swap_scale_shift (
389+ original_state_dict .pop ("final_layer.adaLN_modulation.1.bias" )
390+ )
382391 converted_state_dict ["proj_out.weight" ] = original_state_dict .pop ("final_layer.linear.weight" )
383392 converted_state_dict ["proj_out.bias" ] = original_state_dict .pop ("final_layer.linear.bias" )
384393
@@ -572,6 +581,7 @@ def convert_hunyuan_video_15_vae_checkpoint_to_diffusers(
572581
573582 return converted
574583
584+
575585def load_sharded_safetensors (dir : pathlib .Path ):
576586 file_paths = list (dir .glob ("diffusion_pytorch_model*.safetensors" ))
577587 state_dict = {}
@@ -583,9 +593,9 @@ def load_sharded_safetensors(dir: pathlib.Path):
583593def load_original_transformer_state_dict (args ):
584594 if args .original_state_dict_repo_id is not None :
585595 model_dir = snapshot_download (
586- args .original_state_dict_repo_id ,
596+ args .original_state_dict_repo_id ,
587597 repo_type = "model" ,
588- allow_patterns = "transformer/" + args .transformer_type + "/*"
598+ allow_patterns = "transformer/" + args .transformer_type + "/*" ,
589599 )
590600 elif args .original_state_dict_folder is not None :
591601 model_dir = pathlib .Path (args .original_state_dict_folder )
@@ -599,8 +609,7 @@ def load_original_transformer_state_dict(args):
599609def load_original_vae_state_dict (args ):
600610 if args .original_state_dict_repo_id is not None :
601611 ckpt_path = hf_hub_download (
602- repo_id = args .original_state_dict_repo_id ,
603- filename = "vae/diffusion_pytorch_model.safetensors"
612+ repo_id = args .original_state_dict_repo_id , filename = "vae/diffusion_pytorch_model.safetensors"
604613 )
605614 elif args .original_state_dict_folder is not None :
606615 model_dir = pathlib .Path (args .original_state_dict_folder )
@@ -632,24 +641,27 @@ def convert_vae(args):
632641 vae .load_state_dict (state_dict , strict = True , assign = True )
633642 return vae
634643
644+
635645def load_mllm ():
636- print (f" loading from Qwen/Qwen2.5-VL-7B-Instruct" )
637- text_encoder = AutoModel .from_pretrained ("Qwen/Qwen2.5-VL-7B-Instruct" , torch_dtype = torch .bfloat16 ,low_cpu_mem_usage = True )
638- if hasattr (text_encoder , 'language_model' ):
646+ print (" loading from Qwen/Qwen2.5-VL-7B-Instruct" )
647+ text_encoder = AutoModel .from_pretrained (
648+ "Qwen/Qwen2.5-VL-7B-Instruct" , torch_dtype = torch .bfloat16 , low_cpu_mem_usage = True
649+ )
650+ if hasattr (text_encoder , "language_model" ):
639651 text_encoder = text_encoder .language_model
640652 tokenizer = AutoTokenizer .from_pretrained ("Qwen/Qwen2.5-VL-7B-Instruct" , padding_side = "right" )
641653 return text_encoder , tokenizer
642654
643655
644- #copied from https://github.com/Tencent-Hunyuan/HunyuanVideo-1.5/blob/910da2a829c484ea28982e8cff3bbc2cacdf1681/hyvideo/models/text_encoders/byT5/__init__.py#L89
656+ # copied from https://github.com/Tencent-Hunyuan/HunyuanVideo-1.5/blob/910da2a829c484ea28982e8cff3bbc2cacdf1681/hyvideo/models/text_encoders/byT5/__init__.py#L89
645657def add_special_token (
646658 tokenizer ,
647659 text_encoder ,
648660 add_color = True ,
649661 add_font = True ,
650662 multilingual = True ,
651- color_ann_path = ' assets/color_idx.json' ,
652- font_ann_path = ' assets/multilingual_10-lang_idx.json' ,
663+ color_ann_path = " assets/color_idx.json" ,
664+ font_ann_path = " assets/multilingual_10-lang_idx.json" ,
653665):
654666 """
655667 Add special tokens for color and font to tokenizer and text encoder.
@@ -663,16 +675,16 @@ def add_special_token(
663675 font_ann_path (str): Path to font annotation JSON.
664676 multilingual (bool): Whether to use multilingual font tokens.
665677 """
666- with open (font_ann_path , 'r' ) as f :
678+ with open (font_ann_path , "r" ) as f :
667679 idx_font_dict = json .load (f )
668- with open (color_ann_path , 'r' ) as f :
680+ with open (color_ann_path , "r" ) as f :
669681 idx_color_dict = json .load (f )
670682
671683 if multilingual :
672- font_token = [f' <{ font_code [:2 ]} -font-{ idx_font_dict [font_code ]} >' for font_code in idx_font_dict ]
684+ font_token = [f" <{ font_code [:2 ]} -font-{ idx_font_dict [font_code ]} >" for font_code in idx_font_dict ]
673685 else :
674- font_token = [f' <font-{ i } >' for i in range (len (idx_font_dict ))]
675- color_token = [f' <color-{ i } >' for i in range (len (idx_color_dict ))]
686+ font_token = [f" <font-{ i } >" for i in range (len (idx_font_dict ))]
687+ color_token = [f" <color-{ i } >" for i in range (len (idx_color_dict ))]
676688 additional_special_tokens = []
677689 if add_color :
678690 additional_special_tokens += color_token
@@ -688,14 +700,13 @@ def load_byt5(args):
688700 """
689701 Load ByT5 encoder with Glyph-SDXL-v2 weights and save in HuggingFace format.
690702 """
691-
692703
693704 # 1. Load base tokenizer and encoder
694705 tokenizer = AutoTokenizer .from_pretrained ("google/byt5-small" )
695-
706+
696707 # Load as T5EncoderModel
697708 encoder = T5EncoderModel .from_pretrained ("google/byt5-small" )
698-
709+
699710 byt5_checkpoint_path = os .path .join (args .byt5_path , "checkpoints/byt5_model.pt" )
700711 color_ann_path = os .path .join (args .byt5_path , "assets/color_idx.json" )
701712 font_ann_path = os .path .join (args .byt5_path , "assets/multilingual_10-lang_idx.json" )
@@ -710,48 +721,45 @@ def load_byt5(args):
710721 font_ann_path = font_ann_path ,
711722 multilingual = True ,
712723 )
713-
714-
724+
715725 # 3. Load Glyph-SDXL-v2 checkpoint
716726 print (f"\n 3. Loading Glyph-SDXL-v2 checkpoint: { byt5_checkpoint_path } " )
717- checkpoint = torch .load (byt5_checkpoint_path , map_location = ' cpu' )
718-
727+ checkpoint = torch .load (byt5_checkpoint_path , map_location = " cpu" )
728+
719729 # Handle different checkpoint formats
720- if ' state_dict' in checkpoint :
721- state_dict = checkpoint [' state_dict' ]
730+ if " state_dict" in checkpoint :
731+ state_dict = checkpoint [" state_dict" ]
722732 else :
723733 state_dict = checkpoint
724-
725- # add 'encoder.' prefix to the keys
734+
735+ # add 'encoder.' prefix to the keys
726736 # Remove 'module.text_tower.encoder.' prefix if present
727737 cleaned_state_dict = {}
728738 for key , value in state_dict .items ():
729- if key .startswith (' module.text_tower.encoder.' ):
730- new_key = ' encoder.' + key [len (' module.text_tower.encoder.' ) :]
739+ if key .startswith (" module.text_tower.encoder." ):
740+ new_key = " encoder." + key [len (" module.text_tower.encoder." ) :]
731741 cleaned_state_dict [new_key ] = value
732742 else :
733- new_key = ' encoder.' + key
743+ new_key = " encoder." + key
734744 cleaned_state_dict [new_key ] = value
735-
736-
745+
737746 # 4. Load weights
738747 missing_keys , unexpected_keys = encoder .load_state_dict (cleaned_state_dict , strict = False )
739748 if unexpected_keys :
740749 raise ValueError (f"Unexpected keys: { unexpected_keys } " )
741750 if "shared.weight" in missing_keys :
742- print (f " Missing shared.weight as expected" )
751+ print (" Missing shared.weight as expected" )
743752 missing_keys .remove ("shared.weight" )
744753 if missing_keys :
745754 raise ValueError (f"Missing keys: { missing_keys } " )
746-
747-
755+
748756 return encoder , tokenizer
749757
750758
751759def load_siglip ():
752760 image_encoder = SiglipVisionModel .from_pretrained (
753761 "black-forest-labs/FLUX.1-Redux-dev" , subfolder = "image_encoder" , torch_dtype = torch .bfloat16
754- )
762+ )
755763 feature_extractor = SiglipImageProcessor .from_pretrained (
756764 "black-forest-labs/FLUX.1-Redux-dev" , subfolder = "feature_extractor"
757765 )
@@ -763,11 +771,11 @@ def get_args():
763771 parser .add_argument (
764772 "--original_state_dict_repo_id" , type = str , default = None , help = "Path to original hub_id for the model"
765773 )
766- parser .add_argument ("--original_state_dict_folder" , type = str , default = None , help = "Local folder name of the original state dict" )
767- parser .add_argument ("--output_path" , type = str , required = True , help = "Path where converted model(s) should be saved" )
768774 parser .add_argument (
769- "--transformer_type " , type = str , default = "480p_i2v" , choices = list ( TRANSFORMER_CONFIGS . keys ())
775+ "--original_state_dict_folder " , type = str , default = None , help = "Local folder name of the original state dict"
770776 )
777+ parser .add_argument ("--output_path" , type = str , required = True , help = "Path where converted model(s) should be saved" )
778+ parser .add_argument ("--transformer_type" , type = str , default = "480p_i2v" , choices = list (TRANSFORMER_CONFIGS .keys ()))
771779 parser .add_argument (
772780 "--byt5_path" ,
773781 type = str ,
@@ -826,7 +834,7 @@ def get_args():
826834 feature_extractor = feature_extractor ,
827835 )
828836 elif task_type == "t2v" :
829- pipeline = HunyuanVideo15Text2VideoPipeline (
837+ pipeline = HunyuanVideo15Pipeline (
830838 vae = vae ,
831839 text_encoder = text_encoder ,
832840 text_encoder_2 = text_encoder_2 ,
@@ -840,6 +848,3 @@ def get_args():
840848 raise ValueError (f"Task type { task_type } is not supported" )
841849
842850 pipeline .save_pretrained (args .output_path , safe_serialization = True )
843-
844-
845-
0 commit comments