Skip to content

Commit e7df197

Browse files
committed
fix lora conversion
1 parent 6ce181b commit e7df197

File tree

1 file changed

+6
-9
lines changed

1 file changed

+6
-9
lines changed

src/diffusers/loaders/lora_conversion_utils.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -666,7 +666,7 @@ def handle_qkv(sds_sd, ait_sd, sds_key, ait_keys, dims=None):
666666

667667

668668
def _convert_bfl_flux_control_lora_to_diffusers(old_state_dict):
669-
# in SD3 original implementation of AdaLayerNormContinuous, it split linear projection output into shift, scale;
669+
# in Flux original implementation of AdaLayerNormContinuous, it split linear projection output into shift, scale;
670670
# while in diffusers it split into scale, shift. Here we swap the linear projection weights in order to be able to use diffusers implementation
671671
def swap_scale_shift(weight):
672672
shift, scale = weight.chunk(2, dim=0)
@@ -755,15 +755,13 @@ def remap_double_blocks(key, converted_state_dict, original_state_dict):
755755
new_key = new_key.replace("img_mod.lin", "norm1.linear")
756756
new_key = new_key.replace("txt_mod.lin", "norm1_context.linear")
757757
new_key = new_key.replace("img_mlp.0", "ff.net.0.proj")
758-
new_key = new_key.replace("img_mlp.2", "ff.net.2.proj")
758+
new_key = new_key.replace("img_mlp.2", "ff.net.2")
759759
new_key = new_key.replace("txt_mlp.0", "ff_context.net.0.proj")
760-
new_key = new_key.replace("txt_mlp.2", "ff_context.net.2.proj")
760+
new_key = new_key.replace("txt_mlp.2", "ff_context.net.2")
761761
new_key = new_key.replace("img_attn.proj", "attn.to_out.0")
762762
new_key = new_key.replace("img_attn.norm.query_norm.scale", "attn.norm_q.weight")
763763
new_key = new_key.replace("img_attn.norm.key_norm.scale", "attn.norm_k.weight")
764-
new_key = new_key.replace("txt_attn.proj", "attn.to_add_out.0")
765-
# new_key = new_key.replace("txt_attn.norm.query_norm.scale", "attn.norm_added_q.weight")
766-
# new_key = new_key.replace("txt_attn.norm.key_norm.scale", "attn.norm_added_k.weight")
764+
new_key = new_key.replace("txt_attn.proj", "attn.to_add_out")
767765
converted_state_dict[new_key] = original_state_dict.pop(key)
768766

769767
def remap_single_blocks(key, converted_state_dict, original_state_dict):
@@ -817,10 +815,9 @@ def remap_single_blocks(key, converted_state_dict, original_state_dict):
817815
)
818816

819817
else:
820-
new_key = key.replace("modulation.lin", "norm.linear")
818+
new_key = key.replace("single_blocks", "single_transformer_blocks")
819+
new_key = new_key.replace("modulation.lin", "norm.linear")
821820
new_key = new_key.replace("linear2", "proj_out")
822-
# new_key = new_key.replace("norm.query_norm.scale", "attn.norm_q.weight")
823-
# new_key = new_key.replace("norm.key_norm.scale", "attn.norm_k.weight")
824821
converted_state_dict[new_key] = original_state_dict.pop(key)
825822

826823
def remap_final_layer(key, converted_state_dict, original_state_dict):

0 commit comments

Comments
 (0)