@@ -666,7 +666,7 @@ def handle_qkv(sds_sd, ait_sd, sds_key, ait_keys, dims=None):
666666
667667
668668def _convert_bfl_flux_control_lora_to_diffusers (old_state_dict ):
669- # in SD3 original implementation of AdaLayerNormContinuous, it split linear projection output into shift, scale;
669+ # in Flux original implementation of AdaLayerNormContinuous, it split linear projection output into shift, scale;
670670 # while in diffusers it split into scale, shift. Here we swap the linear projection weights in order to be able to use diffusers implementation
671671 def swap_scale_shift (weight ):
672672 shift , scale = weight .chunk (2 , dim = 0 )
@@ -755,15 +755,13 @@ def remap_double_blocks(key, converted_state_dict, original_state_dict):
755755 new_key = new_key .replace ("img_mod.lin" , "norm1.linear" )
756756 new_key = new_key .replace ("txt_mod.lin" , "norm1_context.linear" )
757757 new_key = new_key .replace ("img_mlp.0" , "ff.net.0.proj" )
758- new_key = new_key .replace ("img_mlp.2" , "ff.net.2.proj " )
758+ new_key = new_key .replace ("img_mlp.2" , "ff.net.2" )
759759 new_key = new_key .replace ("txt_mlp.0" , "ff_context.net.0.proj" )
760- new_key = new_key .replace ("txt_mlp.2" , "ff_context.net.2.proj " )
760+ new_key = new_key .replace ("txt_mlp.2" , "ff_context.net.2" )
761761 new_key = new_key .replace ("img_attn.proj" , "attn.to_out.0" )
762762 new_key = new_key .replace ("img_attn.norm.query_norm.scale" , "attn.norm_q.weight" )
763763 new_key = new_key .replace ("img_attn.norm.key_norm.scale" , "attn.norm_k.weight" )
764- new_key = new_key .replace ("txt_attn.proj" , "attn.to_add_out.0" )
765- # new_key = new_key.replace("txt_attn.norm.query_norm.scale", "attn.norm_added_q.weight")
766- # new_key = new_key.replace("txt_attn.norm.key_norm.scale", "attn.norm_added_k.weight")
764+ new_key = new_key .replace ("txt_attn.proj" , "attn.to_add_out" )
767765 converted_state_dict [new_key ] = original_state_dict .pop (key )
768766
769767 def remap_single_blocks (key , converted_state_dict , original_state_dict ):
@@ -817,10 +815,9 @@ def remap_single_blocks(key, converted_state_dict, original_state_dict):
817815 )
818816
819817 else :
820- new_key = key .replace ("modulation.lin" , "norm.linear" )
818+ new_key = key .replace ("single_blocks" , "single_transformer_blocks" )
819+ new_key = new_key .replace ("modulation.lin" , "norm.linear" )
821820 new_key = new_key .replace ("linear2" , "proj_out" )
822- # new_key = new_key.replace("norm.query_norm.scale", "attn.norm_q.weight")
823- # new_key = new_key.replace("norm.key_norm.scale", "attn.norm_k.weight")
824821 converted_state_dict [new_key ] = original_state_dict .pop (key )
825822
826823 def remap_final_layer (key , converted_state_dict , original_state_dict ):
0 commit comments