Skip to content

Commit 6915d62

Browse files
committed
remove helper functions in vae
1 parent e258480 commit 6915d62

File tree

4 files changed

+116
-214
lines changed

4 files changed

+116
-214
lines changed

scripts/convert_hunyuan_video_to_diffusers.py

Lines changed: 34 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import torch
55
from accelerate import init_empty_weights
66

7-
from diffusers import HunyuanVideoTransformer3DModel
7+
from diffusers import AutoencoderKLHunyuanVideo, HunyuanVideoTransformer3DModel
88

99

1010
def remap_norm_scale_shift_(key, state_dict):
@@ -109,7 +109,9 @@ def remap_single_transformer_blocks_(key, state_dict):
109109
"single_blocks": remap_single_transformer_blocks_,
110110
}
111111

112-
VAE_KEYS_RENAME_DICT = {}
112+
VAE_KEYS_RENAME_DICT = {
113+
114+
}
113115

114116
VAE_SPECIAL_KEYS_REMAP = {}
115117

@@ -151,14 +153,37 @@ def convert_transformer(ckpt_path: str):
151153
return transformer
152154

153155

156+
def convert_vae(ckpt_path: str):
157+
original_state_dict = get_state_dict(torch.load(ckpt_path, map_location="cpu", weights_only=True))
158+
159+
with init_empty_weights():
160+
vae = AutoencoderKLHunyuanVideo()
161+
162+
for key in list(original_state_dict.keys()):
163+
new_key = key[:]
164+
for replace_key, rename_key in VAE_KEYS_RENAME_DICT.items():
165+
new_key = new_key.replace(replace_key, rename_key)
166+
update_state_dict_(original_state_dict, key, new_key)
167+
168+
for key in list(original_state_dict.keys()):
169+
for special_key, handler_fn_inplace in VAE_SPECIAL_KEYS_REMAP.items():
170+
if special_key not in key:
171+
continue
172+
handler_fn_inplace(key, original_state_dict)
173+
174+
vae.load_state_dict(original_state_dict, strict=True, assign=True)
175+
return vae
176+
177+
154178
def get_args():
155179
parser = argparse.ArgumentParser()
156180
parser.add_argument(
157181
"--transformer_ckpt_path", type=str, default=None, help="Path to original transformer checkpoint"
158182
)
183+
parser.add_argument("--vae_ckpt_path", type=str, default=None, help="Path to original VAE checkpoint")
159184
parser.add_argument("--save_pipeline", action="store_true")
160185
parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved")
161-
parser.add_argument("--dtype", default="bf16", help="Torch dtype to save the model in.")
186+
parser.add_argument("--dtype", default="bf16", help="Torch dtype to save the transformer in.")
162187
return parser.parse_args()
163188

164189

@@ -180,5 +205,11 @@ def get_args():
180205

181206
if args.transformer_ckpt_path is not None:
182207
transformer = convert_transformer(args.transformer_ckpt_path)
208+
transformer = transformer.to(dtype=dtype)
183209
if not args.save_pipeline:
184210
transformer.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")
211+
212+
if args.vae_ckpt_path is not None:
213+
vae = convert_vae(args.vae_ckpt_path)
214+
if not args.save_pipeline:
215+
vae.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")

0 commit comments

Comments
 (0)