44import torch
55from accelerate import init_empty_weights
66
7- from diffusers import HunyuanVideoTransformer3DModel
7+ from diffusers import AutoencoderKLHunyuanVideo , HunyuanVideoTransformer3DModel
88
99
1010def remap_norm_scale_shift_ (key , state_dict ):
@@ -109,7 +109,9 @@ def remap_single_transformer_blocks_(key, state_dict):
109109 "single_blocks" : remap_single_transformer_blocks_ ,
110110}
111111
112- VAE_KEYS_RENAME_DICT = {}
112+ VAE_KEYS_RENAME_DICT = {
113+
114+ }
113115
114116VAE_SPECIAL_KEYS_REMAP = {}
115117
@@ -151,14 +153,37 @@ def convert_transformer(ckpt_path: str):
151153 return transformer
152154
153155
156+ def convert_vae (ckpt_path : str ):
157+ original_state_dict = get_state_dict (torch .load (ckpt_path , map_location = "cpu" , weights_only = True ))
158+
159+ with init_empty_weights ():
160+ vae = AutoencoderKLHunyuanVideo ()
161+
162+ for key in list (original_state_dict .keys ()):
163+ new_key = key [:]
164+ for replace_key , rename_key in VAE_KEYS_RENAME_DICT .items ():
165+ new_key = new_key .replace (replace_key , rename_key )
166+ update_state_dict_ (original_state_dict , key , new_key )
167+
168+ for key in list (original_state_dict .keys ()):
169+ for special_key , handler_fn_inplace in VAE_SPECIAL_KEYS_REMAP .items ():
170+ if special_key not in key :
171+ continue
172+ handler_fn_inplace (key , original_state_dict )
173+
174+ vae .load_state_dict (original_state_dict , strict = True , assign = True )
175+ return vae
176+
177+
154178def get_args ():
155179 parser = argparse .ArgumentParser ()
156180 parser .add_argument (
157181 "--transformer_ckpt_path" , type = str , default = None , help = "Path to original transformer checkpoint"
158182 )
183+ parser .add_argument ("--vae_ckpt_path" , type = str , default = None , help = "Path to original VAE checkpoint" )
159184 parser .add_argument ("--save_pipeline" , action = "store_true" )
160185 parser .add_argument ("--output_path" , type = str , required = True , help = "Path where converted model should be saved" )
161- parser .add_argument ("--dtype" , default = "bf16" , help = "Torch dtype to save the model in." )
186+ parser .add_argument ("--dtype" , default = "bf16" , help = "Torch dtype to save the transformer in." )
162187 return parser .parse_args ()
163188
164189
@@ -180,5 +205,11 @@ def get_args():
180205
181206 if args .transformer_ckpt_path is not None :
182207 transformer = convert_transformer (args .transformer_ckpt_path )
208+ transformer = transformer .to (dtype = dtype )
183209 if not args .save_pipeline :
184210 transformer .save_pretrained (args .output_path , safe_serialization = True , max_shard_size = "5GB" )
211+
212+ if args .vae_ckpt_path is not None :
213+ vae = convert_vae (args .vae_ckpt_path )
214+ if not args .save_pipeline :
215+ vae .save_pretrained (args .output_path , safe_serialization = True , max_shard_size = "5GB" )
0 commit comments