@@ -588,18 +588,23 @@ def _convert(original_key, diffusers_key, state_dict, new_state_dict):
588588 new_state_dict [diffusers_down_key .replace (".lora_A." , ".lora_B." )] = up_weight
589589
590590 all_unique_keys = {
591- k .replace (".lora_down.weight" , "" ).replace (".lora_up.weight" , "" ).replace (".alpha" , "" ) for k in state_dict
591+ k .replace (".lora_down.weight" , "" ).replace (".lora_up.weight" , "" ).replace (".alpha" , "" )
592+ for k in state_dict
593+ if not k .startswith (("lora_unet_" ))
592594 }
593- all_unique_keys = sorted (all_unique_keys )
594- assert all ("lora_transformer_" in k for k in all_unique_keys ), f"{ all_unique_keys = } "
595+ assert all (k .startswith (("lora_transformer_" , "lora_te1_" )) for k in all_unique_keys ), f"{ all_unique_keys = } "
595596
597+ has_te_keys = False
596598 for k in all_unique_keys :
597599 if k .startswith ("lora_transformer_single_transformer_blocks_" ):
598600 i = int (k .split ("lora_transformer_single_transformer_blocks_" )[- 1 ].split ("_" )[0 ])
599601 diffusers_key = f"single_transformer_blocks.{ i } "
600602 elif k .startswith ("lora_transformer_transformer_blocks_" ):
601603 i = int (k .split ("lora_transformer_transformer_blocks_" )[- 1 ].split ("_" )[0 ])
602604 diffusers_key = f"transformer_blocks.{ i } "
605+ elif k .startswith ("lora_te1_" ):
606+ has_te_keys = True
607+ continue
603608 else :
604609 raise NotImplementedError
605610
@@ -615,17 +620,57 @@ def _convert(original_key, diffusers_key, state_dict, new_state_dict):
615620 remaining = k .split ("attn_" )[- 1 ]
616621 diffusers_key += f".attn.{ remaining } "
617622
618- if diffusers_key == f"transformer_blocks.{ i } " :
619- print (k , diffusers_key )
620623 _convert (k , diffusers_key , state_dict , new_state_dict )
621624
625+ if has_te_keys :
626+ layer_pattern = re .compile (r"lora_te1_text_model_encoder_layers_(\d+)" )
627+ attn_mapping = {
628+ "q_proj" : ".self_attn.q_proj" ,
629+ "k_proj" : ".self_attn.k_proj" ,
630+ "v_proj" : ".self_attn.v_proj" ,
631+ "out_proj" : ".self_attn.out_proj" ,
632+ }
633+ mlp_mapping = {"fc1" : ".mlp.fc1" , "fc2" : ".mlp.fc2" }
634+ for k in all_unique_keys :
635+ if not k .startswith ("lora_te1_" ):
636+ continue
637+
638+ match = layer_pattern .search (k )
639+ if not match :
640+ continue
641+ i = int (match .group (1 ))
642+ diffusers_key = f"text_model.encoder.layers.{ i } "
643+
644+ if "attn" in k :
645+ for key_fragment , suffix in attn_mapping .items ():
646+ if key_fragment in k :
647+ diffusers_key += suffix
648+ break
649+ elif "mlp" in k :
650+ for key_fragment , suffix in mlp_mapping .items ():
651+ if key_fragment in k :
652+ diffusers_key += suffix
653+ break
654+
655+ _convert (k , diffusers_key , state_dict , new_state_dict )
656+
657+ if state_dict :
658+ remaining_all_unet = all (k .startswith ("lora_unet_" ) for k in state_dict )
659+ if remaining_all_unet :
660+ keys = list (state_dict .keys ())
661+ for k in keys :
662+ state_dict .pop (k )
663+
622664 if len (state_dict ) > 0 :
623665 raise ValueError (
624666 f"Expected an empty state dict at this point but its has these keys which couldn't be parsed: { list (state_dict .keys ())} ."
625667 )
626668
627- new_state_dict = {f"transformer.{ k } " : v for k , v in new_state_dict .items ()}
628- return new_state_dict
669+ transformer_state_dict = {
670+ f"transformer.{ k } " : v for k , v in new_state_dict .items () if not k .startswith ("text_model." )
671+ }
672+ te_state_dict = {f"text_encoder.{ k } " : v for k , v in new_state_dict .items () if k .startswith ("text_model." )}
673+ return {** transformer_state_dict , ** te_state_dict }
629674
630675 # This is weird.
631676 # https://huggingface.co/sayakpaul/different-lora-from-civitai/tree/main?show_file_info=sharp_detailed_foot.safetensors
@@ -640,6 +685,7 @@ def _convert(original_key, diffusers_key, state_dict, new_state_dict):
640685 )
641686 if has_mixture :
642687 return _convert_mixture_state_dict_to_diffusers (state_dict )
688+
643689 return _convert_sd_scripts_to_ai_toolkit (state_dict )
644690
645691
0 commit comments