@@ -21,7 +21,7 @@ def remap_qkv_(key: str, state_dict: Dict[str, Any]):
2121 state_dict [key .replace ("qkv.conv" , "to_qkv" )] = state_dict .pop (key )
2222
2323
24- VAE_KEYS_RENAME_DICT = {
24+ AE_KEYS_RENAME_DICT = {
2525 # common
2626 "main." : "" ,
2727 "op_list." : "" ,
@@ -51,7 +51,7 @@ def remap_qkv_(key: str, state_dict: Dict[str, Any]):
5151 "decoder.project_out.2.conv" : "decoder.conv_out" ,
5252}
5353
54- VAE_SPECIAL_KEYS_REMAP = {
54+ AE_SPECIAL_KEYS_REMAP = {
5555 "qkv.conv.weight" : remap_qkv_ ,
5656}
5757
@@ -71,9 +71,9 @@ def update_state_dict_(state_dict: Dict[str, Any], old_key: str, new_key: str) -
7171 state_dict [new_key ] = state_dict .pop (old_key )
7272
7373
74- def convert_vae (ckpt_path : str , dtype : torch .dtype ):
74+ def convert_ae (ckpt_path : str , dtype : torch .dtype ):
7575 original_state_dict = get_state_dict (load_file (ckpt_path ))
76- vae = AutoencoderDC (
76+ ae = AutoencoderDC (
7777 in_channels = 3 ,
7878 latent_channels = 32 ,
7979 encoder_block_types = (
@@ -106,21 +106,21 @@ def convert_vae(ckpt_path: str, dtype: torch.dtype):
106106
107107 for key in list (original_state_dict .keys ()):
108108 new_key = key [:]
109- for replace_key , rename_key in VAE_KEYS_RENAME_DICT .items ():
109+ for replace_key , rename_key in AE_KEYS_RENAME_DICT .items ():
110110 new_key = new_key .replace (replace_key , rename_key )
111111 update_state_dict_ (original_state_dict , key , new_key )
112112
113113 for key in list (original_state_dict .keys ()):
114- for special_key , handler_fn_inplace in VAE_SPECIAL_KEYS_REMAP .items ():
114+ for special_key , handler_fn_inplace in AE_SPECIAL_KEYS_REMAP .items ():
115115 if special_key not in key :
116116 continue
117117 handler_fn_inplace (key , original_state_dict )
118118
119- vae .load_state_dict (original_state_dict , strict = True )
120- return vae
119+ ae .load_state_dict (original_state_dict , strict = True )
120+ return ae
121121
122122
123- def get_vae_config (name : str ):
123+ def get_ae_config (name : str ):
124124 if name in ["dc-ae-f32c32-sana-1.0" ]:
125125 config = {
126126 "latent_channels" : 32 ,
@@ -245,7 +245,7 @@ def get_vae_config(name: str):
245245
246246def get_args ():
247247 parser = argparse .ArgumentParser ()
248- parser .add_argument ("--vae_ckpt_path " , type = str , default = None , help = "Path to original vae checkpoint" )
248+ parser .add_argument ("--ae_ckpt_path " , type = str , default = None , help = "Path to original ae checkpoint" )
249249 parser .add_argument ("--output_path" , type = str , required = True , help = "Path where converted model should be saved" )
250250 parser .add_argument ("--dtype" , default = "fp32" , help = "Torch dtype to save the model in." )
251251 return parser .parse_args ()
@@ -270,6 +270,6 @@ def get_args():
270270 dtype = DTYPE_MAPPING [args .dtype ]
271271 variant = VARIANT_MAPPING [args .dtype ]
272272
273- if args .vae_ckpt_path is not None :
274- vae = convert_vae (args .vae_ckpt_path , dtype )
275- vae .save_pretrained (args .output_path , safe_serialization = True , max_shard_size = "5GB" , variant = variant )
273+ if args .ae_ckpt_path is not None :
274+ ae = convert_ae (args .ae_ckpt_path , dtype )
275+ ae .save_pretrained (args .output_path , safe_serialization = True , max_shard_size = "5GB" , variant = variant )
0 commit comments