Skip to content

Commit 2a3fbc2

Browse files
sayakpaulapolinarioLeommm-byte
authored
[LoRA] support kohya and xlabs loras for flux. (#9295)
* support kohya lora in flux. * format * support xlabs * diffusion_model prefix. * Apply suggestions from code review Co-authored-by: apolinário <[email protected]> * empty commit. Co-authored-by: Leommm-byte <[email protected]> --------- Co-authored-by: apolinário <[email protected]> Co-authored-by: Leommm-byte <[email protected]>
1 parent 089cf79 commit 2a3fbc2

File tree

2 files changed

+313
-1
lines changed

2 files changed

+313
-1
lines changed

src/diffusers/loaders/lora_conversion_utils.py

Lines changed: 293 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414

1515
import re
1616

17+
import torch
18+
1719
from ..utils import is_peft_version, logging
1820

1921

@@ -326,3 +328,294 @@ def _get_alpha_name(lora_name_alpha, diffusers_name, alpha):
326328
prefix = "text_encoder_2."
327329
new_name = prefix + diffusers_name.split(".lora.")[0] + ".alpha"
328330
return {new_name: alpha}
331+
332+
333+
# The utilities under `_convert_kohya_flux_lora_to_diffusers()`
334+
# are taken from https://github.com/kohya-ss/sd-scripts/blob/a61cf73a5cb5209c3f4d1a3688dd276a4dfd1ecb/networks/convert_flux_lora.py
335+
# All credits go to `kohya-ss`.
336+
def _convert_kohya_flux_lora_to_diffusers(state_dict):
337+
def _convert_to_ai_toolkit(sds_sd, ait_sd, sds_key, ait_key):
338+
if sds_key + ".lora_down.weight" not in sds_sd:
339+
return
340+
down_weight = sds_sd.pop(sds_key + ".lora_down.weight")
341+
342+
# scale weight by alpha and dim
343+
rank = down_weight.shape[0]
344+
alpha = sds_sd.pop(sds_key + ".alpha").item() # alpha is scalar
345+
scale = alpha / rank # LoRA is scaled by 'alpha / rank' in forward pass, so we need to scale it back here
346+
347+
# calculate scale_down and scale_up to keep the same value. if scale is 4, scale_down is 2 and scale_up is 2
348+
scale_down = scale
349+
scale_up = 1.0
350+
while scale_down * 2 < scale_up:
351+
scale_down *= 2
352+
scale_up /= 2
353+
354+
ait_sd[ait_key + ".lora_A.weight"] = down_weight * scale_down
355+
ait_sd[ait_key + ".lora_B.weight"] = sds_sd.pop(sds_key + ".lora_up.weight") * scale_up
356+
357+
def _convert_to_ai_toolkit_cat(sds_sd, ait_sd, sds_key, ait_keys, dims=None):
358+
if sds_key + ".lora_down.weight" not in sds_sd:
359+
return
360+
down_weight = sds_sd.pop(sds_key + ".lora_down.weight")
361+
up_weight = sds_sd.pop(sds_key + ".lora_up.weight")
362+
sd_lora_rank = down_weight.shape[0]
363+
364+
# scale weight by alpha and dim
365+
alpha = sds_sd.pop(sds_key + ".alpha")
366+
scale = alpha / sd_lora_rank
367+
368+
# calculate scale_down and scale_up
369+
scale_down = scale
370+
scale_up = 1.0
371+
while scale_down * 2 < scale_up:
372+
scale_down *= 2
373+
scale_up /= 2
374+
375+
down_weight = down_weight * scale_down
376+
up_weight = up_weight * scale_up
377+
378+
# calculate dims if not provided
379+
num_splits = len(ait_keys)
380+
if dims is None:
381+
dims = [up_weight.shape[0] // num_splits] * num_splits
382+
else:
383+
assert sum(dims) == up_weight.shape[0]
384+
385+
# check upweight is sparse or not
386+
is_sparse = False
387+
if sd_lora_rank % num_splits == 0:
388+
ait_rank = sd_lora_rank // num_splits
389+
is_sparse = True
390+
i = 0
391+
for j in range(len(dims)):
392+
for k in range(len(dims)):
393+
if j == k:
394+
continue
395+
is_sparse = is_sparse and torch.all(
396+
up_weight[i : i + dims[j], k * ait_rank : (k + 1) * ait_rank] == 0
397+
)
398+
i += dims[j]
399+
if is_sparse:
400+
logger.info(f"weight is sparse: {sds_key}")
401+
402+
# make ai-toolkit weight
403+
ait_down_keys = [k + ".lora_A.weight" for k in ait_keys]
404+
ait_up_keys = [k + ".lora_B.weight" for k in ait_keys]
405+
if not is_sparse:
406+
# down_weight is copied to each split
407+
ait_sd.update({k: down_weight for k in ait_down_keys})
408+
409+
# up_weight is split to each split
410+
ait_sd.update({k: v for k, v in zip(ait_up_keys, torch.split(up_weight, dims, dim=0))}) # noqa: C416
411+
else:
412+
# down_weight is chunked to each split
413+
ait_sd.update({k: v for k, v in zip(ait_down_keys, torch.chunk(down_weight, num_splits, dim=0))}) # noqa: C416
414+
415+
# up_weight is sparse: only non-zero values are copied to each split
416+
i = 0
417+
for j in range(len(dims)):
418+
ait_sd[ait_up_keys[j]] = up_weight[i : i + dims[j], j * ait_rank : (j + 1) * ait_rank].contiguous()
419+
i += dims[j]
420+
421+
def _convert_sd_scripts_to_ai_toolkit(sds_sd):
422+
ait_sd = {}
423+
for i in range(19):
424+
_convert_to_ai_toolkit(
425+
sds_sd,
426+
ait_sd,
427+
f"lora_unet_double_blocks_{i}_img_attn_proj",
428+
f"transformer.transformer_blocks.{i}.attn.to_out.0",
429+
)
430+
_convert_to_ai_toolkit_cat(
431+
sds_sd,
432+
ait_sd,
433+
f"lora_unet_double_blocks_{i}_img_attn_qkv",
434+
[
435+
f"transformer.transformer_blocks.{i}.attn.to_q",
436+
f"transformer.transformer_blocks.{i}.attn.to_k",
437+
f"transformer.transformer_blocks.{i}.attn.to_v",
438+
],
439+
)
440+
_convert_to_ai_toolkit(
441+
sds_sd,
442+
ait_sd,
443+
f"lora_unet_double_blocks_{i}_img_mlp_0",
444+
f"transformer.transformer_blocks.{i}.ff.net.0.proj",
445+
)
446+
_convert_to_ai_toolkit(
447+
sds_sd,
448+
ait_sd,
449+
f"lora_unet_double_blocks_{i}_img_mlp_2",
450+
f"transformer.transformer_blocks.{i}.ff.net.2",
451+
)
452+
_convert_to_ai_toolkit(
453+
sds_sd,
454+
ait_sd,
455+
f"lora_unet_double_blocks_{i}_img_mod_lin",
456+
f"transformer.transformer_blocks.{i}.norm1.linear",
457+
)
458+
_convert_to_ai_toolkit(
459+
sds_sd,
460+
ait_sd,
461+
f"lora_unet_double_blocks_{i}_txt_attn_proj",
462+
f"transformer.transformer_blocks.{i}.attn.to_add_out",
463+
)
464+
_convert_to_ai_toolkit_cat(
465+
sds_sd,
466+
ait_sd,
467+
f"lora_unet_double_blocks_{i}_txt_attn_qkv",
468+
[
469+
f"transformer.transformer_blocks.{i}.attn.add_q_proj",
470+
f"transformer.transformer_blocks.{i}.attn.add_k_proj",
471+
f"transformer.transformer_blocks.{i}.attn.add_v_proj",
472+
],
473+
)
474+
_convert_to_ai_toolkit(
475+
sds_sd,
476+
ait_sd,
477+
f"lora_unet_double_blocks_{i}_txt_mlp_0",
478+
f"transformer.transformer_blocks.{i}.ff_context.net.0.proj",
479+
)
480+
_convert_to_ai_toolkit(
481+
sds_sd,
482+
ait_sd,
483+
f"lora_unet_double_blocks_{i}_txt_mlp_2",
484+
f"transformer.transformer_blocks.{i}.ff_context.net.2",
485+
)
486+
_convert_to_ai_toolkit(
487+
sds_sd,
488+
ait_sd,
489+
f"lora_unet_double_blocks_{i}_txt_mod_lin",
490+
f"transformer.transformer_blocks.{i}.norm1_context.linear",
491+
)
492+
493+
for i in range(38):
494+
_convert_to_ai_toolkit_cat(
495+
sds_sd,
496+
ait_sd,
497+
f"lora_unet_single_blocks_{i}_linear1",
498+
[
499+
f"transformer.single_transformer_blocks.{i}.attn.to_q",
500+
f"transformer.single_transformer_blocks.{i}.attn.to_k",
501+
f"transformer.single_transformer_blocks.{i}.attn.to_v",
502+
f"transformer.single_transformer_blocks.{i}.proj_mlp",
503+
],
504+
dims=[3072, 3072, 3072, 12288],
505+
)
506+
_convert_to_ai_toolkit(
507+
sds_sd,
508+
ait_sd,
509+
f"lora_unet_single_blocks_{i}_linear2",
510+
f"transformer.single_transformer_blocks.{i}.proj_out",
511+
)
512+
_convert_to_ai_toolkit(
513+
sds_sd,
514+
ait_sd,
515+
f"lora_unet_single_blocks_{i}_modulation_lin",
516+
f"transformer.single_transformer_blocks.{i}.norm.linear",
517+
)
518+
519+
if len(sds_sd) > 0:
520+
logger.warning(f"Unsuppored keys for ai-toolkit: {sds_sd.keys()}")
521+
522+
return ait_sd
523+
524+
return _convert_sd_scripts_to_ai_toolkit(state_dict)
525+
526+
527+
# Adapted from https://gist.github.com/Leommm-byte/6b331a1e9bd53271210b26543a7065d6
528+
# Some utilities were reused from
529+
# https://github.com/kohya-ss/sd-scripts/blob/a61cf73a5cb5209c3f4d1a3688dd276a4dfd1ecb/networks/convert_flux_lora.py
530+
def _convert_xlabs_flux_lora_to_diffusers(old_state_dict):
531+
new_state_dict = {}
532+
orig_keys = list(old_state_dict.keys())
533+
534+
def handle_qkv(sds_sd, ait_sd, sds_key, ait_keys, dims=None):
535+
down_weight = sds_sd.pop(sds_key)
536+
up_weight = sds_sd.pop(sds_key.replace(".down.weight", ".up.weight"))
537+
538+
# calculate dims if not provided
539+
num_splits = len(ait_keys)
540+
if dims is None:
541+
dims = [up_weight.shape[0] // num_splits] * num_splits
542+
else:
543+
assert sum(dims) == up_weight.shape[0]
544+
545+
# make ai-toolkit weight
546+
ait_down_keys = [k + ".lora_A.weight" for k in ait_keys]
547+
ait_up_keys = [k + ".lora_B.weight" for k in ait_keys]
548+
549+
# down_weight is copied to each split
550+
ait_sd.update({k: down_weight for k in ait_down_keys})
551+
552+
# up_weight is split to each split
553+
ait_sd.update({k: v for k, v in zip(ait_up_keys, torch.split(up_weight, dims, dim=0))}) # noqa: C416
554+
555+
for old_key in orig_keys:
556+
# Handle double_blocks
557+
if old_key.startswith(("diffusion_model.double_blocks", "double_blocks")):
558+
block_num = re.search(r"double_blocks\.(\d+)", old_key).group(1)
559+
new_key = f"transformer.transformer_blocks.{block_num}"
560+
561+
if "processor.proj_lora1" in old_key:
562+
new_key += ".attn.to_out.0"
563+
elif "processor.proj_lora2" in old_key:
564+
new_key += ".attn.to_add_out"
565+
elif "processor.qkv_lora1" in old_key and "up" not in old_key:
566+
handle_qkv(
567+
old_state_dict,
568+
new_state_dict,
569+
old_key,
570+
[
571+
f"transformer.transformer_blocks.{block_num}.attn.add_q_proj",
572+
f"transformer.transformer_blocks.{block_num}.attn.add_k_proj",
573+
f"transformer.transformer_blocks.{block_num}.attn.add_v_proj",
574+
],
575+
)
576+
# continue
577+
elif "processor.qkv_lora2" in old_key and "up" not in old_key:
578+
handle_qkv(
579+
old_state_dict,
580+
new_state_dict,
581+
old_key,
582+
[
583+
f"transformer.transformer_blocks.{block_num}.attn.to_q",
584+
f"transformer.transformer_blocks.{block_num}.attn.to_k",
585+
f"transformer.transformer_blocks.{block_num}.attn.to_v",
586+
],
587+
)
588+
# continue
589+
590+
if "down" in old_key:
591+
new_key += ".lora_A.weight"
592+
elif "up" in old_key:
593+
new_key += ".lora_B.weight"
594+
595+
# Handle single_blocks
596+
elif old_key.startswith("diffusion_model.single_blocks", "single_blocks"):
597+
block_num = re.search(r"single_blocks\.(\d+)", old_key).group(1)
598+
new_key = f"transformer.single_transformer_blocks.{block_num}"
599+
600+
if "proj_lora1" in old_key or "proj_lora2" in old_key:
601+
new_key += ".proj_out"
602+
elif "qkv_lora1" in old_key or "qkv_lora2" in old_key:
603+
new_key += ".norm.linear"
604+
605+
if "down" in old_key:
606+
new_key += ".lora_A.weight"
607+
elif "up" in old_key:
608+
new_key += ".lora_B.weight"
609+
610+
else:
611+
# Handle other potential key patterns here
612+
new_key = old_key
613+
614+
# Since we already handle qkv above.
615+
if "qkv" not in old_key:
616+
new_state_dict[new_key] = old_state_dict.pop(old_key)
617+
618+
if len(old_state_dict) > 0:
619+
raise ValueError(f"`old_state_dict` should be at this point but has: {list(old_state_dict.keys())}.")
620+
621+
return new_state_dict

src/diffusers/loaders/lora_pipeline.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,12 @@
3131
scale_lora_layers,
3232
)
3333
from .lora_base import LoraBaseMixin
34-
from .lora_conversion_utils import _convert_non_diffusers_lora_to_diffusers, _maybe_map_sgm_blocks_to_diffusers
34+
from .lora_conversion_utils import (
35+
_convert_kohya_flux_lora_to_diffusers,
36+
_convert_non_diffusers_lora_to_diffusers,
37+
_convert_xlabs_flux_lora_to_diffusers,
38+
_maybe_map_sgm_blocks_to_diffusers,
39+
)
3540

3641

3742
if is_transformers_available():
@@ -1583,6 +1588,20 @@ def lora_state_dict(
15831588
allow_pickle=allow_pickle,
15841589
)
15851590

1591+
# TODO (sayakpaul): to a follow-up to clean and try to unify the conditions.
1592+
1593+
is_kohya = any(".lora_down.weight" in k for k in state_dict)
1594+
if is_kohya:
1595+
state_dict = _convert_kohya_flux_lora_to_diffusers(state_dict)
1596+
# Kohya already takes care of scaling the LoRA parameters with alpha.
1597+
return (state_dict, None) if return_alphas else state_dict
1598+
1599+
is_xlabs = any("processor" in k for k in state_dict)
1600+
if is_xlabs:
1601+
state_dict = _convert_xlabs_flux_lora_to_diffusers(state_dict)
1602+
# xlabs doesn't use `alpha`.
1603+
return (state_dict, None) if return_alphas else state_dict
1604+
15861605
# For state dicts like
15871606
# https://huggingface.co/TheLastBen/Jon_Snow_Flux_LoRA
15881607
keys = list(state_dict.keys())

0 commit comments

Comments
 (0)