@@ -558,13 +558,88 @@ def _convert_sd_scripts_to_ai_toolkit(sds_sd):
558558        new_state_dict  =  {** ait_sd , ** te_state_dict }
559559        return  new_state_dict 
560560
561+     def  _convert_mixture_state_dict_to_diffusers (state_dict ):
562+         new_state_dict  =  {}
563+ 
564+         def  _convert (original_key , diffusers_key , state_dict , new_state_dict ):
565+             down_key  =  f"{ original_key }  
566+             down_weight  =  state_dict .pop (down_key )
567+             lora_rank  =  down_weight .shape [0 ]
568+ 
569+             up_weight_key  =  f"{ original_key }  
570+             up_weight  =  state_dict .pop (up_weight_key )
571+ 
572+             alpha_key  =  f"{ original_key }  
573+             alpha  =  state_dict .pop (alpha_key )
574+ 
575+             # scale weight by alpha and dim 
576+             scale  =  alpha  /  lora_rank 
577+             # calculate scale_down and scale_up 
578+             scale_down  =  scale 
579+             scale_up  =  1.0 
580+             while  scale_down  *  2  <  scale_up :
581+                 scale_down  *=  2 
582+                 scale_up  /=  2 
583+             down_weight  =  down_weight  *  scale_down 
584+             up_weight  =  up_weight  *  scale_up 
585+ 
586+             diffusers_down_key  =  f"{ diffusers_key }  
587+             new_state_dict [diffusers_down_key ] =  down_weight 
588+             new_state_dict [diffusers_down_key .replace (".lora_A." , ".lora_B." )] =  up_weight 
589+ 
590+         all_unique_keys  =  {
591+             k .replace (".lora_down.weight" , "" ).replace (".lora_up.weight" , "" ).replace (".alpha" , "" ) for  k  in  state_dict 
592+         }
593+         all_unique_keys  =  sorted (all_unique_keys )
594+         assert  all ("lora_transformer_"  in  k  for  k  in  all_unique_keys ), f"{ all_unique_keys = }  
595+ 
596+         for  k  in  all_unique_keys :
597+             if  k .startswith ("lora_transformer_single_transformer_blocks_" ):
598+                 i  =  int (k .split ("lora_transformer_single_transformer_blocks_" )[- 1 ].split ("_" )[0 ])
599+                 diffusers_key  =  f"single_transformer_blocks.{ i }  
600+             elif  k .startswith ("lora_transformer_transformer_blocks_" ):
601+                 i  =  int (k .split ("lora_transformer_transformer_blocks_" )[- 1 ].split ("_" )[0 ])
602+                 diffusers_key  =  f"transformer_blocks.{ i }  
603+             else :
604+                 raise  NotImplementedError 
605+ 
606+             if  "attn_"  in  k :
607+                 if  "_to_out_0"  in  k :
608+                     diffusers_key  +=  ".attn.to_out.0" 
609+                 elif  "_to_add_out"  in  k :
610+                     diffusers_key  +=  ".attn.to_add_out" 
611+                 elif  any (qkv  in  k  for  qkv  in  ["to_q" , "to_k" , "to_v" ]):
612+                     remaining  =  k .split ("attn_" )[- 1 ]
613+                     diffusers_key  +=  f".attn.{ remaining }  
614+                 elif  any (add_qkv  in  k  for  add_qkv  in  ["add_q_proj" , "add_k_proj" , "add_v_proj" ]):
615+                     remaining  =  k .split ("attn_" )[- 1 ]
616+                     diffusers_key  +=  f".attn.{ remaining }  
617+ 
618+             if  diffusers_key  ==  f"transformer_blocks.{ i }  :
619+                 print (k , diffusers_key )
620+             _convert (k , diffusers_key , state_dict , new_state_dict )
621+ 
622+         if  len (state_dict ) >  0 :
623+             raise  ValueError (
624+                 f"Expected an empty state dict at this point but its has these keys which couldn't be parsed: { list (state_dict .keys ())}  
625+             )
626+ 
627+         new_state_dict  =  {f"transformer.{ k }  : v  for  k , v  in  new_state_dict .items ()}
628+         return  new_state_dict 
629+ 
561630    # This is  weird. 
562631    # https://huggingface.co/sayakpaul/different-lora-from-civitai/tree/main?show_file_info=sharp_detailed_foot.safetensors 
563632    # has both `peft` and non-peft state dict. 
564633    has_peft_state_dict  =  any (k .startswith ("transformer." ) for  k  in  state_dict )
565634    if  has_peft_state_dict :
566635        state_dict  =  {k : v  for  k , v  in  state_dict .items () if  k .startswith ("transformer." )}
567636        return  state_dict 
637+     # Another weird one. 
638+     has_mixture  =  any (
639+         k .startswith ("lora_transformer_" ) and  ("lora_down"  in  k  or  "lora_up"  in  k  or  "alpha"  in  k ) for  k  in  state_dict 
640+     )
641+     if  has_mixture :
642+         return  _convert_mixture_state_dict_to_diffusers (state_dict )
568643    return  _convert_sd_scripts_to_ai_toolkit (state_dict )
569644
570645
0 commit comments