Skip to content

Commit 8765e1b

Browse files
committed
Merge branch 'flux-control-lora' into flux-control-lora-training-script
2 parents 90708fa + f688ecf commit 8765e1b

File tree

5 files changed

+833
-11
lines changed

5 files changed

+833
-11
lines changed

src/diffusers/loaders/lora_conversion_utils.py

Lines changed: 306 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -663,3 +663,309 @@ def handle_qkv(sds_sd, ait_sd, sds_key, ait_keys, dims=None):
663663
raise ValueError(f"`old_state_dict` should be at this point but has: {list(old_state_dict.keys())}.")
664664

665665
return new_state_dict
666+
667+
668+
def _convert_bfl_flux_control_lora_to_diffusers(original_state_dict):
669+
converted_state_dict = {}
670+
original_state_dict_keys = list(original_state_dict.keys())
671+
num_layers = 19
672+
num_single_layers = 38
673+
inner_dim = 3072
674+
mlp_ratio = 4.0
675+
676+
def swap_scale_shift(weight):
677+
shift, scale = weight.chunk(2, dim=0)
678+
new_weight = torch.cat([scale, shift], dim=0)
679+
return new_weight
680+
681+
for lora_key in ["lora_A", "lora_B"]:
682+
## time_text_embed.timestep_embedder <- time_in
683+
converted_state_dict[
684+
f"time_text_embed.timestep_embedder.linear_1.{lora_key}.weight"
685+
] = original_state_dict.pop(f"time_in.in_layer.{lora_key}.weight")
686+
if f"time_in.in_layer.{lora_key}.bias" in original_state_dict_keys:
687+
converted_state_dict[
688+
f"time_text_embed.timestep_embedder.linear_1.{lora_key}.bias"
689+
] = original_state_dict.pop(f"time_in.in_layer.{lora_key}.bias")
690+
691+
converted_state_dict[
692+
f"time_text_embed.timestep_embedder.linear_2.{lora_key}.weight"
693+
] = original_state_dict.pop(f"time_in.out_layer.{lora_key}.weight")
694+
if f"time_in.out_layer.{lora_key}.bias" in original_state_dict_keys:
695+
converted_state_dict[
696+
f"time_text_embed.timestep_embedder.linear_2.{lora_key}.bias"
697+
] = original_state_dict.pop(f"time_in.out_layer.{lora_key}.bias")
698+
699+
## time_text_embed.text_embedder <- vector_in
700+
converted_state_dict[f"time_text_embed.text_embedder.linear_1.{lora_key}.weight"] = original_state_dict.pop(
701+
f"vector_in.in_layer.{lora_key}.weight"
702+
)
703+
if f"vector_in.in_layer.{lora_key}.bias" in original_state_dict_keys:
704+
converted_state_dict[f"time_text_embed.text_embedder.linear_1.{lora_key}.bias"] = original_state_dict.pop(
705+
f"vector_in.in_layer.{lora_key}.bias"
706+
)
707+
708+
converted_state_dict[f"time_text_embed.text_embedder.linear_2.{lora_key}.weight"] = original_state_dict.pop(
709+
f"vector_in.out_layer.{lora_key}.weight"
710+
)
711+
if f"vector_in.out_layer.{lora_key}.bias" in original_state_dict_keys:
712+
converted_state_dict[f"time_text_embed.text_embedder.linear_2.{lora_key}.bias"] = original_state_dict.pop(
713+
f"vector_in.out_layer.{lora_key}.bias"
714+
)
715+
716+
# guidance
717+
has_guidance = any("guidance" in k for k in original_state_dict)
718+
if has_guidance:
719+
converted_state_dict[
720+
f"time_text_embed.guidance_embedder.linear_1.{lora_key}.weight"
721+
] = original_state_dict.pop(f"guidance_in.in_layer.{lora_key}.weight")
722+
if f"guidance_in.in_layer.{lora_key}.bias" in original_state_dict_keys:
723+
converted_state_dict[
724+
f"time_text_embed.guidance_embedder.linear_1.{lora_key}.bias"
725+
] = original_state_dict.pop(f"guidance_in.in_layer.{lora_key}.bias")
726+
727+
converted_state_dict[
728+
f"time_text_embed.guidance_embedder.linear_2.{lora_key}.weight"
729+
] = original_state_dict.pop(f"guidance_in.out_layer.{lora_key}.weight")
730+
if f"guidance_in.out_layer.{lora_key}.bias" in original_state_dict_keys:
731+
converted_state_dict[
732+
f"time_text_embed.guidance_embedder.linear_2.{lora_key}.bias"
733+
] = original_state_dict.pop(f"guidance_in.out_layer.{lora_key}.bias")
734+
735+
# context_embedder
736+
converted_state_dict[f"context_embedder.{lora_key}.weight"] = original_state_dict.pop(
737+
f"txt_in.{lora_key}.weight"
738+
)
739+
if f"txt_in.{lora_key}.bias" in original_state_dict_keys:
740+
converted_state_dict[f"context_embedder.{lora_key}.bias"] = original_state_dict.pop(
741+
f"txt_in.{lora_key}.bias"
742+
)
743+
744+
# x_embedder
745+
converted_state_dict[f"x_embedder.{lora_key}.weight"] = original_state_dict.pop(f"img_in.{lora_key}.weight")
746+
if f"img_in.{lora_key}.bias" in original_state_dict_keys:
747+
converted_state_dict[f"x_embedder.{lora_key}.bias"] = original_state_dict.pop(f"img_in.{lora_key}.bias")
748+
749+
# double transformer blocks
750+
for i in range(num_layers):
751+
block_prefix = f"transformer_blocks.{i}."
752+
753+
for lora_key, lora_key in zip(["lora_A", "lora_B"], ["lora_A", "lora_B"]):
754+
# norms
755+
converted_state_dict[f"{block_prefix}norm1.linear.{lora_key}.weight"] = original_state_dict.pop(
756+
f"double_blocks.{i}.img_mod.lin.{lora_key}.weight"
757+
)
758+
if f"double_blocks.{i}.img_mod.lin.{lora_key}.bias" in original_state_dict_keys:
759+
converted_state_dict[f"{block_prefix}norm1.linear.{lora_key}.bias"] = original_state_dict.pop(
760+
f"double_blocks.{i}.img_mod.lin.{lora_key}.bias"
761+
)
762+
763+
converted_state_dict[f"{block_prefix}norm1_context.linear.{lora_key}.weight"] = original_state_dict.pop(
764+
f"double_blocks.{i}.txt_mod.lin.{lora_key}.weight"
765+
)
766+
if f"double_blocks.{i}.txt_mod.lin.{lora_key}.bias" in original_state_dict_keys:
767+
converted_state_dict[f"{block_prefix}norm1_context.linear.{lora_key}.bias"] = original_state_dict.pop(
768+
f"double_blocks.{i}.txt_mod.lin.{lora_key}.bias"
769+
)
770+
771+
# Q, K, V
772+
if lora_key == "lora_A":
773+
sample_lora_weight = original_state_dict.pop(f"double_blocks.{i}.img_attn.qkv.{lora_key}.weight")
774+
converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.weight"] = torch.cat([sample_lora_weight])
775+
converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.weight"] = torch.cat([sample_lora_weight])
776+
converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.weight"] = torch.cat([sample_lora_weight])
777+
778+
context_lora_weight = original_state_dict.pop(f"double_blocks.{i}.txt_attn.qkv.{lora_key}.weight")
779+
converted_state_dict[f"{block_prefix}attn.add_q_proj.{lora_key}.weight"] = torch.cat(
780+
[context_lora_weight]
781+
)
782+
converted_state_dict[f"{block_prefix}attn.add_k_proj.{lora_key}.weight"] = torch.cat(
783+
[context_lora_weight]
784+
)
785+
converted_state_dict[f"{block_prefix}attn.add_v_proj.{lora_key}.weight"] = torch.cat(
786+
[context_lora_weight]
787+
)
788+
else:
789+
sample_q, sample_k, sample_v = torch.chunk(
790+
original_state_dict.pop(f"double_blocks.{i}.img_attn.qkv.{lora_key}.weight"), 3, dim=0
791+
)
792+
converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.weight"] = torch.cat([sample_q])
793+
converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.weight"] = torch.cat([sample_k])
794+
converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.weight"] = torch.cat([sample_v])
795+
796+
context_q, context_k, context_v = torch.chunk(
797+
original_state_dict.pop(f"double_blocks.{i}.txt_attn.qkv.{lora_key}.weight"), 3, dim=0
798+
)
799+
converted_state_dict[f"{block_prefix}attn.add_q_proj.{lora_key}.weight"] = torch.cat([context_q])
800+
converted_state_dict[f"{block_prefix}attn.add_k_proj.{lora_key}.weight"] = torch.cat([context_k])
801+
converted_state_dict[f"{block_prefix}attn.add_v_proj.{lora_key}.weight"] = torch.cat([context_v])
802+
803+
if f"double_blocks.{i}.img_attn.qkv.{lora_key}.bias" in original_state_dict_keys:
804+
sample_q_bias, sample_k_bias, sample_v_bias = torch.chunk(
805+
original_state_dict.pop(f"double_blocks.{i}.img_attn.qkv.{lora_key}.bias"), 3, dim=0
806+
)
807+
converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.bias"] = torch.cat([sample_q_bias])
808+
converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.bias"] = torch.cat([sample_k_bias])
809+
converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.bias"] = torch.cat([sample_v_bias])
810+
811+
if f"double_blocks.{i}.txt_attn.qkv.{lora_key}.bias" in original_state_dict_keys:
812+
context_q_bias, context_k_bias, context_v_bias = torch.chunk(
813+
original_state_dict.pop(f"double_blocks.{i}.txt_attn.qkv.{lora_key}.bias"), 3, dim=0
814+
)
815+
converted_state_dict[f"{block_prefix}attn.add_q_proj.{lora_key}.bias"] = torch.cat([context_q_bias])
816+
converted_state_dict[f"{block_prefix}attn.add_k_proj.{lora_key}.bias"] = torch.cat([context_k_bias])
817+
converted_state_dict[f"{block_prefix}attn.add_v_proj.{lora_key}.bias"] = torch.cat([context_v_bias])
818+
819+
# ff img_mlp
820+
converted_state_dict[f"{block_prefix}ff.net.0.proj.{lora_key}.weight"] = original_state_dict.pop(
821+
f"double_blocks.{i}.img_mlp.0.{lora_key}.weight"
822+
)
823+
if f"double_blocks.{i}.img_mlp.0.{lora_key}.bias" in original_state_dict_keys:
824+
converted_state_dict[f"{block_prefix}ff.net.0.proj.{lora_key}.bias"] = original_state_dict.pop(
825+
f"double_blocks.{i}.img_mlp.0.{lora_key}.bias"
826+
)
827+
828+
converted_state_dict[f"{block_prefix}ff.net.2.{lora_key}.weight"] = original_state_dict.pop(
829+
f"double_blocks.{i}.img_mlp.2.{lora_key}.weight"
830+
)
831+
if f"double_blocks.{i}.img_mlp.2.{lora_key}.bias" in original_state_dict_keys:
832+
converted_state_dict[f"{block_prefix}ff.net.2.{lora_key}.bias"] = original_state_dict.pop(
833+
f"double_blocks.{i}.img_mlp.2.{lora_key}.bias"
834+
)
835+
836+
converted_state_dict[f"{block_prefix}ff_context.net.0.proj.{lora_key}.weight"] = original_state_dict.pop(
837+
f"double_blocks.{i}.txt_mlp.0.{lora_key}.weight"
838+
)
839+
if f"double_blocks.{i}.txt_mlp.0.{lora_key}.bias" in original_state_dict_keys:
840+
converted_state_dict[f"{block_prefix}ff_context.net.0.proj.{lora_key}.bias"] = original_state_dict.pop(
841+
f"double_blocks.{i}.txt_mlp.0.{lora_key}.bias"
842+
)
843+
844+
converted_state_dict[f"{block_prefix}ff_context.net.2.{lora_key}.weight"] = original_state_dict.pop(
845+
f"double_blocks.{i}.txt_mlp.2.{lora_key}.weight"
846+
)
847+
if f"double_blocks.{i}.txt_mlp.2.{lora_key}.bias" in original_state_dict_keys:
848+
converted_state_dict[f"{block_prefix}ff_context.net.2.{lora_key}.bias"] = original_state_dict.pop(
849+
f"double_blocks.{i}.txt_mlp.2.{lora_key}.bias"
850+
)
851+
852+
# output projections.
853+
converted_state_dict[f"{block_prefix}attn.to_out.0.{lora_key}.weight"] = original_state_dict.pop(
854+
f"double_blocks.{i}.img_attn.proj.{lora_key}.weight"
855+
)
856+
if f"double_blocks.{i}.img_attn.proj.{lora_key}.bias" in original_state_dict_keys:
857+
converted_state_dict[f"{block_prefix}attn.to_out.0.{lora_key}.bias"] = original_state_dict.pop(
858+
f"double_blocks.{i}.img_attn.proj.{lora_key}.bias"
859+
)
860+
converted_state_dict[f"{block_prefix}attn.to_add_out.{lora_key}.weight"] = original_state_dict.pop(
861+
f"double_blocks.{i}.txt_attn.proj.{lora_key}.weight"
862+
)
863+
if f"double_blocks.{i}.txt_attn.proj.{lora_key}.bias" in original_state_dict_keys:
864+
converted_state_dict[f"{block_prefix}attn.to_add_out.{lora_key}.bias"] = original_state_dict.pop(
865+
f"double_blocks.{i}.txt_attn.proj.{lora_key}.bias"
866+
)
867+
868+
# qk_norm
869+
converted_state_dict[f"{block_prefix}attn.norm_q.weight"] = original_state_dict.pop(
870+
f"double_blocks.{i}.img_attn.norm.query_norm.scale"
871+
)
872+
converted_state_dict[f"{block_prefix}attn.norm_k.weight"] = original_state_dict.pop(
873+
f"double_blocks.{i}.img_attn.norm.key_norm.scale"
874+
)
875+
converted_state_dict[f"{block_prefix}attn.norm_added_q.weight"] = original_state_dict.pop(
876+
f"double_blocks.{i}.txt_attn.norm.query_norm.scale"
877+
)
878+
converted_state_dict[f"{block_prefix}attn.norm_added_k.weight"] = original_state_dict.pop(
879+
f"double_blocks.{i}.txt_attn.norm.key_norm.scale"
880+
)
881+
882+
# single transfomer blocks
883+
for i in range(num_single_layers):
884+
block_prefix = f"single_transformer_blocks.{i}."
885+
886+
for lora_key in ["lora_A", "lora_B"]:
887+
# norm.linear <- single_blocks.0.modulation.lin
888+
converted_state_dict[f"{block_prefix}norm.linear.{lora_key}.weight"] = original_state_dict.pop(
889+
f"single_blocks.{i}.modulation.lin.{lora_key}.weight"
890+
)
891+
if f"single_blocks.{i}.modulation.lin.{lora_key}.bias" in original_state_dict_keys:
892+
converted_state_dict[f"{block_prefix}norm.linear.{lora_key}.bias"] = original_state_dict.pop(
893+
f"single_blocks.{i}.modulation.lin.{lora_key}.bias"
894+
)
895+
896+
# Q, K, V, mlp
897+
mlp_hidden_dim = int(inner_dim * mlp_ratio)
898+
split_size = (inner_dim, inner_dim, inner_dim, mlp_hidden_dim)
899+
900+
if lora_key == "lora_A":
901+
lora_weight = original_state_dict.pop(f"single_blocks.{i}.linear1.{lora_key}.weight")
902+
converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.weight"] = torch.cat([lora_weight])
903+
converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.weight"] = torch.cat([lora_weight])
904+
converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.weight"] = torch.cat([lora_weight])
905+
converted_state_dict[f"{block_prefix}proj_mlp.{lora_key}.weight"] = torch.cat([lora_weight])
906+
907+
if f"single_blocks.{i}.linear1.{lora_key}.bias" in original_state_dict_keys:
908+
lora_bias = original_state_dict.pop(f"single_blocks.{i}.linear1.{lora_key}.bias")
909+
converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.bias"] = torch.cat([lora_bias])
910+
converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.bias"] = torch.cat([lora_bias])
911+
converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.bias"] = torch.cat([lora_bias])
912+
converted_state_dict[f"{block_prefix}proj_mlp.{lora_key}.bias"] = torch.cat([lora_bias])
913+
else:
914+
q, k, v, mlp = torch.split(
915+
original_state_dict.pop(f"single_blocks.{i}.linear1.{lora_key}.weight"), split_size, dim=0
916+
)
917+
converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.weight"] = torch.cat([q])
918+
converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.weight"] = torch.cat([k])
919+
converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.weight"] = torch.cat([v])
920+
converted_state_dict[f"{block_prefix}proj_mlp.{lora_key}.weight"] = torch.cat([mlp])
921+
922+
if f"single_blocks.{i}.linear1.{lora_key}.bias" in original_state_dict_keys:
923+
q_bias, k_bias, v_bias, mlp_bias = torch.split(
924+
original_state_dict.pop(f"single_blocks.{i}.linear1.{lora_key}.bias"), split_size, dim=0
925+
)
926+
converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.bias"] = torch.cat([q_bias])
927+
converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.bias"] = torch.cat([k_bias])
928+
converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.bias"] = torch.cat([v_bias])
929+
converted_state_dict[f"{block_prefix}proj_mlp.{lora_key}.bias"] = torch.cat([mlp_bias])
930+
931+
# output projections.
932+
converted_state_dict[f"{block_prefix}proj_out.{lora_key}.weight"] = original_state_dict.pop(
933+
f"single_blocks.{i}.linear2.{lora_key}.weight"
934+
)
935+
if f"single_blocks.{i}.linear2.{lora_key}.bias" in original_state_dict_keys:
936+
converted_state_dict[f"{block_prefix}proj_out.{lora_key}.bias"] = original_state_dict.pop(
937+
f"single_blocks.{i}.linear2.{lora_key}.bias"
938+
)
939+
940+
# qk norm
941+
converted_state_dict[f"{block_prefix}attn.norm_q.weight"] = original_state_dict.pop(
942+
f"single_blocks.{i}.norm.query_norm.scale"
943+
)
944+
converted_state_dict[f"{block_prefix}attn.norm_k.weight"] = original_state_dict.pop(
945+
f"single_blocks.{i}.norm.key_norm.scale"
946+
)
947+
948+
for lora_key in ["lora_A", "lora_B"]:
949+
converted_state_dict[f"proj_out.{lora_key}.weight"] = original_state_dict.pop(
950+
f"final_layer.linear.{lora_key}.weight"
951+
)
952+
if f"final_layer.linear.{lora_key}.bias" in original_state_dict_keys:
953+
converted_state_dict[f"proj_out.{lora_key}.bias"] = original_state_dict.pop(
954+
f"final_layer.linear.{lora_key}.bias"
955+
)
956+
957+
converted_state_dict[f"norm_out.linear.{lora_key}.weight"] = swap_scale_shift(
958+
original_state_dict.pop(f"final_layer.adaLN_modulation.1.{lora_key}.weight")
959+
)
960+
if f"final_layer.adaLN_modulation.1.{lora_key}.bias" in original_state_dict_keys:
961+
converted_state_dict[f"norm_out.linear.{lora_key}.bias"] = swap_scale_shift(
962+
original_state_dict.pop(f"final_layer.adaLN_modulation.1.{lora_key}.bias")
963+
)
964+
965+
if len(original_state_dict) > 0:
966+
raise ValueError(f"`original_state_dict` should be empty at this point but has {original_state_dict.keys()=}.")
967+
968+
for key in list(converted_state_dict.keys()):
969+
converted_state_dict[f"transformer.{key}"] = converted_state_dict.pop(key)
970+
971+
return converted_state_dict

0 commit comments

Comments
 (0)