@@ -663,3 +663,309 @@ def handle_qkv(sds_sd, ait_sd, sds_key, ait_keys, dims=None):
663663        raise  ValueError (f"`old_state_dict` should be at this point but has: { list (old_state_dict .keys ())}  )
664664
665665    return  new_state_dict 
666+ 
667+ 
668+ def  _convert_bfl_flux_control_lora_to_diffusers (original_state_dict ):
669+     converted_state_dict  =  {}
670+     original_state_dict_keys  =  list (original_state_dict .keys ())
671+     num_layers  =  19 
672+     num_single_layers  =  38 
673+     inner_dim  =  3072 
674+     mlp_ratio  =  4.0 
675+ 
676+     def  swap_scale_shift (weight ):
677+         shift , scale  =  weight .chunk (2 , dim = 0 )
678+         new_weight  =  torch .cat ([scale , shift ], dim = 0 )
679+         return  new_weight 
680+ 
681+     for  lora_key  in  ["lora_A" , "lora_B" ]:
682+         ## time_text_embed.timestep_embedder <-  time_in 
683+         converted_state_dict [
684+             f"time_text_embed.timestep_embedder.linear_1.{ lora_key }  
685+         ] =  original_state_dict .pop (f"time_in.in_layer.{ lora_key }  )
686+         if  f"time_in.in_layer.{ lora_key }   in  original_state_dict_keys :
687+             converted_state_dict [
688+                 f"time_text_embed.timestep_embedder.linear_1.{ lora_key }  
689+             ] =  original_state_dict .pop (f"time_in.in_layer.{ lora_key }  )
690+ 
691+         converted_state_dict [
692+             f"time_text_embed.timestep_embedder.linear_2.{ lora_key }  
693+         ] =  original_state_dict .pop (f"time_in.out_layer.{ lora_key }  )
694+         if  f"time_in.out_layer.{ lora_key }   in  original_state_dict_keys :
695+             converted_state_dict [
696+                 f"time_text_embed.timestep_embedder.linear_2.{ lora_key }  
697+             ] =  original_state_dict .pop (f"time_in.out_layer.{ lora_key }  )
698+ 
699+         ## time_text_embed.text_embedder <- vector_in 
700+         converted_state_dict [f"time_text_embed.text_embedder.linear_1.{ lora_key }  ] =  original_state_dict .pop (
701+             f"vector_in.in_layer.{ lora_key }  
702+         )
703+         if  f"vector_in.in_layer.{ lora_key }   in  original_state_dict_keys :
704+             converted_state_dict [f"time_text_embed.text_embedder.linear_1.{ lora_key }  ] =  original_state_dict .pop (
705+                 f"vector_in.in_layer.{ lora_key }  
706+             )
707+ 
708+         converted_state_dict [f"time_text_embed.text_embedder.linear_2.{ lora_key }  ] =  original_state_dict .pop (
709+             f"vector_in.out_layer.{ lora_key }  
710+         )
711+         if  f"vector_in.out_layer.{ lora_key }   in  original_state_dict_keys :
712+             converted_state_dict [f"time_text_embed.text_embedder.linear_2.{ lora_key }  ] =  original_state_dict .pop (
713+                 f"vector_in.out_layer.{ lora_key }  
714+             )
715+ 
716+         # guidance 
717+         has_guidance  =  any ("guidance"  in  k  for  k  in  original_state_dict )
718+         if  has_guidance :
719+             converted_state_dict [
720+                 f"time_text_embed.guidance_embedder.linear_1.{ lora_key }  
721+             ] =  original_state_dict .pop (f"guidance_in.in_layer.{ lora_key }  )
722+             if  f"guidance_in.in_layer.{ lora_key }   in  original_state_dict_keys :
723+                 converted_state_dict [
724+                     f"time_text_embed.guidance_embedder.linear_1.{ lora_key }  
725+                 ] =  original_state_dict .pop (f"guidance_in.in_layer.{ lora_key }  )
726+ 
727+             converted_state_dict [
728+                 f"time_text_embed.guidance_embedder.linear_2.{ lora_key }  
729+             ] =  original_state_dict .pop (f"guidance_in.out_layer.{ lora_key }  )
730+             if  f"guidance_in.out_layer.{ lora_key }   in  original_state_dict_keys :
731+                 converted_state_dict [
732+                     f"time_text_embed.guidance_embedder.linear_2.{ lora_key }  
733+                 ] =  original_state_dict .pop (f"guidance_in.out_layer.{ lora_key }  )
734+ 
735+         # context_embedder 
736+         converted_state_dict [f"context_embedder.{ lora_key }  ] =  original_state_dict .pop (
737+             f"txt_in.{ lora_key }  
738+         )
739+         if  f"txt_in.{ lora_key }   in  original_state_dict_keys :
740+             converted_state_dict [f"context_embedder.{ lora_key }  ] =  original_state_dict .pop (
741+                 f"txt_in.{ lora_key }  
742+             )
743+ 
744+         # x_embedder 
745+         converted_state_dict [f"x_embedder.{ lora_key }  ] =  original_state_dict .pop (f"img_in.{ lora_key }  )
746+         if  f"img_in.{ lora_key }   in  original_state_dict_keys :
747+             converted_state_dict [f"x_embedder.{ lora_key }  ] =  original_state_dict .pop (f"img_in.{ lora_key }  )
748+ 
749+     # double transformer blocks 
750+     for  i  in  range (num_layers ):
751+         block_prefix  =  f"transformer_blocks.{ i }  
752+ 
753+         for  lora_key , lora_key  in  zip (["lora_A" , "lora_B" ], ["lora_A" , "lora_B" ]):
754+             # norms 
755+             converted_state_dict [f"{ block_prefix } { lora_key }  ] =  original_state_dict .pop (
756+                 f"double_blocks.{ i } { lora_key }  
757+             )
758+             if  f"double_blocks.{ i } { lora_key }   in  original_state_dict_keys :
759+                 converted_state_dict [f"{ block_prefix } { lora_key }  ] =  original_state_dict .pop (
760+                     f"double_blocks.{ i } { lora_key }  
761+                 )
762+ 
763+             converted_state_dict [f"{ block_prefix } { lora_key }  ] =  original_state_dict .pop (
764+                 f"double_blocks.{ i } { lora_key }  
765+             )
766+             if  f"double_blocks.{ i } { lora_key }   in  original_state_dict_keys :
767+                 converted_state_dict [f"{ block_prefix } { lora_key }  ] =  original_state_dict .pop (
768+                     f"double_blocks.{ i } { lora_key }  
769+                 )
770+ 
771+             # Q, K, V 
772+             if  lora_key  ==  "lora_A" :
773+                 sample_lora_weight  =  original_state_dict .pop (f"double_blocks.{ i } { lora_key }  )
774+                 converted_state_dict [f"{ block_prefix } { lora_key }  ] =  torch .cat ([sample_lora_weight ])
775+                 converted_state_dict [f"{ block_prefix } { lora_key }  ] =  torch .cat ([sample_lora_weight ])
776+                 converted_state_dict [f"{ block_prefix } { lora_key }  ] =  torch .cat ([sample_lora_weight ])
777+ 
778+                 context_lora_weight  =  original_state_dict .pop (f"double_blocks.{ i } { lora_key }  )
779+                 converted_state_dict [f"{ block_prefix } { lora_key }  ] =  torch .cat (
780+                     [context_lora_weight ]
781+                 )
782+                 converted_state_dict [f"{ block_prefix } { lora_key }  ] =  torch .cat (
783+                     [context_lora_weight ]
784+                 )
785+                 converted_state_dict [f"{ block_prefix } { lora_key }  ] =  torch .cat (
786+                     [context_lora_weight ]
787+                 )
788+             else :
789+                 sample_q , sample_k , sample_v  =  torch .chunk (
790+                     original_state_dict .pop (f"double_blocks.{ i } { lora_key }  ), 3 , dim = 0 
791+                 )
792+                 converted_state_dict [f"{ block_prefix } { lora_key }  ] =  torch .cat ([sample_q ])
793+                 converted_state_dict [f"{ block_prefix } { lora_key }  ] =  torch .cat ([sample_k ])
794+                 converted_state_dict [f"{ block_prefix } { lora_key }  ] =  torch .cat ([sample_v ])
795+ 
796+                 context_q , context_k , context_v  =  torch .chunk (
797+                     original_state_dict .pop (f"double_blocks.{ i } { lora_key }  ), 3 , dim = 0 
798+                 )
799+                 converted_state_dict [f"{ block_prefix } { lora_key }  ] =  torch .cat ([context_q ])
800+                 converted_state_dict [f"{ block_prefix } { lora_key }  ] =  torch .cat ([context_k ])
801+                 converted_state_dict [f"{ block_prefix } { lora_key }  ] =  torch .cat ([context_v ])
802+ 
803+             if  f"double_blocks.{ i } { lora_key }   in  original_state_dict_keys :
804+                 sample_q_bias , sample_k_bias , sample_v_bias  =  torch .chunk (
805+                     original_state_dict .pop (f"double_blocks.{ i } { lora_key }  ), 3 , dim = 0 
806+                 )
807+                 converted_state_dict [f"{ block_prefix } { lora_key }  ] =  torch .cat ([sample_q_bias ])
808+                 converted_state_dict [f"{ block_prefix } { lora_key }  ] =  torch .cat ([sample_k_bias ])
809+                 converted_state_dict [f"{ block_prefix } { lora_key }  ] =  torch .cat ([sample_v_bias ])
810+ 
811+             if  f"double_blocks.{ i } { lora_key }   in  original_state_dict_keys :
812+                 context_q_bias , context_k_bias , context_v_bias  =  torch .chunk (
813+                     original_state_dict .pop (f"double_blocks.{ i } { lora_key }  ), 3 , dim = 0 
814+                 )
815+                 converted_state_dict [f"{ block_prefix } { lora_key }  ] =  torch .cat ([context_q_bias ])
816+                 converted_state_dict [f"{ block_prefix } { lora_key }  ] =  torch .cat ([context_k_bias ])
817+                 converted_state_dict [f"{ block_prefix } { lora_key }  ] =  torch .cat ([context_v_bias ])
818+ 
819+             # ff img_mlp 
820+             converted_state_dict [f"{ block_prefix } { lora_key }  ] =  original_state_dict .pop (
821+                 f"double_blocks.{ i } { lora_key }  
822+             )
823+             if  f"double_blocks.{ i } { lora_key }   in  original_state_dict_keys :
824+                 converted_state_dict [f"{ block_prefix } { lora_key }  ] =  original_state_dict .pop (
825+                     f"double_blocks.{ i } { lora_key }  
826+                 )
827+ 
828+             converted_state_dict [f"{ block_prefix } { lora_key }  ] =  original_state_dict .pop (
829+                 f"double_blocks.{ i } { lora_key }  
830+             )
831+             if  f"double_blocks.{ i } { lora_key }   in  original_state_dict_keys :
832+                 converted_state_dict [f"{ block_prefix } { lora_key }  ] =  original_state_dict .pop (
833+                     f"double_blocks.{ i } { lora_key }  
834+                 )
835+ 
836+             converted_state_dict [f"{ block_prefix } { lora_key }  ] =  original_state_dict .pop (
837+                 f"double_blocks.{ i } { lora_key }  
838+             )
839+             if  f"double_blocks.{ i } { lora_key }   in  original_state_dict_keys :
840+                 converted_state_dict [f"{ block_prefix } { lora_key }  ] =  original_state_dict .pop (
841+                     f"double_blocks.{ i } { lora_key }  
842+                 )
843+ 
844+             converted_state_dict [f"{ block_prefix } { lora_key }  ] =  original_state_dict .pop (
845+                 f"double_blocks.{ i } { lora_key }  
846+             )
847+             if  f"double_blocks.{ i } { lora_key }   in  original_state_dict_keys :
848+                 converted_state_dict [f"{ block_prefix } { lora_key }  ] =  original_state_dict .pop (
849+                     f"double_blocks.{ i } { lora_key }  
850+                 )
851+ 
852+             # output projections. 
853+             converted_state_dict [f"{ block_prefix } { lora_key }  ] =  original_state_dict .pop (
854+                 f"double_blocks.{ i } { lora_key }  
855+             )
856+             if  f"double_blocks.{ i } { lora_key }   in  original_state_dict_keys :
857+                 converted_state_dict [f"{ block_prefix } { lora_key }  ] =  original_state_dict .pop (
858+                     f"double_blocks.{ i } { lora_key }  
859+                 )
860+             converted_state_dict [f"{ block_prefix } { lora_key }  ] =  original_state_dict .pop (
861+                 f"double_blocks.{ i } { lora_key }  
862+             )
863+             if  f"double_blocks.{ i } { lora_key }   in  original_state_dict_keys :
864+                 converted_state_dict [f"{ block_prefix } { lora_key }  ] =  original_state_dict .pop (
865+                     f"double_blocks.{ i } { lora_key }  
866+                 )
867+ 
868+         # qk_norm 
869+         converted_state_dict [f"{ block_prefix }  ] =  original_state_dict .pop (
870+             f"double_blocks.{ i }  
871+         )
872+         converted_state_dict [f"{ block_prefix }  ] =  original_state_dict .pop (
873+             f"double_blocks.{ i }  
874+         )
875+         converted_state_dict [f"{ block_prefix }  ] =  original_state_dict .pop (
876+             f"double_blocks.{ i }  
877+         )
878+         converted_state_dict [f"{ block_prefix }  ] =  original_state_dict .pop (
879+             f"double_blocks.{ i }  
880+         )
881+ 
882+     # single transfomer blocks 
883+     for  i  in  range (num_single_layers ):
884+         block_prefix  =  f"single_transformer_blocks.{ i }  
885+ 
886+         for  lora_key  in  ["lora_A" , "lora_B" ]:
887+             # norm.linear  <- single_blocks.0.modulation.lin 
888+             converted_state_dict [f"{ block_prefix } { lora_key }  ] =  original_state_dict .pop (
889+                 f"single_blocks.{ i } { lora_key }  
890+             )
891+             if  f"single_blocks.{ i } { lora_key }   in  original_state_dict_keys :
892+                 converted_state_dict [f"{ block_prefix } { lora_key }  ] =  original_state_dict .pop (
893+                     f"single_blocks.{ i } { lora_key }  
894+                 )
895+ 
896+             # Q, K, V, mlp 
897+             mlp_hidden_dim  =  int (inner_dim  *  mlp_ratio )
898+             split_size  =  (inner_dim , inner_dim , inner_dim , mlp_hidden_dim )
899+ 
900+             if  lora_key  ==  "lora_A" :
901+                 lora_weight  =  original_state_dict .pop (f"single_blocks.{ i } { lora_key }  )
902+                 converted_state_dict [f"{ block_prefix } { lora_key }  ] =  torch .cat ([lora_weight ])
903+                 converted_state_dict [f"{ block_prefix } { lora_key }  ] =  torch .cat ([lora_weight ])
904+                 converted_state_dict [f"{ block_prefix } { lora_key }  ] =  torch .cat ([lora_weight ])
905+                 converted_state_dict [f"{ block_prefix } { lora_key }  ] =  torch .cat ([lora_weight ])
906+ 
907+                 if  f"single_blocks.{ i } { lora_key }   in  original_state_dict_keys :
908+                     lora_bias  =  original_state_dict .pop (f"single_blocks.{ i } { lora_key }  )
909+                     converted_state_dict [f"{ block_prefix } { lora_key }  ] =  torch .cat ([lora_bias ])
910+                     converted_state_dict [f"{ block_prefix } { lora_key }  ] =  torch .cat ([lora_bias ])
911+                     converted_state_dict [f"{ block_prefix } { lora_key }  ] =  torch .cat ([lora_bias ])
912+                     converted_state_dict [f"{ block_prefix } { lora_key }  ] =  torch .cat ([lora_bias ])
913+             else :
914+                 q , k , v , mlp  =  torch .split (
915+                     original_state_dict .pop (f"single_blocks.{ i } { lora_key }  ), split_size , dim = 0 
916+                 )
917+                 converted_state_dict [f"{ block_prefix } { lora_key }  ] =  torch .cat ([q ])
918+                 converted_state_dict [f"{ block_prefix } { lora_key }  ] =  torch .cat ([k ])
919+                 converted_state_dict [f"{ block_prefix } { lora_key }  ] =  torch .cat ([v ])
920+                 converted_state_dict [f"{ block_prefix } { lora_key }  ] =  torch .cat ([mlp ])
921+ 
922+                 if  f"single_blocks.{ i } { lora_key }   in  original_state_dict_keys :
923+                     q_bias , k_bias , v_bias , mlp_bias  =  torch .split (
924+                         original_state_dict .pop (f"single_blocks.{ i } { lora_key }  ), split_size , dim = 0 
925+                     )
926+                     converted_state_dict [f"{ block_prefix } { lora_key }  ] =  torch .cat ([q_bias ])
927+                     converted_state_dict [f"{ block_prefix } { lora_key }  ] =  torch .cat ([k_bias ])
928+                     converted_state_dict [f"{ block_prefix } { lora_key }  ] =  torch .cat ([v_bias ])
929+                     converted_state_dict [f"{ block_prefix } { lora_key }  ] =  torch .cat ([mlp_bias ])
930+ 
931+             # output projections. 
932+             converted_state_dict [f"{ block_prefix } { lora_key }  ] =  original_state_dict .pop (
933+                 f"single_blocks.{ i } { lora_key }  
934+             )
935+             if  f"single_blocks.{ i } { lora_key }   in  original_state_dict_keys :
936+                 converted_state_dict [f"{ block_prefix } { lora_key }  ] =  original_state_dict .pop (
937+                     f"single_blocks.{ i } { lora_key }  
938+                 )
939+ 
940+         # qk norm 
941+         converted_state_dict [f"{ block_prefix }  ] =  original_state_dict .pop (
942+             f"single_blocks.{ i }  
943+         )
944+         converted_state_dict [f"{ block_prefix }  ] =  original_state_dict .pop (
945+             f"single_blocks.{ i }  
946+         )
947+ 
948+     for  lora_key  in  ["lora_A" , "lora_B" ]:
949+         converted_state_dict [f"proj_out.{ lora_key }  ] =  original_state_dict .pop (
950+             f"final_layer.linear.{ lora_key }  
951+         )
952+         if  f"final_layer.linear.{ lora_key }   in  original_state_dict_keys :
953+             converted_state_dict [f"proj_out.{ lora_key }  ] =  original_state_dict .pop (
954+                 f"final_layer.linear.{ lora_key }  
955+             )
956+ 
957+         converted_state_dict [f"norm_out.linear.{ lora_key }  ] =  swap_scale_shift (
958+             original_state_dict .pop (f"final_layer.adaLN_modulation.1.{ lora_key }  )
959+         )
960+         if  f"final_layer.adaLN_modulation.1.{ lora_key }   in  original_state_dict_keys :
961+             converted_state_dict [f"norm_out.linear.{ lora_key }  ] =  swap_scale_shift (
962+                 original_state_dict .pop (f"final_layer.adaLN_modulation.1.{ lora_key }  )
963+             )
964+ 
965+     if  len (original_state_dict ) >  0 :
966+         raise  ValueError (f"`original_state_dict` should be empty at this point but has { original_state_dict .keys ()= }  )
967+ 
968+     for  key  in  list (converted_state_dict .keys ()):
969+         converted_state_dict [f"transformer.{ key }  ] =  converted_state_dict .pop (key )
970+ 
971+     return  converted_state_dict 
0 commit comments