33
44import torch
55from accelerate import init_empty_weights
6+ from transformers import AutoModel , AutoTokenizer , CLIPTextModel , CLIPTokenizer
67
7- from diffusers import AutoencoderKLHunyuanVideo , HunyuanVideoTransformer3DModel
8+ from diffusers import AutoencoderKLHunyuanVideo , HunyuanVideoTransformer3DModel , HunyuanVideoPipeline
89
910
1011def remap_norm_scale_shift_ (key , state_dict ):
@@ -76,6 +77,8 @@ def remap_single_transformer_blocks_(key, state_dict):
7677 # "guidance_in.mlp.2": "time_text_embed.guidance_embedder.linear_2",
7778 # "vector_in.in_layer": "time_text_embed.text_embedder.linear_1",
7879 # "vector_in.out_layer": "time_text_embed.text_embedder.linear_2",
80+ "txt_in.t_embedder" : "txt_in.time_embed" ,
81+ "txt_in.c_embedder" : "txt_in.context_embed" ,
7982 "double_blocks" : "transformer_blocks" ,
8083 "individual_token_refiner.blocks" : "token_refiner.refiner_blocks" ,
8184 "img_attn_q_norm" : "attn.norm_q" ,
@@ -179,6 +182,8 @@ def get_args():
179182 "--transformer_ckpt_path" , type = str , default = None , help = "Path to original transformer checkpoint"
180183 )
181184 parser .add_argument ("--vae_ckpt_path" , type = str , default = None , help = "Path to original VAE checkpoint" )
185+ parser .add_argument ("--text_encoder_path" , type = str , default = None , help = "Path to original llama checkpoint" )
186+ parser .add_argument ("--text_encoder_2_path" , type = str , default = None , help = "Path to original clip checkpoint" )
182187 parser .add_argument ("--save_pipeline" , action = "store_true" )
183188 parser .add_argument ("--output_path" , type = str , required = True , help = "Path where converted model should be saved" )
184189 parser .add_argument ("--dtype" , default = "bf16" , help = "Torch dtype to save the transformer in." )
@@ -200,6 +205,8 @@ def get_args():
200205
201206 if args .save_pipeline :
202207 assert args .transformer_ckpt_path is not None and args .vae_ckpt_path is not None
208+ assert args .text_encoder_path is not None
209+ assert args .text_encoder_2_path is not None
203210
204211 if args .transformer_ckpt_path is not None :
205212 transformer = convert_transformer (args .transformer_ckpt_path )
@@ -211,3 +218,19 @@ def get_args():
211218 vae = convert_vae (args .vae_ckpt_path )
212219 if not args .save_pipeline :
213220 vae .save_pretrained (args .output_path , safe_serialization = True , max_shard_size = "5GB" )
221+
222+ if args .save_pipeline :
223+ text_encoder = AutoModel .from_pretrained (args .text_encoder_path , torch_dtype = torch .float16 )
224+ tokenizer = AutoTokenizer .from_pretrained (args .text_encoder_path , padding_side = "right" )
225+ text_encoder_2 = CLIPTextModel .from_pretrained (args .text_encoder_2_path , torch_dtype = torch .float16 )
226+ tokenizer_2 = CLIPTokenizer .from_pretrained (args .text_encoder_2_path )
227+
228+ pipe = HunyuanVideoPipeline (
229+ transformer = transformer ,
230+ vae = vae ,
231+ text_encoder = text_encoder ,
232+ tokenizer = tokenizer ,
233+ text_encoder_2 = text_encoder_2 ,
234+ tokenizer_2 = tokenizer_2 ,
235+ )
236+ pipe .save_pretrained (args .output_path , safe_serialization = True , max_shard_size = "5GB" )
0 commit comments