Skip to content

Commit 78b87dc

Browse files
authored
[LoRA] make LoRAs trained with peft loadable when peft isn't installed (#6306)
* spit diffusers-native format from the get go. * rejig the peft_to_diffusers mapping.
1 parent 0af12f1 commit 78b87dc

File tree

2 files changed

+23
-7
lines changed

2 files changed

+23
-7
lines changed

examples/dreambooth/train_dreambooth_lora_sdxl.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@
5454
from diffusers.loaders import LoraLoaderMixin
5555
from diffusers.optimization import get_scheduler
5656
from diffusers.training_utils import compute_snr
57-
from diffusers.utils import check_min_version, is_wandb_available
57+
from diffusers.utils import check_min_version, convert_state_dict_to_diffusers, is_wandb_available
5858
from diffusers.utils.import_utils import is_xformers_available
5959

6060

@@ -1019,11 +1019,15 @@ def save_model_hook(models, weights, output_dir):
10191019

10201020
for model in models:
10211021
if isinstance(model, type(accelerator.unwrap_model(unet))):
1022-
unet_lora_layers_to_save = get_peft_model_state_dict(model)
1022+
unet_lora_layers_to_save = convert_state_dict_to_diffusers(get_peft_model_state_dict(model))
10231023
elif isinstance(model, type(accelerator.unwrap_model(text_encoder_one))):
1024-
text_encoder_one_lora_layers_to_save = get_peft_model_state_dict(model)
1024+
text_encoder_one_lora_layers_to_save = convert_state_dict_to_diffusers(
1025+
get_peft_model_state_dict(model)
1026+
)
10251027
elif isinstance(model, type(accelerator.unwrap_model(text_encoder_two))):
1026-
text_encoder_two_lora_layers_to_save = get_peft_model_state_dict(model)
1028+
text_encoder_two_lora_layers_to_save = convert_state_dict_to_diffusers(
1029+
get_peft_model_state_dict(model)
1030+
)
10271031
else:
10281032
raise ValueError(f"unexpected save model: {model.__class__}")
10291033

@@ -1615,13 +1619,17 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
16151619
if accelerator.is_main_process:
16161620
unet = accelerator.unwrap_model(unet)
16171621
unet = unet.to(torch.float32)
1618-
unet_lora_layers = get_peft_model_state_dict(unet)
1622+
unet_lora_layers = convert_state_dict_to_diffusers(get_peft_model_state_dict(unet))
16191623

16201624
if args.train_text_encoder:
16211625
text_encoder_one = accelerator.unwrap_model(text_encoder_one)
1622-
text_encoder_lora_layers = get_peft_model_state_dict(text_encoder_one.to(torch.float32))
1626+
text_encoder_lora_layers = convert_state_dict_to_diffusers(
1627+
get_peft_model_state_dict(text_encoder_one.to(torch.float32))
1628+
)
16231629
text_encoder_two = accelerator.unwrap_model(text_encoder_two)
1624-
text_encoder_2_lora_layers = get_peft_model_state_dict(text_encoder_two.to(torch.float32))
1630+
text_encoder_2_lora_layers = convert_state_dict_to_diffusers(
1631+
get_peft_model_state_dict(text_encoder_two.to(torch.float32))
1632+
)
16251633
else:
16261634
text_encoder_lora_layers = None
16271635
text_encoder_2_lora_layers = None

src/diffusers/utils/state_dict_utils.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,14 @@ class StateDictType(enum.Enum):
7979
".v_proj.lora_A": ".v_proj.lora_linear_layer.down",
8080
".out_proj.lora_B": ".out_proj.lora_linear_layer.up",
8181
".out_proj.lora_A": ".out_proj.lora_linear_layer.down",
82+
"to_k.lora_A": "to_k.lora.down",
83+
"to_k.lora_B": "to_k.lora.up",
84+
"to_q.lora_A": "to_q.lora.down",
85+
"to_q.lora_B": "to_q.lora.up",
86+
"to_v.lora_A": "to_v.lora.down",
87+
"to_v.lora_B": "to_v.lora.up",
88+
"to_out.0.lora_A": "to_out.0.lora.down",
89+
"to_out.0.lora_B": "to_out.0.lora.up",
8290
}
8391

8492
DIFFUSERS_OLD_TO_DIFFUSERS = {

0 commit comments

Comments
 (0)