@@ -663,3 +663,248 @@ def handle_qkv(sds_sd, ait_sd, sds_key, ait_keys, dims=None):
663663        raise  ValueError (f"`old_state_dict` should be at this point but has: { list (old_state_dict .keys ())}  )
664664
665665    return  new_state_dict 
666+ 
667+ 
668+ def  _convert_non_diffusers_sd3_lora_to_diffusers (state_dict , prefix = None ):
669+     new_state_dict  =  {}
670+ 
671+     # in SD3 original implementation of AdaLayerNormContinuous, it split linear projection output into shift, scale; 
672+     # while in diffusers it split into scale, shift. Here we swap the linear projection weights in order to be able to use diffusers implementation 
673+     def  swap_scale_shift (weight ):
674+         shift , scale  =  weight .chunk (2 , dim = 0 )
675+         new_weight  =  torch .cat ([scale , shift ], dim = 0 )
676+         return  new_weight 
677+ 
678+     def  calculate_scales (key ):
679+         lora_rank  =  state_dict [f"{ key }  ].shape [0 ]
680+         alpha  =  state_dict .pop (key  +  ".alpha" )
681+         scale  =  alpha  /  lora_rank 
682+ 
683+         # calculate scale_down and scale_up 
684+         scale_down  =  scale 
685+         scale_up  =  1.0 
686+         while  scale_down  *  2  <  scale_up :
687+             scale_down  *=  2 
688+             scale_up  /=  2 
689+ 
690+         return  scale_down , scale_up 
691+ 
692+     def  weight_is_sparse (key , rank , num_splits , up_weight ):
693+         dims  =  [up_weight .shape [0 ] //  num_splits ] *  num_splits 
694+ 
695+         is_sparse  =  False 
696+         requested_rank  =  rank 
697+         if  rank  %  num_splits  ==  0 :
698+             requested_rank  =  rank  //  num_splits 
699+             is_sparse  =  True 
700+             i  =  0 
701+             for  j  in  range (len (dims )):
702+                 for  k  in  range (len (dims )):
703+                     if  j  ==  k :
704+                         continue 
705+                     is_sparse  =  is_sparse  and  torch .all (
706+                         up_weight [i  : i  +  dims [j ], k  *  requested_rank  : (k  +  1 ) *  requested_rank ] ==  0 
707+                     )
708+                 i  +=  dims [j ]
709+             if  is_sparse :
710+                 logger .info (f"weight is sparse: { key }  )
711+ 
712+         return  is_sparse , requested_rank 
713+ 
714+     # handle only transformer blocks for now. 
715+     layers  =  set ()
716+     for  k  in  state_dict :
717+         if  "joint_blocks"  in  k :
718+             idx  =  int (k .split ("_" , 4 )[- 1 ].split ("_" , 1 )[0 ])
719+             layers .add (idx )
720+     num_layers  =  max (layers ) +  1 
721+ 
722+     for  i  in  range (num_layers ):
723+         # norms 
724+         for  diffusers_key , orig_key  in  [
725+             (f"transformer_blocks.{ i }  , f"lora_unet_joint_blocks_{ i }  )
726+         ]:
727+             scale_down , scale_up  =  calculate_scales (orig_key )
728+             new_state_dict [f"{ diffusers_key }  ] =  (
729+                 state_dict .pop (f"{ orig_key }  ) *  scale_down 
730+             )
731+             new_state_dict [f"{ diffusers_key }  ] =  state_dict .pop (f"{ orig_key }  ) *  scale_up 
732+ 
733+         if  not  (i  ==  num_layers  -  1 ):
734+             for  diffusers_key , orig_key  in  [
735+                 (
736+                     f"transformer_blocks.{ i }  ,
737+                     f"lora_unet_joint_blocks_{ i }  ,
738+                 )
739+             ]:
740+                 scale_down , scale_up  =  calculate_scales (orig_key )
741+                 new_state_dict [f"{ diffusers_key }  ] =  (
742+                     state_dict .pop (f"{ orig_key }  ) *  scale_down 
743+                 )
744+                 new_state_dict [f"{ diffusers_key }  ] =  (
745+                     state_dict .pop (f"{ orig_key }  ) *  scale_up 
746+                 )
747+         else :
748+             for  diffusers_key , orig_key  in  [
749+                 (
750+                     f"transformer_blocks.{ i }  ,
751+                     f"lora_unet_joint_blocks_{ i }  ,
752+                 )
753+             ]:
754+                 scale_down , scale_up  =  calculate_scales (orig_key )
755+                 new_state_dict [f"{ diffusers_key }  ] =  (
756+                     swap_scale_shift (state_dict .pop (f"{ orig_key }  )) *  scale_down 
757+                 )
758+                 new_state_dict [f"{ diffusers_key }  ] =  (
759+                     swap_scale_shift (state_dict .pop (f"{ orig_key }  )) *  scale_up 
760+                 )
761+ 
762+         # output projections 
763+         for  diffusers_key , orig_key  in  [
764+             (f"transformer_blocks.{ i }  , f"lora_unet_joint_blocks_{ i }  )
765+         ]:
766+             scale_down , scale_up  =  calculate_scales (orig_key )
767+             new_state_dict [f"{ diffusers_key }  ] =  (
768+                 state_dict .pop (f"{ orig_key }  ) *  scale_down 
769+             )
770+             new_state_dict [f"{ diffusers_key }  ] =  state_dict .pop (f"{ orig_key }  ) *  scale_up 
771+         if  not  (i  ==  num_layers  -  1 ):
772+             for  diffusers_key , orig_key  in  [
773+                 (f"transformer_blocks.{ i }  , f"lora_unet_joint_blocks_{ i }  )
774+             ]:
775+                 scale_down , scale_up  =  calculate_scales (orig_key )
776+                 new_state_dict [f"{ diffusers_key }  ] =  (
777+                     state_dict .pop (f"{ orig_key }  ) *  scale_down 
778+                 )
779+                 new_state_dict [f"{ diffusers_key }  ] =  (
780+                     state_dict .pop (f"{ orig_key }  ) *  scale_up 
781+                 )
782+ 
783+         # ffs 
784+         for  diffusers_key , orig_key  in  [
785+             (f"transformer_blocks.{ i }  , f"lora_unet_joint_blocks_{ i }  )
786+         ]:
787+             scale_down , scale_up  =  calculate_scales (orig_key )
788+             new_state_dict [f"{ diffusers_key }  ] =  (
789+                 state_dict .pop (f"{ orig_key }  ) *  scale_down 
790+             )
791+             new_state_dict [f"{ diffusers_key }  ] =  state_dict .pop (f"{ orig_key }  ) *  scale_up 
792+ 
793+         for  diffusers_key , orig_key  in  [
794+             (f"transformer_blocks.{ i }  , f"lora_unet_joint_blocks_{ i }  )
795+         ]:
796+             scale_down , scale_up  =  calculate_scales (orig_key )
797+             new_state_dict [f"{ diffusers_key }  ] =  (
798+                 state_dict .pop (f"{ orig_key }  ) *  scale_down 
799+             )
800+             new_state_dict [f"{ diffusers_key }  ] =  state_dict .pop (f"{ orig_key }  ) *  scale_up 
801+ 
802+         if  not  (i  ==  num_layers  -  1 ):
803+             for  diffusers_key , orig_key  in  [
804+                 (f"transformer_blocks.{ i }  , f"lora_unet_joint_blocks_{ i }  )
805+             ]:
806+                 scale_down , scale_up  =  calculate_scales (orig_key )
807+                 new_state_dict [f"{ diffusers_key }  ] =  (
808+                     state_dict .pop (f"{ orig_key }  ) *  scale_down 
809+                 )
810+                 new_state_dict [f"{ diffusers_key }  ] =  (
811+                     state_dict .pop (f"{ orig_key }  ) *  scale_up 
812+                 )
813+ 
814+             for  diffusers_key , orig_key  in  [
815+                 (f"transformer_blocks.{ i }  , f"lora_unet_joint_blocks_{ i }  )
816+             ]:
817+                 scale_down , scale_up  =  calculate_scales (orig_key )
818+                 new_state_dict [f"{ diffusers_key }  ] =  (
819+                     state_dict .pop (f"{ orig_key }  ) *  scale_down 
820+                 )
821+                 new_state_dict [f"{ diffusers_key }  ] =  (
822+                     state_dict .pop (f"{ orig_key }  ) *  scale_up 
823+                 )
824+ 
825+         # core transformer blocks. 
826+         # sample blocks. 
827+         scale_down , scale_up  =  calculate_scales (f"lora_unet_joint_blocks_{ i }  )
828+         is_sparse , requested_rank  =  weight_is_sparse (
829+             key = f"lora_unet_joint_blocks_{ i }  ,
830+             rank = state_dict [f"lora_unet_joint_blocks_{ i }  ].shape [0 ],
831+             num_splits = 3 ,
832+             up_weight = state_dict [f"lora_unet_joint_blocks_{ i }  ],
833+         )
834+         num_splits  =  3 
835+         sample_qkv_lora_down  =  (
836+             state_dict .pop (f"lora_unet_joint_blocks_{ i }  ) *  scale_down 
837+         )
838+         sample_qkv_lora_up  =  state_dict .pop (f"lora_unet_joint_blocks_{ i }  ) *  scale_up 
839+         dims  =  [sample_qkv_lora_up .shape [0 ] //  num_splits ] *  num_splits   # 3 = num_splits 
840+         if  not  is_sparse :
841+             for  attn_k  in  ["to_q" , "to_k" , "to_v" ]:
842+                 new_state_dict [f"transformer_blocks.{ i } { attn_k }  ] =  sample_qkv_lora_down 
843+             for  attn_k , v  in  zip (["to_q" , "to_k" , "to_v" ], torch .split (sample_qkv_lora_up , dims , dim = 0 )):
844+                 new_state_dict [f"transformer_blocks.{ i } { attn_k }  ] =  v 
845+         else :
846+             # down_weight is chunked to each split 
847+             new_state_dict .update (
848+                 {
849+                     f"transformer_blocks.{ i } { k }  : v 
850+                     for  k , v  in  zip (["to_q" , "to_k" , "to_v" ], torch .chunk (sample_qkv_lora_down , num_splits , dim = 0 ))
851+                 }
852+             )  # noqa: C416 
853+ 
854+             # up_weight is sparse: only non-zero values are copied to each split 
855+             i  =  0 
856+             for  j , attn_k  in  enumerate (["to_q" , "to_k" , "to_v" ]):
857+                 new_state_dict [f"transformer_blocks.{ i } { attn_k }  ] =  sample_qkv_lora_up [
858+                     i  : i  +  dims [j ], j  *  requested_rank  : (j  +  1 ) *  requested_rank 
859+                 ].contiguous ()
860+                 i  +=  dims [j ]
861+ 
862+         # context blocks. 
863+         scale_down , scale_up  =  calculate_scales (f"lora_unet_joint_blocks_{ i }  )
864+         is_sparse , requested_rank  =  weight_is_sparse (
865+             key = f"lora_unet_joint_blocks_{ i }  ,
866+             rank = state_dict [f"lora_unet_joint_blocks_{ i }  ].shape [0 ],
867+             num_splits = 3 ,
868+             up_weight = state_dict [f"lora_unet_joint_blocks_{ i }  ],
869+         )
870+         num_splits  =  3 
871+         sample_qkv_lora_down  =  (
872+             state_dict .pop (f"lora_unet_joint_blocks_{ i }  ) *  scale_down 
873+         )
874+         sample_qkv_lora_up  =  (
875+             state_dict .pop (f"lora_unet_joint_blocks_{ i }  ) *  scale_up 
876+         )
877+         dims  =  [sample_qkv_lora_up .shape [0 ] //  num_splits ] *  num_splits   # 3 = num_splits 
878+         if  not  is_sparse :
879+             for  attn_k  in  ["add_q_proj" , "add_k_proj" , "add_v_proj" ]:
880+                 new_state_dict [f"transformer_blocks.{ i } { attn_k }  ] =  sample_qkv_lora_down 
881+             for  attn_k , v  in  zip (
882+                 ["add_q_proj" , "add_k_proj" , "add_v_proj" ], torch .split (sample_qkv_lora_up , dims , dim = 0 )
883+             ):
884+                 new_state_dict [f"transformer_blocks.{ i } { attn_k }  ] =  v 
885+         else :
886+             # down_weight is chunked to each split 
887+             new_state_dict .update (
888+                 {
889+                     f"transformer_blocks.{ i } { k }  : v 
890+                     for  k , v  in  zip (
891+                         ["add_q_proj" , "add_k_proj" , "add_v_proj" ],
892+                         torch .chunk (sample_qkv_lora_down , num_splits , dim = 0 ),
893+                     )
894+                 }
895+             )  # noqa: C416 
896+ 
897+             # up_weight is sparse: only non-zero values are copied to each split 
898+             i  =  0 
899+             for  j , attn_k  in  enumerate (["add_q_proj" , "add_k_proj" , "add_v_proj" ]):
900+                 new_state_dict [f"transformer_blocks.{ i } { attn_k }  ] =  sample_qkv_lora_up [
901+                     i  : i  +  dims [j ], j  *  requested_rank  : (j  +  1 ) *  requested_rank 
902+                 ].contiguous ()
903+                 i  +=  dims [j ]
904+ 
905+     if  len (state_dict ) >  0 :
906+         raise  ValueError (f"`state_dict` should be at this point but has: { list (state_dict .keys ())}  )
907+ 
908+     prefix  =  prefix  or  "transformer" 
909+     new_state_dict  =  {f"{ prefix } { k }  : v  for  k , v  in  new_state_dict .items ()}
910+     return  new_state_dict 
0 commit comments