Skip to content

Commit 66fc85e

Browse files
committed
update
1 parent 825979d commit 66fc85e

File tree

2 files changed

+180
-0
lines changed

2 files changed

+180
-0
lines changed

src/diffusers/loaders/lora_conversion_utils.py

Lines changed: 175 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -973,3 +973,178 @@ def swap_scale_shift(weight):
973973
converted_state_dict[f"transformer.{key}"] = converted_state_dict.pop(key)
974974

975975
return converted_state_dict
976+
977+
978+
def _convert_hunyuan_video_lora_to_diffusers(original_state_dict):
979+
converted_state_dict = {k: original_state_dict.pop(k) for k in list(original_state_dict.keys())}
980+
981+
def remap_norm_scale_shift_(key, state_dict):
982+
weight = state_dict.pop(key)
983+
shift, scale = weight.chunk(2, dim=0)
984+
new_weight = torch.cat([scale, shift], dim=0)
985+
state_dict[key.replace("final_layer.adaLN_modulation.1", "norm_out.linear")] = new_weight
986+
987+
def remap_txt_in_(key, state_dict):
988+
def rename_key(key):
989+
new_key = key.replace("individual_token_refiner.blocks", "token_refiner.refiner_blocks")
990+
new_key = new_key.replace("adaLN_modulation.1", "norm_out.linear")
991+
new_key = new_key.replace("txt_in", "context_embedder")
992+
new_key = new_key.replace("t_embedder.mlp.0", "time_text_embed.timestep_embedder.linear_1")
993+
new_key = new_key.replace("t_embedder.mlp.2", "time_text_embed.timestep_embedder.linear_2")
994+
new_key = new_key.replace("c_embedder", "time_text_embed.text_embedder")
995+
new_key = new_key.replace("mlp", "ff")
996+
return new_key
997+
998+
if "self_attn_qkv" in key:
999+
weight = state_dict.pop(key)
1000+
to_q, to_k, to_v = weight.chunk(3, dim=0)
1001+
state_dict[rename_key(key.replace("self_attn_qkv", "attn.to_q"))] = to_q
1002+
state_dict[rename_key(key.replace("self_attn_qkv", "attn.to_k"))] = to_k
1003+
state_dict[rename_key(key.replace("self_attn_qkv", "attn.to_v"))] = to_v
1004+
else:
1005+
state_dict[rename_key(key)] = state_dict.pop(key)
1006+
1007+
def remap_img_attn_qkv_(key, state_dict):
1008+
weight = state_dict.pop(key)
1009+
if "lora_A" in key:
1010+
state_dict[key.replace("img_attn_qkv", "attn.to_q")] = weight
1011+
state_dict[key.replace("img_attn_qkv", "attn.to_k")] = weight
1012+
state_dict[key.replace("img_attn_qkv", "attn.to_v")] = weight
1013+
else:
1014+
to_q, to_k, to_v = weight.chunk(3, dim=0)
1015+
state_dict[key.replace("img_attn_qkv", "attn.to_q")] = to_q
1016+
state_dict[key.replace("img_attn_qkv", "attn.to_k")] = to_k
1017+
state_dict[key.replace("img_attn_qkv", "attn.to_v")] = to_v
1018+
1019+
def remap_txt_attn_qkv_(key, state_dict):
1020+
weight = state_dict.pop(key)
1021+
if "lora_A" in key:
1022+
state_dict[key.replace("txt_attn_qkv", "attn.add_q_proj")] = weight
1023+
state_dict[key.replace("txt_attn_qkv", "attn.add_k_proj")] = weight
1024+
state_dict[key.replace("txt_attn_qkv", "attn.add_v_proj")] = weight
1025+
else:
1026+
to_q, to_k, to_v = weight.chunk(3, dim=0)
1027+
state_dict[key.replace("txt_attn_qkv", "attn.add_q_proj")] = to_q
1028+
state_dict[key.replace("txt_attn_qkv", "attn.add_k_proj")] = to_k
1029+
state_dict[key.replace("txt_attn_qkv", "attn.add_v_proj")] = to_v
1030+
1031+
def remap_single_transformer_blocks_(key, state_dict):
1032+
hidden_size = 3072
1033+
1034+
if "linear1.lora_A.weight" in key or "linear1.lora_B.weight" in key:
1035+
linear1_weight = state_dict.pop(key)
1036+
if "lora_A" in key:
1037+
new_key = key.replace("single_blocks", "single_transformer_blocks").removesuffix(
1038+
".linear1.lora_A.weight"
1039+
)
1040+
state_dict[f"{new_key}.attn.to_q.lora_A.weight"] = linear1_weight
1041+
state_dict[f"{new_key}.attn.to_k.lora_A.weight"] = linear1_weight
1042+
state_dict[f"{new_key}.attn.to_v.lora_A.weight"] = linear1_weight
1043+
state_dict[f"{new_key}.proj_mlp.lora_A.weight"] = linear1_weight
1044+
else:
1045+
split_size = (hidden_size, hidden_size, hidden_size, linear1_weight.size(0) - 3 * hidden_size)
1046+
q, k, v, mlp = torch.split(linear1_weight, split_size, dim=0)
1047+
new_key = key.replace("single_blocks", "single_transformer_blocks").removesuffix(
1048+
".linear1.lora_B.weight"
1049+
)
1050+
state_dict[f"{new_key}.attn.to_q.lora_B.weight"] = q
1051+
state_dict[f"{new_key}.attn.to_k.lora_B.weight"] = k
1052+
state_dict[f"{new_key}.attn.to_v.lora_B.weight"] = v
1053+
state_dict[f"{new_key}.proj_mlp.lora_B.weight"] = mlp
1054+
1055+
elif "linear1.lora_A.bias" in key or "linear1.lora_B.bias" in key:
1056+
linear1_bias = state_dict.pop(key)
1057+
if "lora_A" in key:
1058+
new_key = key.replace("single_blocks", "single_transformer_blocks").removesuffix(
1059+
".linear1.lora_A.bias"
1060+
)
1061+
state_dict[f"{new_key}.attn.to_q.lora_A.bias"] = linear1_bias
1062+
state_dict[f"{new_key}.attn.to_k.lora_A.bias"] = linear1_bias
1063+
state_dict[f"{new_key}.attn.to_v.lora_A.bias"] = linear1_bias
1064+
state_dict[f"{new_key}.proj_mlp.lora_A.bias"] = linear1_bias
1065+
else:
1066+
split_size = (hidden_size, hidden_size, hidden_size, linear1_bias.size(0) - 3 * hidden_size)
1067+
q_bias, k_bias, v_bias, mlp_bias = torch.split(linear1_bias, split_size, dim=0)
1068+
new_key = key.replace("single_blocks", "single_transformer_blocks").removesuffix(
1069+
".linear1.lora_B.bias"
1070+
)
1071+
state_dict[f"{new_key}.attn.to_q.lora_B.bias"] = q_bias
1072+
state_dict[f"{new_key}.attn.to_k.lora_B.bias"] = k_bias
1073+
state_dict[f"{new_key}.attn.to_v.lora_B.bias"] = v_bias
1074+
state_dict[f"{new_key}.proj_mlp.lora_B.bias"] = mlp_bias
1075+
1076+
else:
1077+
new_key = key.replace("single_blocks", "single_transformer_blocks")
1078+
new_key = new_key.replace("linear2", "proj_out")
1079+
new_key = new_key.replace("q_norm", "attn.norm_q")
1080+
new_key = new_key.replace("k_norm", "attn.norm_k")
1081+
state_dict[new_key] = state_dict.pop(key)
1082+
1083+
TRANSFORMER_KEYS_RENAME_DICT = {
1084+
"img_in": "x_embedder",
1085+
"time_in.mlp.0": "time_text_embed.timestep_embedder.linear_1",
1086+
"time_in.mlp.2": "time_text_embed.timestep_embedder.linear_2",
1087+
"guidance_in.mlp.0": "time_text_embed.guidance_embedder.linear_1",
1088+
"guidance_in.mlp.2": "time_text_embed.guidance_embedder.linear_2",
1089+
"vector_in.in_layer": "time_text_embed.text_embedder.linear_1",
1090+
"vector_in.out_layer": "time_text_embed.text_embedder.linear_2",
1091+
"double_blocks": "transformer_blocks",
1092+
"img_attn_q_norm": "attn.norm_q",
1093+
"img_attn_k_norm": "attn.norm_k",
1094+
"img_attn_proj": "attn.to_out.0",
1095+
"txt_attn_q_norm": "attn.norm_added_q",
1096+
"txt_attn_k_norm": "attn.norm_added_k",
1097+
"txt_attn_proj": "attn.to_add_out",
1098+
"img_mod.linear": "norm1.linear",
1099+
"img_norm1": "norm1.norm",
1100+
"img_norm2": "norm2",
1101+
"img_mlp": "ff",
1102+
"txt_mod.linear": "norm1_context.linear",
1103+
"txt_norm1": "norm1.norm",
1104+
"txt_norm2": "norm2_context",
1105+
"txt_mlp": "ff_context",
1106+
"self_attn_proj": "attn.to_out.0",
1107+
"modulation.linear": "norm.linear",
1108+
"pre_norm": "norm.norm",
1109+
"final_layer.norm_final": "norm_out.norm",
1110+
"final_layer.linear": "proj_out",
1111+
"fc1": "net.0.proj",
1112+
"fc2": "net.2",
1113+
"input_embedder": "proj_in",
1114+
}
1115+
1116+
TRANSFORMER_SPECIAL_KEYS_REMAP = {
1117+
"txt_in": remap_txt_in_,
1118+
"img_attn_qkv": remap_img_attn_qkv_,
1119+
"txt_attn_qkv": remap_txt_attn_qkv_,
1120+
"single_blocks": remap_single_transformer_blocks_,
1121+
"final_layer.adaLN_modulation.1": remap_norm_scale_shift_,
1122+
}
1123+
1124+
# Some folks attempt to make their state dict compatible with diffusers by adding "transformer." prefix to all keys
1125+
# and use their custom code. To make sure both "original" and "attempted diffusers" loras work as expected, we make
1126+
# sure that both follow the same initial format by stripping off the "transformer." prefix.
1127+
for key in list(converted_state_dict.keys()):
1128+
if key.startswith("transformer."):
1129+
converted_state_dict[key[len("transformer.") :]] = converted_state_dict.pop(key)
1130+
if key.startswith("diffusion_model."):
1131+
converted_state_dict[key[len("diffusion_model.") :]] = converted_state_dict.pop(key)
1132+
1133+
# Rename and remap the state dict keys
1134+
for key in list(converted_state_dict.keys()):
1135+
new_key = key[:]
1136+
for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items():
1137+
new_key = new_key.replace(replace_key, rename_key)
1138+
converted_state_dict[new_key] = converted_state_dict.pop(key)
1139+
1140+
for key in list(converted_state_dict.keys()):
1141+
for special_key, handler_fn_inplace in TRANSFORMER_SPECIAL_KEYS_REMAP.items():
1142+
if special_key not in key:
1143+
continue
1144+
handler_fn_inplace(key, converted_state_dict)
1145+
1146+
# Add back the "transformer." prefix
1147+
for key in list(converted_state_dict.keys()):
1148+
converted_state_dict[f"transformer.{key}"] = converted_state_dict.pop(key)
1149+
1150+
return converted_state_dict

src/diffusers/loaders/lora_pipeline.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
from .lora_base import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE, LoraBaseMixin, _fetch_state_dict # noqa
3737
from .lora_conversion_utils import (
3838
_convert_bfl_flux_control_lora_to_diffusers,
39+
_convert_hunyuan_video_lora_to_diffusers,
3940
_convert_kohya_flux_lora_to_diffusers,
4041
_convert_non_diffusers_lora_to_diffusers,
4142
_convert_xlabs_flux_lora_to_diffusers,
@@ -4027,6 +4028,10 @@ def lora_state_dict(
40274028
logger.warning(warn_msg)
40284029
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
40294030

4031+
is_original_hunyuan_video = any("img_attn_qkv" in k for k in state_dict)
4032+
if is_original_hunyuan_video:
4033+
state_dict = _convert_hunyuan_video_lora_to_diffusers(state_dict)
4034+
40304035
return state_dict
40314036

40324037
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights

0 commit comments

Comments
 (0)