Skip to content

Commit 90b9479

Browse files
authored
[LoRA PEFT] fix LoRA loading so that correct alphas are parsed (#6225)
* initialize alpha too. * add: test * remove config parsing * store rank * debug * remove faulty test
1 parent df76a39 commit 90b9479

File tree

5 files changed

+31
-8
lines changed

5 files changed

+31
-8
lines changed

examples/dreambooth/train_dreambooth_lora.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -827,6 +827,7 @@ def main(args):
827827
# now we will add new LoRA weights to the attention layers
828828
unet_lora_config = LoraConfig(
829829
r=args.rank,
830+
lora_alpha=args.rank,
830831
init_lora_weights="gaussian",
831832
target_modules=["to_k", "to_q", "to_v", "to_out.0", "add_k_proj", "add_v_proj"],
832833
)
@@ -835,7 +836,10 @@ def main(args):
835836
# The text encoder comes from 🤗 transformers, we will also attach adapters to it.
836837
if args.train_text_encoder:
837838
text_lora_config = LoraConfig(
838-
r=args.rank, init_lora_weights="gaussian", target_modules=["q_proj", "k_proj", "v_proj", "out_proj"]
839+
r=args.rank,
840+
lora_alpha=args.rank,
841+
init_lora_weights="gaussian",
842+
target_modules=["q_proj", "k_proj", "v_proj", "out_proj"],
839843
)
840844
text_encoder.add_adapter(text_lora_config)
841845

examples/dreambooth/train_dreambooth_lora_sdxl.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -978,15 +978,21 @@ def main(args):
978978

979979
# now we will add new LoRA weights to the attention layers
980980
unet_lora_config = LoraConfig(
981-
r=args.rank, init_lora_weights="gaussian", target_modules=["to_k", "to_q", "to_v", "to_out.0"]
981+
r=args.rank,
982+
lora_alpha=args.rank,
983+
init_lora_weights="gaussian",
984+
target_modules=["to_k", "to_q", "to_v", "to_out.0"],
982985
)
983986
unet.add_adapter(unet_lora_config)
984987

985988
# The text encoder comes from 🤗 transformers, so we cannot directly modify it.
986989
# So, instead, we monkey-patch the forward calls of its attention-blocks.
987990
if args.train_text_encoder:
988991
text_lora_config = LoraConfig(
989-
r=args.rank, init_lora_weights="gaussian", target_modules=["q_proj", "k_proj", "v_proj", "out_proj"]
992+
r=args.rank,
993+
lora_alpha=args.rank,
994+
init_lora_weights="gaussian",
995+
target_modules=["q_proj", "k_proj", "v_proj", "out_proj"],
990996
)
991997
text_encoder_one.add_adapter(text_lora_config)
992998
text_encoder_two.add_adapter(text_lora_config)

examples/text_to_image/train_text_to_image_lora.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -452,7 +452,10 @@ def main():
452452
param.requires_grad_(False)
453453

454454
unet_lora_config = LoraConfig(
455-
r=args.rank, init_lora_weights="gaussian", target_modules=["to_k", "to_q", "to_v", "to_out.0"]
455+
r=args.rank,
456+
lora_alpha=args.rank,
457+
init_lora_weights="gaussian",
458+
target_modules=["to_k", "to_q", "to_v", "to_out.0"],
456459
)
457460

458461
# Move unet, vae and text_encoder to device and cast to weight_dtype

examples/text_to_image/train_text_to_image_lora_sdxl.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -609,7 +609,10 @@ def main(args):
609609
# now we will add new LoRA weights to the attention layers
610610
# Set correct lora layers
611611
unet_lora_config = LoraConfig(
612-
r=args.rank, init_lora_weights="gaussian", target_modules=["to_k", "to_q", "to_v", "to_out.0"]
612+
r=args.rank,
613+
lora_alpha=args.rank,
614+
init_lora_weights="gaussian",
615+
target_modules=["to_k", "to_q", "to_v", "to_out.0"],
613616
)
614617

615618
unet.add_adapter(unet_lora_config)
@@ -618,7 +621,10 @@ def main(args):
618621
if args.train_text_encoder:
619622
# ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16
620623
text_lora_config = LoraConfig(
621-
r=args.rank, init_lora_weights="gaussian", target_modules=["q_proj", "k_proj", "v_proj", "out_proj"]
624+
r=args.rank,
625+
lora_alpha=args.rank,
626+
init_lora_weights="gaussian",
627+
target_modules=["q_proj", "k_proj", "v_proj", "out_proj"],
622628
)
623629
text_encoder_one.add_adapter(text_lora_config)
624630
text_encoder_two.add_adapter(text_lora_config)

tests/lora/test_lora_layers_peft.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@ class PeftLoraLoaderMixinTests:
111111

112112
def get_dummy_components(self, scheduler_cls=None):
113113
scheduler_cls = self.scheduler_cls if scheduler_cls is None else LCMScheduler
114+
rank = 4
114115

115116
torch.manual_seed(0)
116117
unet = UNet2DConditionModel(**self.unet_kwargs)
@@ -125,11 +126,14 @@ def get_dummy_components(self, scheduler_cls=None):
125126
tokenizer_2 = CLIPTokenizer.from_pretrained("peft-internal-testing/tiny-clip-text-2")
126127

127128
text_lora_config = LoraConfig(
128-
r=4, lora_alpha=4, target_modules=["q_proj", "k_proj", "v_proj", "out_proj"], init_lora_weights=False
129+
r=rank,
130+
lora_alpha=rank,
131+
target_modules=["q_proj", "k_proj", "v_proj", "out_proj"],
132+
init_lora_weights=False,
129133
)
130134

131135
unet_lora_config = LoraConfig(
132-
r=4, lora_alpha=4, target_modules=["to_q", "to_k", "to_v", "to_out.0"], init_lora_weights=False
136+
r=rank, lora_alpha=rank, target_modules=["to_q", "to_k", "to_v", "to_out.0"], init_lora_weights=False
133137
)
134138

135139
unet_lora_attn_procs, unet_lora_layers = create_unet_lora_layers(unet)

0 commit comments

Comments
 (0)