@@ -973,3 +973,178 @@ def swap_scale_shift(weight):
973973        converted_state_dict [f"transformer.{ key }  ] =  converted_state_dict .pop (key )
974974
975975    return  converted_state_dict 
976+ 
977+ 
978+ def  _convert_hunyuan_video_lora_to_diffusers (original_state_dict ):
979+     converted_state_dict  =  {k : original_state_dict .pop (k ) for  k  in  list (original_state_dict .keys ())}
980+ 
981+     def  remap_norm_scale_shift_ (key , state_dict ):
982+         weight  =  state_dict .pop (key )
983+         shift , scale  =  weight .chunk (2 , dim = 0 )
984+         new_weight  =  torch .cat ([scale , shift ], dim = 0 )
985+         state_dict [key .replace ("final_layer.adaLN_modulation.1" , "norm_out.linear" )] =  new_weight 
986+ 
987+     def  remap_txt_in_ (key , state_dict ):
988+         def  rename_key (key ):
989+             new_key  =  key .replace ("individual_token_refiner.blocks" , "token_refiner.refiner_blocks" )
990+             new_key  =  new_key .replace ("adaLN_modulation.1" , "norm_out.linear" )
991+             new_key  =  new_key .replace ("txt_in" , "context_embedder" )
992+             new_key  =  new_key .replace ("t_embedder.mlp.0" , "time_text_embed.timestep_embedder.linear_1" )
993+             new_key  =  new_key .replace ("t_embedder.mlp.2" , "time_text_embed.timestep_embedder.linear_2" )
994+             new_key  =  new_key .replace ("c_embedder" , "time_text_embed.text_embedder" )
995+             new_key  =  new_key .replace ("mlp" , "ff" )
996+             return  new_key 
997+ 
998+         if  "self_attn_qkv"  in  key :
999+             weight  =  state_dict .pop (key )
1000+             to_q , to_k , to_v  =  weight .chunk (3 , dim = 0 )
1001+             state_dict [rename_key (key .replace ("self_attn_qkv" , "attn.to_q" ))] =  to_q 
1002+             state_dict [rename_key (key .replace ("self_attn_qkv" , "attn.to_k" ))] =  to_k 
1003+             state_dict [rename_key (key .replace ("self_attn_qkv" , "attn.to_v" ))] =  to_v 
1004+         else :
1005+             state_dict [rename_key (key )] =  state_dict .pop (key )
1006+ 
1007+     def  remap_img_attn_qkv_ (key , state_dict ):
1008+         weight  =  state_dict .pop (key )
1009+         if  "lora_A"  in  key :
1010+             state_dict [key .replace ("img_attn_qkv" , "attn.to_q" )] =  weight 
1011+             state_dict [key .replace ("img_attn_qkv" , "attn.to_k" )] =  weight 
1012+             state_dict [key .replace ("img_attn_qkv" , "attn.to_v" )] =  weight 
1013+         else :
1014+             to_q , to_k , to_v  =  weight .chunk (3 , dim = 0 )
1015+             state_dict [key .replace ("img_attn_qkv" , "attn.to_q" )] =  to_q 
1016+             state_dict [key .replace ("img_attn_qkv" , "attn.to_k" )] =  to_k 
1017+             state_dict [key .replace ("img_attn_qkv" , "attn.to_v" )] =  to_v 
1018+ 
1019+     def  remap_txt_attn_qkv_ (key , state_dict ):
1020+         weight  =  state_dict .pop (key )
1021+         if  "lora_A"  in  key :
1022+             state_dict [key .replace ("txt_attn_qkv" , "attn.add_q_proj" )] =  weight 
1023+             state_dict [key .replace ("txt_attn_qkv" , "attn.add_k_proj" )] =  weight 
1024+             state_dict [key .replace ("txt_attn_qkv" , "attn.add_v_proj" )] =  weight 
1025+         else :
1026+             to_q , to_k , to_v  =  weight .chunk (3 , dim = 0 )
1027+             state_dict [key .replace ("txt_attn_qkv" , "attn.add_q_proj" )] =  to_q 
1028+             state_dict [key .replace ("txt_attn_qkv" , "attn.add_k_proj" )] =  to_k 
1029+             state_dict [key .replace ("txt_attn_qkv" , "attn.add_v_proj" )] =  to_v 
1030+ 
1031+     def  remap_single_transformer_blocks_ (key , state_dict ):
1032+         hidden_size  =  3072 
1033+ 
1034+         if  "linear1.lora_A.weight"  in  key  or  "linear1.lora_B.weight"  in  key :
1035+             linear1_weight  =  state_dict .pop (key )
1036+             if  "lora_A"  in  key :
1037+                 new_key  =  key .replace ("single_blocks" , "single_transformer_blocks" ).removesuffix (
1038+                     ".linear1.lora_A.weight" 
1039+                 )
1040+                 state_dict [f"{ new_key }  ] =  linear1_weight 
1041+                 state_dict [f"{ new_key }  ] =  linear1_weight 
1042+                 state_dict [f"{ new_key }  ] =  linear1_weight 
1043+                 state_dict [f"{ new_key }  ] =  linear1_weight 
1044+             else :
1045+                 split_size  =  (hidden_size , hidden_size , hidden_size , linear1_weight .size (0 ) -  3  *  hidden_size )
1046+                 q , k , v , mlp  =  torch .split (linear1_weight , split_size , dim = 0 )
1047+                 new_key  =  key .replace ("single_blocks" , "single_transformer_blocks" ).removesuffix (
1048+                     ".linear1.lora_B.weight" 
1049+                 )
1050+                 state_dict [f"{ new_key }  ] =  q 
1051+                 state_dict [f"{ new_key }  ] =  k 
1052+                 state_dict [f"{ new_key }  ] =  v 
1053+                 state_dict [f"{ new_key }  ] =  mlp 
1054+ 
1055+         elif  "linear1.lora_A.bias"  in  key  or  "linear1.lora_B.bias"  in  key :
1056+             linear1_bias  =  state_dict .pop (key )
1057+             if  "lora_A"  in  key :
1058+                 new_key  =  key .replace ("single_blocks" , "single_transformer_blocks" ).removesuffix (
1059+                     ".linear1.lora_A.bias" 
1060+                 )
1061+                 state_dict [f"{ new_key }  ] =  linear1_bias 
1062+                 state_dict [f"{ new_key }  ] =  linear1_bias 
1063+                 state_dict [f"{ new_key }  ] =  linear1_bias 
1064+                 state_dict [f"{ new_key }  ] =  linear1_bias 
1065+             else :
1066+                 split_size  =  (hidden_size , hidden_size , hidden_size , linear1_bias .size (0 ) -  3  *  hidden_size )
1067+                 q_bias , k_bias , v_bias , mlp_bias  =  torch .split (linear1_bias , split_size , dim = 0 )
1068+                 new_key  =  key .replace ("single_blocks" , "single_transformer_blocks" ).removesuffix (
1069+                     ".linear1.lora_B.bias" 
1070+                 )
1071+                 state_dict [f"{ new_key }  ] =  q_bias 
1072+                 state_dict [f"{ new_key }  ] =  k_bias 
1073+                 state_dict [f"{ new_key }  ] =  v_bias 
1074+                 state_dict [f"{ new_key }  ] =  mlp_bias 
1075+ 
1076+         else :
1077+             new_key  =  key .replace ("single_blocks" , "single_transformer_blocks" )
1078+             new_key  =  new_key .replace ("linear2" , "proj_out" )
1079+             new_key  =  new_key .replace ("q_norm" , "attn.norm_q" )
1080+             new_key  =  new_key .replace ("k_norm" , "attn.norm_k" )
1081+             state_dict [new_key ] =  state_dict .pop (key )
1082+ 
1083+     TRANSFORMER_KEYS_RENAME_DICT  =  {
1084+         "img_in" : "x_embedder" ,
1085+         "time_in.mlp.0" : "time_text_embed.timestep_embedder.linear_1" ,
1086+         "time_in.mlp.2" : "time_text_embed.timestep_embedder.linear_2" ,
1087+         "guidance_in.mlp.0" : "time_text_embed.guidance_embedder.linear_1" ,
1088+         "guidance_in.mlp.2" : "time_text_embed.guidance_embedder.linear_2" ,
1089+         "vector_in.in_layer" : "time_text_embed.text_embedder.linear_1" ,
1090+         "vector_in.out_layer" : "time_text_embed.text_embedder.linear_2" ,
1091+         "double_blocks" : "transformer_blocks" ,
1092+         "img_attn_q_norm" : "attn.norm_q" ,
1093+         "img_attn_k_norm" : "attn.norm_k" ,
1094+         "img_attn_proj" : "attn.to_out.0" ,
1095+         "txt_attn_q_norm" : "attn.norm_added_q" ,
1096+         "txt_attn_k_norm" : "attn.norm_added_k" ,
1097+         "txt_attn_proj" : "attn.to_add_out" ,
1098+         "img_mod.linear" : "norm1.linear" ,
1099+         "img_norm1" : "norm1.norm" ,
1100+         "img_norm2" : "norm2" ,
1101+         "img_mlp" : "ff" ,
1102+         "txt_mod.linear" : "norm1_context.linear" ,
1103+         "txt_norm1" : "norm1.norm" ,
1104+         "txt_norm2" : "norm2_context" ,
1105+         "txt_mlp" : "ff_context" ,
1106+         "self_attn_proj" : "attn.to_out.0" ,
1107+         "modulation.linear" : "norm.linear" ,
1108+         "pre_norm" : "norm.norm" ,
1109+         "final_layer.norm_final" : "norm_out.norm" ,
1110+         "final_layer.linear" : "proj_out" ,
1111+         "fc1" : "net.0.proj" ,
1112+         "fc2" : "net.2" ,
1113+         "input_embedder" : "proj_in" ,
1114+     }
1115+ 
1116+     TRANSFORMER_SPECIAL_KEYS_REMAP  =  {
1117+         "txt_in" : remap_txt_in_ ,
1118+         "img_attn_qkv" : remap_img_attn_qkv_ ,
1119+         "txt_attn_qkv" : remap_txt_attn_qkv_ ,
1120+         "single_blocks" : remap_single_transformer_blocks_ ,
1121+         "final_layer.adaLN_modulation.1" : remap_norm_scale_shift_ ,
1122+     }
1123+ 
1124+     # Some folks attempt to make their state dict compatible with diffusers by adding "transformer." prefix to all keys 
1125+     # and use their custom code. To make sure both "original" and "attempted diffusers" loras work as expected, we make 
1126+     # sure that both follow the same initial format by stripping off the "transformer." prefix. 
1127+     for  key  in  list (converted_state_dict .keys ()):
1128+         if  key .startswith ("transformer." ):
1129+             converted_state_dict [key [len ("transformer." ) :]] =  converted_state_dict .pop (key )
1130+         if  key .startswith ("diffusion_model." ):
1131+             converted_state_dict [key [len ("diffusion_model." ) :]] =  converted_state_dict .pop (key )
1132+ 
1133+     # Rename and remap the state dict keys 
1134+     for  key  in  list (converted_state_dict .keys ()):
1135+         new_key  =  key [:]
1136+         for  replace_key , rename_key  in  TRANSFORMER_KEYS_RENAME_DICT .items ():
1137+             new_key  =  new_key .replace (replace_key , rename_key )
1138+         converted_state_dict [new_key ] =  converted_state_dict .pop (key )
1139+ 
1140+     for  key  in  list (converted_state_dict .keys ()):
1141+         for  special_key , handler_fn_inplace  in  TRANSFORMER_SPECIAL_KEYS_REMAP .items ():
1142+             if  special_key  not  in key :
1143+                 continue 
1144+             handler_fn_inplace (key , converted_state_dict )
1145+ 
1146+     # Add back the "transformer." prefix 
1147+     for  key  in  list (converted_state_dict .keys ()):
1148+         converted_state_dict [f"transformer.{ key }  ] =  converted_state_dict .pop (key )
1149+ 
1150+     return  converted_state_dict 
0 commit comments