Skip to content

Commit 6b6b4bc

Browse files
gnobitabxingchaoliuyiyixuxu
authored
[Tencent Hunyuan Team] Add checkpoint conversion scripts and changed controlnet (#8783)
* add conversion files; changed controlnet for hunyuandit * style --------- Co-authored-by: xingchaoliu <[email protected]> Co-authored-by: yiyixuxu <[email protected]>
1 parent beb1c01 commit 6b6b4bc

File tree

3 files changed

+510
-0
lines changed

3 files changed

+510
-0
lines changed
Lines changed: 241 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,241 @@
1+
import argparse
2+
3+
import torch
4+
5+
from diffusers import HunyuanDiT2DControlNetModel
6+
7+
8+
def main(args):
9+
state_dict = torch.load(args.pt_checkpoint_path, map_location="cpu")
10+
11+
if args.load_key != "none":
12+
try:
13+
state_dict = state_dict[args.load_key]
14+
except KeyError:
15+
raise KeyError(
16+
f"{args.load_key} not found in the checkpoint."
17+
"Please load from the following keys:{state_dict.keys()}"
18+
)
19+
device = "cuda"
20+
21+
model_config = HunyuanDiT2DControlNetModel.load_config(
22+
"Tencent-Hunyuan/HunyuanDiT-v1.2-Diffusers", subfolder="transformer"
23+
)
24+
model_config[
25+
"use_style_cond_and_image_meta_size"
26+
] = args.use_style_cond_and_image_meta_size ### version <= v1.1: True; version >= v1.2: False
27+
print(model_config)
28+
29+
for key in state_dict:
30+
print("local:", key)
31+
32+
model = HunyuanDiT2DControlNetModel.from_config(model_config).to(device)
33+
34+
for key in model.state_dict():
35+
print("diffusers:", key)
36+
37+
num_layers = 19
38+
for i in range(num_layers):
39+
# attn1
40+
# Wkqv -> to_q, to_k, to_v
41+
q, k, v = torch.chunk(state_dict[f"blocks.{i}.attn1.Wqkv.weight"], 3, dim=0)
42+
q_bias, k_bias, v_bias = torch.chunk(state_dict[f"blocks.{i}.attn1.Wqkv.bias"], 3, dim=0)
43+
state_dict[f"blocks.{i}.attn1.to_q.weight"] = q
44+
state_dict[f"blocks.{i}.attn1.to_q.bias"] = q_bias
45+
state_dict[f"blocks.{i}.attn1.to_k.weight"] = k
46+
state_dict[f"blocks.{i}.attn1.to_k.bias"] = k_bias
47+
state_dict[f"blocks.{i}.attn1.to_v.weight"] = v
48+
state_dict[f"blocks.{i}.attn1.to_v.bias"] = v_bias
49+
state_dict.pop(f"blocks.{i}.attn1.Wqkv.weight")
50+
state_dict.pop(f"blocks.{i}.attn1.Wqkv.bias")
51+
52+
# q_norm, k_norm -> norm_q, norm_k
53+
state_dict[f"blocks.{i}.attn1.norm_q.weight"] = state_dict[f"blocks.{i}.attn1.q_norm.weight"]
54+
state_dict[f"blocks.{i}.attn1.norm_q.bias"] = state_dict[f"blocks.{i}.attn1.q_norm.bias"]
55+
state_dict[f"blocks.{i}.attn1.norm_k.weight"] = state_dict[f"blocks.{i}.attn1.k_norm.weight"]
56+
state_dict[f"blocks.{i}.attn1.norm_k.bias"] = state_dict[f"blocks.{i}.attn1.k_norm.bias"]
57+
58+
state_dict.pop(f"blocks.{i}.attn1.q_norm.weight")
59+
state_dict.pop(f"blocks.{i}.attn1.q_norm.bias")
60+
state_dict.pop(f"blocks.{i}.attn1.k_norm.weight")
61+
state_dict.pop(f"blocks.{i}.attn1.k_norm.bias")
62+
63+
# out_proj -> to_out
64+
state_dict[f"blocks.{i}.attn1.to_out.0.weight"] = state_dict[f"blocks.{i}.attn1.out_proj.weight"]
65+
state_dict[f"blocks.{i}.attn1.to_out.0.bias"] = state_dict[f"blocks.{i}.attn1.out_proj.bias"]
66+
state_dict.pop(f"blocks.{i}.attn1.out_proj.weight")
67+
state_dict.pop(f"blocks.{i}.attn1.out_proj.bias")
68+
69+
# attn2
70+
# kq_proj -> to_k, to_v
71+
k, v = torch.chunk(state_dict[f"blocks.{i}.attn2.kv_proj.weight"], 2, dim=0)
72+
k_bias, v_bias = torch.chunk(state_dict[f"blocks.{i}.attn2.kv_proj.bias"], 2, dim=0)
73+
state_dict[f"blocks.{i}.attn2.to_k.weight"] = k
74+
state_dict[f"blocks.{i}.attn2.to_k.bias"] = k_bias
75+
state_dict[f"blocks.{i}.attn2.to_v.weight"] = v
76+
state_dict[f"blocks.{i}.attn2.to_v.bias"] = v_bias
77+
state_dict.pop(f"blocks.{i}.attn2.kv_proj.weight")
78+
state_dict.pop(f"blocks.{i}.attn2.kv_proj.bias")
79+
80+
# q_proj -> to_q
81+
state_dict[f"blocks.{i}.attn2.to_q.weight"] = state_dict[f"blocks.{i}.attn2.q_proj.weight"]
82+
state_dict[f"blocks.{i}.attn2.to_q.bias"] = state_dict[f"blocks.{i}.attn2.q_proj.bias"]
83+
state_dict.pop(f"blocks.{i}.attn2.q_proj.weight")
84+
state_dict.pop(f"blocks.{i}.attn2.q_proj.bias")
85+
86+
# q_norm, k_norm -> norm_q, norm_k
87+
state_dict[f"blocks.{i}.attn2.norm_q.weight"] = state_dict[f"blocks.{i}.attn2.q_norm.weight"]
88+
state_dict[f"blocks.{i}.attn2.norm_q.bias"] = state_dict[f"blocks.{i}.attn2.q_norm.bias"]
89+
state_dict[f"blocks.{i}.attn2.norm_k.weight"] = state_dict[f"blocks.{i}.attn2.k_norm.weight"]
90+
state_dict[f"blocks.{i}.attn2.norm_k.bias"] = state_dict[f"blocks.{i}.attn2.k_norm.bias"]
91+
92+
state_dict.pop(f"blocks.{i}.attn2.q_norm.weight")
93+
state_dict.pop(f"blocks.{i}.attn2.q_norm.bias")
94+
state_dict.pop(f"blocks.{i}.attn2.k_norm.weight")
95+
state_dict.pop(f"blocks.{i}.attn2.k_norm.bias")
96+
97+
# out_proj -> to_out
98+
state_dict[f"blocks.{i}.attn2.to_out.0.weight"] = state_dict[f"blocks.{i}.attn2.out_proj.weight"]
99+
state_dict[f"blocks.{i}.attn2.to_out.0.bias"] = state_dict[f"blocks.{i}.attn2.out_proj.bias"]
100+
state_dict.pop(f"blocks.{i}.attn2.out_proj.weight")
101+
state_dict.pop(f"blocks.{i}.attn2.out_proj.bias")
102+
103+
# switch norm 2 and norm 3
104+
norm2_weight = state_dict[f"blocks.{i}.norm2.weight"]
105+
norm2_bias = state_dict[f"blocks.{i}.norm2.bias"]
106+
state_dict[f"blocks.{i}.norm2.weight"] = state_dict[f"blocks.{i}.norm3.weight"]
107+
state_dict[f"blocks.{i}.norm2.bias"] = state_dict[f"blocks.{i}.norm3.bias"]
108+
state_dict[f"blocks.{i}.norm3.weight"] = norm2_weight
109+
state_dict[f"blocks.{i}.norm3.bias"] = norm2_bias
110+
111+
# norm1 -> norm1.norm
112+
# default_modulation.1 -> norm1.linear
113+
state_dict[f"blocks.{i}.norm1.norm.weight"] = state_dict[f"blocks.{i}.norm1.weight"]
114+
state_dict[f"blocks.{i}.norm1.norm.bias"] = state_dict[f"blocks.{i}.norm1.bias"]
115+
state_dict[f"blocks.{i}.norm1.linear.weight"] = state_dict[f"blocks.{i}.default_modulation.1.weight"]
116+
state_dict[f"blocks.{i}.norm1.linear.bias"] = state_dict[f"blocks.{i}.default_modulation.1.bias"]
117+
state_dict.pop(f"blocks.{i}.norm1.weight")
118+
state_dict.pop(f"blocks.{i}.norm1.bias")
119+
state_dict.pop(f"blocks.{i}.default_modulation.1.weight")
120+
state_dict.pop(f"blocks.{i}.default_modulation.1.bias")
121+
122+
# mlp.fc1 -> ff.net.0, mlp.fc2 -> ff.net.2
123+
state_dict[f"blocks.{i}.ff.net.0.proj.weight"] = state_dict[f"blocks.{i}.mlp.fc1.weight"]
124+
state_dict[f"blocks.{i}.ff.net.0.proj.bias"] = state_dict[f"blocks.{i}.mlp.fc1.bias"]
125+
state_dict[f"blocks.{i}.ff.net.2.weight"] = state_dict[f"blocks.{i}.mlp.fc2.weight"]
126+
state_dict[f"blocks.{i}.ff.net.2.bias"] = state_dict[f"blocks.{i}.mlp.fc2.bias"]
127+
state_dict.pop(f"blocks.{i}.mlp.fc1.weight")
128+
state_dict.pop(f"blocks.{i}.mlp.fc1.bias")
129+
state_dict.pop(f"blocks.{i}.mlp.fc2.weight")
130+
state_dict.pop(f"blocks.{i}.mlp.fc2.bias")
131+
132+
# after_proj_list -> controlnet_blocks
133+
state_dict[f"controlnet_blocks.{i}.weight"] = state_dict[f"after_proj_list.{i}.weight"]
134+
state_dict[f"controlnet_blocks.{i}.bias"] = state_dict[f"after_proj_list.{i}.bias"]
135+
state_dict.pop(f"after_proj_list.{i}.weight")
136+
state_dict.pop(f"after_proj_list.{i}.bias")
137+
138+
# before_proj -> input_block
139+
state_dict["input_block.weight"] = state_dict["before_proj.weight"]
140+
state_dict["input_block.bias"] = state_dict["before_proj.bias"]
141+
state_dict.pop("before_proj.weight")
142+
state_dict.pop("before_proj.bias")
143+
144+
# pooler -> time_extra_emb
145+
state_dict["time_extra_emb.pooler.positional_embedding"] = state_dict["pooler.positional_embedding"]
146+
state_dict["time_extra_emb.pooler.k_proj.weight"] = state_dict["pooler.k_proj.weight"]
147+
state_dict["time_extra_emb.pooler.k_proj.bias"] = state_dict["pooler.k_proj.bias"]
148+
state_dict["time_extra_emb.pooler.q_proj.weight"] = state_dict["pooler.q_proj.weight"]
149+
state_dict["time_extra_emb.pooler.q_proj.bias"] = state_dict["pooler.q_proj.bias"]
150+
state_dict["time_extra_emb.pooler.v_proj.weight"] = state_dict["pooler.v_proj.weight"]
151+
state_dict["time_extra_emb.pooler.v_proj.bias"] = state_dict["pooler.v_proj.bias"]
152+
state_dict["time_extra_emb.pooler.c_proj.weight"] = state_dict["pooler.c_proj.weight"]
153+
state_dict["time_extra_emb.pooler.c_proj.bias"] = state_dict["pooler.c_proj.bias"]
154+
state_dict.pop("pooler.k_proj.weight")
155+
state_dict.pop("pooler.k_proj.bias")
156+
state_dict.pop("pooler.q_proj.weight")
157+
state_dict.pop("pooler.q_proj.bias")
158+
state_dict.pop("pooler.v_proj.weight")
159+
state_dict.pop("pooler.v_proj.bias")
160+
state_dict.pop("pooler.c_proj.weight")
161+
state_dict.pop("pooler.c_proj.bias")
162+
state_dict.pop("pooler.positional_embedding")
163+
164+
# t_embedder -> time_embedding (`TimestepEmbedding`)
165+
state_dict["time_extra_emb.timestep_embedder.linear_1.bias"] = state_dict["t_embedder.mlp.0.bias"]
166+
state_dict["time_extra_emb.timestep_embedder.linear_1.weight"] = state_dict["t_embedder.mlp.0.weight"]
167+
state_dict["time_extra_emb.timestep_embedder.linear_2.bias"] = state_dict["t_embedder.mlp.2.bias"]
168+
state_dict["time_extra_emb.timestep_embedder.linear_2.weight"] = state_dict["t_embedder.mlp.2.weight"]
169+
170+
state_dict.pop("t_embedder.mlp.0.bias")
171+
state_dict.pop("t_embedder.mlp.0.weight")
172+
state_dict.pop("t_embedder.mlp.2.bias")
173+
state_dict.pop("t_embedder.mlp.2.weight")
174+
175+
# x_embedder -> pos_embd (`PatchEmbed`)
176+
state_dict["pos_embed.proj.weight"] = state_dict["x_embedder.proj.weight"]
177+
state_dict["pos_embed.proj.bias"] = state_dict["x_embedder.proj.bias"]
178+
state_dict.pop("x_embedder.proj.weight")
179+
state_dict.pop("x_embedder.proj.bias")
180+
181+
# mlp_t5 -> text_embedder
182+
state_dict["text_embedder.linear_1.bias"] = state_dict["mlp_t5.0.bias"]
183+
state_dict["text_embedder.linear_1.weight"] = state_dict["mlp_t5.0.weight"]
184+
state_dict["text_embedder.linear_2.bias"] = state_dict["mlp_t5.2.bias"]
185+
state_dict["text_embedder.linear_2.weight"] = state_dict["mlp_t5.2.weight"]
186+
state_dict.pop("mlp_t5.0.bias")
187+
state_dict.pop("mlp_t5.0.weight")
188+
state_dict.pop("mlp_t5.2.bias")
189+
state_dict.pop("mlp_t5.2.weight")
190+
191+
# extra_embedder -> extra_embedder
192+
state_dict["time_extra_emb.extra_embedder.linear_1.bias"] = state_dict["extra_embedder.0.bias"]
193+
state_dict["time_extra_emb.extra_embedder.linear_1.weight"] = state_dict["extra_embedder.0.weight"]
194+
state_dict["time_extra_emb.extra_embedder.linear_2.bias"] = state_dict["extra_embedder.2.bias"]
195+
state_dict["time_extra_emb.extra_embedder.linear_2.weight"] = state_dict["extra_embedder.2.weight"]
196+
state_dict.pop("extra_embedder.0.bias")
197+
state_dict.pop("extra_embedder.0.weight")
198+
state_dict.pop("extra_embedder.2.bias")
199+
state_dict.pop("extra_embedder.2.weight")
200+
201+
# style_embedder
202+
if model_config["use_style_cond_and_image_meta_size"]:
203+
print(state_dict["style_embedder.weight"])
204+
print(state_dict["style_embedder.weight"].shape)
205+
state_dict["time_extra_emb.style_embedder.weight"] = state_dict["style_embedder.weight"][0:1]
206+
state_dict.pop("style_embedder.weight")
207+
208+
model.load_state_dict(state_dict)
209+
210+
if args.save:
211+
model.save_pretrained(args.output_checkpoint_path)
212+
213+
214+
if __name__ == "__main__":
215+
parser = argparse.ArgumentParser()
216+
217+
parser.add_argument(
218+
"--save", default=True, type=bool, required=False, help="Whether to save the converted pipeline or not."
219+
)
220+
parser.add_argument(
221+
"--pt_checkpoint_path", default=None, type=str, required=True, help="Path to the .pt pretrained model."
222+
)
223+
parser.add_argument(
224+
"--output_checkpoint_path",
225+
default=None,
226+
type=str,
227+
required=False,
228+
help="Path to the output converted diffusers pipeline.",
229+
)
230+
parser.add_argument(
231+
"--load_key", default="none", type=str, required=False, help="The key to load from the pretrained .pt file"
232+
)
233+
parser.add_argument(
234+
"--use_style_cond_and_image_meta_size",
235+
type=bool,
236+
default=False,
237+
help="version <= v1.1: True; version >= v1.2: False",
238+
)
239+
240+
args = parser.parse_args()
241+
main(args)

0 commit comments

Comments
 (0)