Skip to content

Commit ca3ac4d

Browse files
committed
replace vae with ae
1 parent 65edfa5 commit ca3ac4d

File tree

1 file changed

+13
-13
lines changed

1 file changed

+13
-13
lines changed

scripts/convert_dcae_to_diffusers.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ def remap_qkv_(key: str, state_dict: Dict[str, Any]):
2121
state_dict[key.replace("qkv.conv", "to_qkv")] = state_dict.pop(key)
2222

2323

24-
VAE_KEYS_RENAME_DICT = {
24+
AE_KEYS_RENAME_DICT = {
2525
# common
2626
"main.": "",
2727
"op_list.": "",
@@ -51,7 +51,7 @@ def remap_qkv_(key: str, state_dict: Dict[str, Any]):
5151
"decoder.project_out.2.conv": "decoder.conv_out",
5252
}
5353

54-
VAE_SPECIAL_KEYS_REMAP = {
54+
AE_SPECIAL_KEYS_REMAP = {
5555
"qkv.conv.weight": remap_qkv_,
5656
}
5757

@@ -71,9 +71,9 @@ def update_state_dict_(state_dict: Dict[str, Any], old_key: str, new_key: str) -
7171
state_dict[new_key] = state_dict.pop(old_key)
7272

7373

74-
def convert_vae(ckpt_path: str, dtype: torch.dtype):
74+
def convert_ae(ckpt_path: str, dtype: torch.dtype):
7575
original_state_dict = get_state_dict(load_file(ckpt_path))
76-
vae = AutoencoderDC(
76+
ae = AutoencoderDC(
7777
in_channels=3,
7878
latent_channels=32,
7979
encoder_block_types=(
@@ -106,21 +106,21 @@ def convert_vae(ckpt_path: str, dtype: torch.dtype):
106106

107107
for key in list(original_state_dict.keys()):
108108
new_key = key[:]
109-
for replace_key, rename_key in VAE_KEYS_RENAME_DICT.items():
109+
for replace_key, rename_key in AE_KEYS_RENAME_DICT.items():
110110
new_key = new_key.replace(replace_key, rename_key)
111111
update_state_dict_(original_state_dict, key, new_key)
112112

113113
for key in list(original_state_dict.keys()):
114-
for special_key, handler_fn_inplace in VAE_SPECIAL_KEYS_REMAP.items():
114+
for special_key, handler_fn_inplace in AE_SPECIAL_KEYS_REMAP.items():
115115
if special_key not in key:
116116
continue
117117
handler_fn_inplace(key, original_state_dict)
118118

119-
vae.load_state_dict(original_state_dict, strict=True)
120-
return vae
119+
ae.load_state_dict(original_state_dict, strict=True)
120+
return ae
121121

122122

123-
def get_vae_config(name: str):
123+
def get_ae_config(name: str):
124124
if name in ["dc-ae-f32c32-sana-1.0"]:
125125
config = {
126126
"latent_channels": 32,
@@ -245,7 +245,7 @@ def get_vae_config(name: str):
245245

246246
def get_args():
247247
parser = argparse.ArgumentParser()
248-
parser.add_argument("--vae_ckpt_path", type=str, default=None, help="Path to original vae checkpoint")
248+
parser.add_argument("--ae_ckpt_path", type=str, default=None, help="Path to original ae checkpoint")
249249
parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved")
250250
parser.add_argument("--dtype", default="fp32", help="Torch dtype to save the model in.")
251251
return parser.parse_args()
@@ -270,6 +270,6 @@ def get_args():
270270
dtype = DTYPE_MAPPING[args.dtype]
271271
variant = VARIANT_MAPPING[args.dtype]
272272

273-
if args.vae_ckpt_path is not None:
274-
vae = convert_vae(args.vae_ckpt_path, dtype)
275-
vae.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB", variant=variant)
273+
if args.ae_ckpt_path is not None:
274+
ae = convert_ae(args.ae_ckpt_path, dtype)
275+
ae.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB", variant=variant)

0 commit comments

Comments
 (0)