@@ -588,18 +588,23 @@ def _convert(original_key, diffusers_key, state_dict, new_state_dict):
588588            new_state_dict [diffusers_down_key .replace (".lora_A." , ".lora_B." )] =  up_weight 
589589
590590        all_unique_keys  =  {
591-             k .replace (".lora_down.weight" , "" ).replace (".lora_up.weight" , "" ).replace (".alpha" , "" ) for  k  in  state_dict 
591+             k .replace (".lora_down.weight" , "" ).replace (".lora_up.weight" , "" ).replace (".alpha" , "" )
592+             for  k  in  state_dict 
593+             if  not  k .startswith (("lora_unet_" ))
592594        }
593-         all_unique_keys  =  sorted (all_unique_keys )
594-         assert  all ("lora_transformer_"  in  k  for  k  in  all_unique_keys ), f"{ all_unique_keys = }  " 
595+         assert  all (k .startswith (("lora_transformer_" , "lora_te1_" )) for  k  in  all_unique_keys ), f"{ all_unique_keys = }  " 
595596
597+         has_te_keys  =  False 
596598        for  k  in  all_unique_keys :
597599            if  k .startswith ("lora_transformer_single_transformer_blocks_" ):
598600                i  =  int (k .split ("lora_transformer_single_transformer_blocks_" )[- 1 ].split ("_" )[0 ])
599601                diffusers_key  =  f"single_transformer_blocks.{ i }  " 
600602            elif  k .startswith ("lora_transformer_transformer_blocks_" ):
601603                i  =  int (k .split ("lora_transformer_transformer_blocks_" )[- 1 ].split ("_" )[0 ])
602604                diffusers_key  =  f"transformer_blocks.{ i }  " 
605+             elif  k .startswith ("lora_te1_" ):
606+                 has_te_keys  =  True 
607+                 continue 
603608            else :
604609                raise  NotImplementedError 
605610
@@ -615,17 +620,57 @@ def _convert(original_key, diffusers_key, state_dict, new_state_dict):
615620                    remaining  =  k .split ("attn_" )[- 1 ]
616621                    diffusers_key  +=  f".attn.{ remaining }  " 
617622
618-             if  diffusers_key  ==  f"transformer_blocks.{ i }  " :
619-                 print (k , diffusers_key )
620623            _convert (k , diffusers_key , state_dict , new_state_dict )
621624
625+         if  has_te_keys :
626+             layer_pattern  =  re .compile (r"lora_te1_text_model_encoder_layers_(\d+)" )
627+             attn_mapping  =  {
628+                 "q_proj" : ".self_attn.q_proj" ,
629+                 "k_proj" : ".self_attn.k_proj" ,
630+                 "v_proj" : ".self_attn.v_proj" ,
631+                 "out_proj" : ".self_attn.out_proj" ,
632+             }
633+             mlp_mapping  =  {"fc1" : ".mlp.fc1" , "fc2" : ".mlp.fc2" }
634+             for  k  in  all_unique_keys :
635+                 if  not  k .startswith ("lora_te1_" ):
636+                     continue 
637+ 
638+                 match  =  layer_pattern .search (k )
639+                 if  not  match :
640+                     continue 
641+                 i  =  int (match .group (1 ))
642+                 diffusers_key  =  f"text_model.encoder.layers.{ i }  " 
643+ 
644+                 if  "attn"  in  k :
645+                     for  key_fragment , suffix  in  attn_mapping .items ():
646+                         if  key_fragment  in  k :
647+                             diffusers_key  +=  suffix 
648+                             break 
649+                 elif  "mlp"  in  k :
650+                     for  key_fragment , suffix  in  mlp_mapping .items ():
651+                         if  key_fragment  in  k :
652+                             diffusers_key  +=  suffix 
653+                             break 
654+ 
655+                 _convert (k , diffusers_key , state_dict , new_state_dict )
656+ 
657+         if  state_dict :
658+             remaining_all_unet  =  all (k .startswith ("lora_unet_" ) for  k  in  state_dict )
659+         if  remaining_all_unet :
660+             keys  =  list (state_dict .keys ())
661+             for  k  in  keys :
662+                 state_dict .pop (k )
663+ 
622664        if  len (state_dict ) >  0 :
623665            raise  ValueError (
624666                f"Expected an empty state dict at this point but its has these keys which couldn't be parsed: { list (state_dict .keys ())}  ." 
625667            )
626668
627-         new_state_dict  =  {f"transformer.{ k }  " : v  for  k , v  in  new_state_dict .items ()}
628-         return  new_state_dict 
669+         transformer_state_dict  =  {
670+             f"transformer.{ k }  " : v  for  k , v  in  new_state_dict .items () if  not  k .startswith ("text_model." )
671+         }
672+         te_state_dict  =  {f"text_encoder.{ k }  " : v  for  k , v  in  new_state_dict .items () if  k .startswith ("text_model." )}
673+         return  {** transformer_state_dict , ** te_state_dict }
629674
630675    # This is  weird. 
631676    # https://huggingface.co/sayakpaul/different-lora-from-civitai/tree/main?show_file_info=sharp_detailed_foot.safetensors 
@@ -640,6 +685,7 @@ def _convert(original_key, diffusers_key, state_dict, new_state_dict):
640685    )
641686    if  has_mixture :
642687        return  _convert_mixture_state_dict_to_diffusers (state_dict )
688+ 
643689    return  _convert_sd_scripts_to_ai_toolkit (state_dict )
644690
645691
0 commit comments