Skip to content

Commit 091b185

Browse files
committed
support sd3.5 non-diffusers loras.
1 parent 6131a93 commit 091b185

File tree

2 files changed

+276
-4
lines changed

2 files changed

+276
-4
lines changed

src/diffusers/loaders/lora_conversion_utils.py

Lines changed: 245 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -663,3 +663,248 @@ def handle_qkv(sds_sd, ait_sd, sds_key, ait_keys, dims=None):
663663
raise ValueError(f"`old_state_dict` should be at this point but has: {list(old_state_dict.keys())}.")
664664

665665
return new_state_dict
666+
667+
668+
def _convert_non_diffusers_sd3_lora_to_diffusers(state_dict, prefix=None):
669+
new_state_dict = {}
670+
671+
# in SD3 original implementation of AdaLayerNormContinuous, it split linear projection output into shift, scale;
672+
# while in diffusers it split into scale, shift. Here we swap the linear projection weights in order to be able to use diffusers implementation
673+
def swap_scale_shift(weight):
674+
shift, scale = weight.chunk(2, dim=0)
675+
new_weight = torch.cat([scale, shift], dim=0)
676+
return new_weight
677+
678+
def calculate_scales(key):
679+
lora_rank = state_dict[f"{key}.lora_down.weight"].shape[0]
680+
alpha = state_dict.pop(key + ".alpha")
681+
scale = alpha / lora_rank
682+
683+
# calculate scale_down and scale_up
684+
scale_down = scale
685+
scale_up = 1.0
686+
while scale_down * 2 < scale_up:
687+
scale_down *= 2
688+
scale_up /= 2
689+
690+
return scale_down, scale_up
691+
692+
def weight_is_sparse(key, rank, num_splits, up_weight):
693+
dims = [up_weight.shape[0] // num_splits] * num_splits
694+
695+
is_sparse = False
696+
requested_rank = rank
697+
if rank % num_splits == 0:
698+
requested_rank = rank // num_splits
699+
is_sparse = True
700+
i = 0
701+
for j in range(len(dims)):
702+
for k in range(len(dims)):
703+
if j == k:
704+
continue
705+
is_sparse = is_sparse and torch.all(
706+
up_weight[i : i + dims[j], k * requested_rank : (k + 1) * requested_rank] == 0
707+
)
708+
i += dims[j]
709+
if is_sparse:
710+
logger.info(f"weight is sparse: {key}")
711+
712+
return is_sparse, requested_rank
713+
714+
# handle only transformer blocks for now.
715+
layers = set()
716+
for k in state_dict:
717+
if "joint_blocks" in k:
718+
idx = int(k.split("_", 4)[-1].split("_", 1)[0])
719+
layers.add(idx)
720+
num_layers = max(layers) + 1
721+
722+
for i in range(num_layers):
723+
# norms
724+
for diffusers_key, orig_key in [
725+
(f"transformer_blocks.{i}.norm1.linear", f"lora_unet_joint_blocks_{i}_x_block_adaLN_modulation_1")
726+
]:
727+
scale_down, scale_up = calculate_scales(orig_key)
728+
new_state_dict[f"{diffusers_key}.lora_A.weight"] = (
729+
state_dict.pop(f"{orig_key}.lora_down.weight") * scale_down
730+
)
731+
new_state_dict[f"{diffusers_key}.lora_B.weight"] = state_dict.pop(f"{orig_key}.lora_up.weight") * scale_up
732+
733+
if not (i == num_layers - 1):
734+
for diffusers_key, orig_key in [
735+
(
736+
f"transformer_blocks.{i}.norm1_context.linear",
737+
f"lora_unet_joint_blocks_{i}_context_block_adaLN_modulation_1",
738+
)
739+
]:
740+
scale_down, scale_up = calculate_scales(orig_key)
741+
new_state_dict[f"{diffusers_key}.lora_A.weight"] = (
742+
state_dict.pop(f"{orig_key}.lora_down.weight") * scale_down
743+
)
744+
new_state_dict[f"{diffusers_key}.lora_B.weight"] = (
745+
state_dict.pop(f"{orig_key}.lora_up.weight") * scale_up
746+
)
747+
else:
748+
for diffusers_key, orig_key in [
749+
(
750+
f"transformer_blocks.{i}.norm1_context.linear",
751+
f"lora_unet_joint_blocks_{i}_context_block_adaLN_modulation_1",
752+
)
753+
]:
754+
scale_down, scale_up = calculate_scales(orig_key)
755+
new_state_dict[f"{diffusers_key}.lora_A.weight"] = (
756+
swap_scale_shift(state_dict.pop(f"{orig_key}.lora_down.weight")) * scale_down
757+
)
758+
new_state_dict[f"{diffusers_key}.lora_B.weight"] = (
759+
swap_scale_shift(state_dict.pop(f"{orig_key}.lora_up.weight")) * scale_up
760+
)
761+
762+
# output projections
763+
for diffusers_key, orig_key in [
764+
(f"transformer_blocks.{i}.attn.to_out.0", f"lora_unet_joint_blocks_{i}_x_block_attn_proj")
765+
]:
766+
scale_down, scale_up = calculate_scales(orig_key)
767+
new_state_dict[f"{diffusers_key}.lora_A.weight"] = (
768+
state_dict.pop(f"{orig_key}.lora_down.weight") * scale_down
769+
)
770+
new_state_dict[f"{diffusers_key}.lora_B.weight"] = state_dict.pop(f"{orig_key}.lora_up.weight") * scale_up
771+
if not (i == num_layers - 1):
772+
for diffusers_key, orig_key in [
773+
(f"transformer_blocks.{i}.attn.to_add_out", f"lora_unet_joint_blocks_{i}_context_block_attn_proj")
774+
]:
775+
scale_down, scale_up = calculate_scales(orig_key)
776+
new_state_dict[f"{diffusers_key}.lora_A.weight"] = (
777+
state_dict.pop(f"{orig_key}.lora_down.weight") * scale_down
778+
)
779+
new_state_dict[f"{diffusers_key}.lora_B.weight"] = (
780+
state_dict.pop(f"{orig_key}.lora_up.weight") * scale_up
781+
)
782+
783+
# ffs
784+
for diffusers_key, orig_key in [
785+
(f"transformer_blocks.{i}.ff.net.0.proj", f"lora_unet_joint_blocks_{i}_x_block_mlp_fc1")
786+
]:
787+
scale_down, scale_up = calculate_scales(orig_key)
788+
new_state_dict[f"{diffusers_key}.lora_A.weight"] = (
789+
state_dict.pop(f"{orig_key}.lora_down.weight") * scale_down
790+
)
791+
new_state_dict[f"{diffusers_key}.lora_B.weight"] = state_dict.pop(f"{orig_key}.lora_up.weight") * scale_up
792+
793+
for diffusers_key, orig_key in [
794+
(f"transformer_blocks.{i}.ff.net.2", f"lora_unet_joint_blocks_{i}_x_block_mlp_fc2")
795+
]:
796+
scale_down, scale_up = calculate_scales(orig_key)
797+
new_state_dict[f"{diffusers_key}.lora_A.weight"] = (
798+
state_dict.pop(f"{orig_key}.lora_down.weight") * scale_down
799+
)
800+
new_state_dict[f"{diffusers_key}.lora_B.weight"] = state_dict.pop(f"{orig_key}.lora_up.weight") * scale_up
801+
802+
if not (i == num_layers - 1):
803+
for diffusers_key, orig_key in [
804+
(f"transformer_blocks.{i}.ff_context.net.0.proj", f"lora_unet_joint_blocks_{i}_context_block_mlp_fc1")
805+
]:
806+
scale_down, scale_up = calculate_scales(orig_key)
807+
new_state_dict[f"{diffusers_key}.lora_A.weight"] = (
808+
state_dict.pop(f"{orig_key}.lora_down.weight") * scale_down
809+
)
810+
new_state_dict[f"{diffusers_key}.lora_B.weight"] = (
811+
state_dict.pop(f"{orig_key}.lora_up.weight") * scale_up
812+
)
813+
814+
for diffusers_key, orig_key in [
815+
(f"transformer_blocks.{i}.ff_context.net.2", f"lora_unet_joint_blocks_{i}_context_block_mlp_fc2")
816+
]:
817+
scale_down, scale_up = calculate_scales(orig_key)
818+
new_state_dict[f"{diffusers_key}.lora_A.weight"] = (
819+
state_dict.pop(f"{orig_key}.lora_down.weight") * scale_down
820+
)
821+
new_state_dict[f"{diffusers_key}.lora_B.weight"] = (
822+
state_dict.pop(f"{orig_key}.lora_up.weight") * scale_up
823+
)
824+
825+
# core transformer blocks.
826+
# sample blocks.
827+
scale_down, scale_up = calculate_scales(f"lora_unet_joint_blocks_{i}_x_block_attn_qkv")
828+
is_sparse, requested_rank = weight_is_sparse(
829+
key=f"lora_unet_joint_blocks_{i}_x_block_attn_qkv",
830+
rank=state_dict[f"lora_unet_joint_blocks_{i}_x_block_attn_qkv.lora_down.weight"].shape[0],
831+
num_splits=3,
832+
up_weight=state_dict[f"lora_unet_joint_blocks_{i}_x_block_attn_qkv.lora_up.weight"],
833+
)
834+
num_splits = 3
835+
sample_qkv_lora_down = (
836+
state_dict.pop(f"lora_unet_joint_blocks_{i}_x_block_attn_qkv.lora_down.weight") * scale_down
837+
)
838+
sample_qkv_lora_up = state_dict.pop(f"lora_unet_joint_blocks_{i}_x_block_attn_qkv.lora_up.weight") * scale_up
839+
dims = [sample_qkv_lora_up.shape[0] // num_splits] * num_splits # 3 = num_splits
840+
if not is_sparse:
841+
for attn_k in ["to_q", "to_k", "to_v"]:
842+
new_state_dict[f"transformer_blocks.{i}.attn.{attn_k}.lora_A.weight"] = sample_qkv_lora_down
843+
for attn_k, v in zip(["to_q", "to_k", "to_v"], torch.split(sample_qkv_lora_up, dims, dim=0)):
844+
new_state_dict[f"transformer_blocks.{i}.attn.{attn_k}.lora_B.weight"] = v
845+
else:
846+
# down_weight is chunked to each split
847+
new_state_dict.update(
848+
{
849+
f"transformer_blocks.{i}.attn.{k}.lora_A.weight": v
850+
for k, v in zip(["to_q", "to_k", "to_v"], torch.chunk(sample_qkv_lora_down, num_splits, dim=0))
851+
}
852+
) # noqa: C416
853+
854+
# up_weight is sparse: only non-zero values are copied to each split
855+
i = 0
856+
for j, attn_k in enumerate(["to_q", "to_k", "to_v"]):
857+
new_state_dict[f"transformer_blocks.{i}.attn.{attn_k}.lora_B.weight"] = sample_qkv_lora_up[
858+
i : i + dims[j], j * requested_rank : (j + 1) * requested_rank
859+
].contiguous()
860+
i += dims[j]
861+
862+
# context blocks.
863+
scale_down, scale_up = calculate_scales(f"lora_unet_joint_blocks_{i}_context_block_attn_qkv")
864+
is_sparse, requested_rank = weight_is_sparse(
865+
key=f"lora_unet_joint_blocks_{i}_context_block_attn_qkv",
866+
rank=state_dict[f"lora_unet_joint_blocks_{i}_context_block_attn_qkv.lora_down.weight"].shape[0],
867+
num_splits=3,
868+
up_weight=state_dict[f"lora_unet_joint_blocks_{i}_context_block_attn_qkv.lora_up.weight"],
869+
)
870+
num_splits = 3
871+
sample_qkv_lora_down = (
872+
state_dict.pop(f"lora_unet_joint_blocks_{i}_context_block_attn_qkv.lora_down.weight") * scale_down
873+
)
874+
sample_qkv_lora_up = (
875+
state_dict.pop(f"lora_unet_joint_blocks_{i}_context_block_attn_qkv.lora_up.weight") * scale_up
876+
)
877+
dims = [sample_qkv_lora_up.shape[0] // num_splits] * num_splits # 3 = num_splits
878+
if not is_sparse:
879+
for attn_k in ["add_q_proj", "add_k_proj", "add_v_proj"]:
880+
new_state_dict[f"transformer_blocks.{i}.attn.{attn_k}.lora_A.weight"] = sample_qkv_lora_down
881+
for attn_k, v in zip(
882+
["add_q_proj", "add_k_proj", "add_v_proj"], torch.split(sample_qkv_lora_up, dims, dim=0)
883+
):
884+
new_state_dict[f"transformer_blocks.{i}.attn.{attn_k}.lora_B.weight"] = v
885+
else:
886+
# down_weight is chunked to each split
887+
new_state_dict.update(
888+
{
889+
f"transformer_blocks.{i}.attn.{k}.lora_A.weight": v
890+
for k, v in zip(
891+
["add_q_proj", "add_k_proj", "add_v_proj"],
892+
torch.chunk(sample_qkv_lora_down, num_splits, dim=0),
893+
)
894+
}
895+
) # noqa: C416
896+
897+
# up_weight is sparse: only non-zero values are copied to each split
898+
i = 0
899+
for j, attn_k in enumerate(["add_q_proj", "add_k_proj", "add_v_proj"]):
900+
new_state_dict[f"transformer_blocks.{i}.attn.{attn_k}.lora_B.weight"] = sample_qkv_lora_up[
901+
i : i + dims[j], j * requested_rank : (j + 1) * requested_rank
902+
].contiguous()
903+
i += dims[j]
904+
905+
if len(state_dict) > 0:
906+
raise ValueError(f"`state_dict` should be at this point but has: {list(state_dict.keys())}.")
907+
908+
prefix = prefix or "transformer"
909+
new_state_dict = {f"{prefix}.{k}": v for k, v in new_state_dict.items()}
910+
return new_state_dict

src/diffusers/loaders/lora_pipeline.py

Lines changed: 31 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
from .lora_conversion_utils import (
3737
_convert_kohya_flux_lora_to_diffusers,
3838
_convert_non_diffusers_lora_to_diffusers,
39+
_convert_non_diffusers_sd3_lora_to_diffusers,
3940
_convert_xlabs_flux_lora_to_diffusers,
4041
_maybe_map_sgm_blocks_to_diffusers,
4142
)
@@ -1211,6 +1212,27 @@ def lora_state_dict(
12111212
logger.warning(warn_msg)
12121213
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
12131214

1215+
is_non_diffusers = any("lora_unet" in k for k in state_dict)
1216+
if is_non_diffusers:
1217+
has_only_transformer = all(k.startswith("lora_unet") for k in state_dict)
1218+
if not has_only_transformer:
1219+
state_dict = {k: v for k, v in state_dict.items() if k.startswith("lora_unet")}
1220+
logger.warning(
1221+
"Some keys in the LoRA checkpoint are not related to transformer blocks and we will filter them out during loading. Please open a new issue with the LoRA checkpoint you are trying to load with a reproducible snippet - https://github.com/huggingface/diffusers/issues/new."
1222+
)
1223+
1224+
all_joint_blocks = all("joint_blocks" in k for k in state_dict)
1225+
if not all_joint_blocks:
1226+
raise ValueError(
1227+
"LoRAs containing only transformer blocks are supported at this point. Please open a new issue with the LoRA checkpoint you are trying to load with a reproducible snippet - https://github.com/huggingface/diffusers/issues/new."
1228+
)
1229+
1230+
has_dual_attention_layers = any("attn2" in k for k in state_dict)
1231+
if has_dual_attention_layers:
1232+
raise ValueError("LoRA state dicts with dual attention layers are not supported.")
1233+
1234+
state_dict = _convert_non_diffusers_sd3_lora_to_diffusers(state_dict, prefix=cls.transformer_name)
1235+
12141236
return state_dict
12151237

12161238
def load_lora_weights(
@@ -1255,12 +1277,11 @@ def load_lora_weights(
12551277

12561278
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
12571279
state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
1258-
12591280
is_correct_format = all("lora" in key for key in state_dict.keys())
12601281
if not is_correct_format:
12611282
raise ValueError("Invalid LoRA checkpoint.")
12621283

1263-
transformer_state_dict = {k: v for k, v in state_dict.items() if "transformer." in k}
1284+
transformer_state_dict = {k: v for k, v in state_dict.items() if k.startswith(f"{self.transformer_name}.")}
12641285
if len(transformer_state_dict) > 0:
12651286
self.load_lora_into_transformer(
12661287
state_dict,
@@ -1271,8 +1292,10 @@ def load_lora_weights(
12711292
_pipeline=self,
12721293
low_cpu_mem_usage=low_cpu_mem_usage,
12731294
)
1295+
else:
1296+
logger.debug("No LoRA keys were found for the transformer.")
12741297

1275-
text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k}
1298+
text_encoder_state_dict = {k: v for k, v in state_dict.items() if k.startswith(f"{self.text_encoder_name}.")}
12761299
if len(text_encoder_state_dict) > 0:
12771300
self.load_lora_into_text_encoder(
12781301
text_encoder_state_dict,
@@ -1284,8 +1307,10 @@ def load_lora_weights(
12841307
_pipeline=self,
12851308
low_cpu_mem_usage=low_cpu_mem_usage,
12861309
)
1310+
else:
1311+
logger.debug("No LoRA keys were found for the first text encoder.")
12871312

1288-
text_encoder_2_state_dict = {k: v for k, v in state_dict.items() if "text_encoder_2." in k}
1313+
text_encoder_2_state_dict = {k: v for k, v in state_dict.items() if k.startswith("text_encoder_2.")}
12891314
if len(text_encoder_2_state_dict) > 0:
12901315
self.load_lora_into_text_encoder(
12911316
text_encoder_2_state_dict,
@@ -1297,6 +1322,8 @@ def load_lora_weights(
12971322
_pipeline=self,
12981323
low_cpu_mem_usage=low_cpu_mem_usage,
12991324
)
1325+
else:
1326+
logger.debug("No LoRA keys were found for the second text encoder.")
13001327

13011328
@classmethod
13021329
def load_lora_into_transformer(

0 commit comments

Comments
 (0)