Skip to content

Commit 6e2cb75

Browse files
committed
make lora target modules configurable and change the default
1 parent 9373c0a commit 6e2cb75

File tree

1 file changed

+12
-2
lines changed

1 file changed

+12
-2
lines changed

examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -657,7 +657,12 @@ def parse_args(input_args=None):
657657
parser.add_argument(
658658
"--adam_weight_decay_text_encoder", type=float, default=1e-03, help="Weight decay to use for text_encoder"
659659
)
660-
660+
parser.add_argument(
661+
"--lora_blocks",
662+
type=str,
663+
default=None,
664+
help=('The transformer modules to apply LoRA training on'),
665+
)
661666
parser.add_argument(
662667
"--adam_epsilon",
663668
type=float,
@@ -1582,12 +1587,17 @@ def main(args):
15821587
if args.train_text_encoder:
15831588
text_encoder_one.gradient_checkpointing_enable()
15841589

1590+
if args.lora_blocks is not None:
1591+
target_modules = [block.strip() for block in args.lora_blocks.split(",")]
1592+
else:
1593+
target_modules = ["to_k", "to_q", "to_v", "to_out.0",
1594+
"add_k_proj", "add_q_proj", "add_v_proj", "to_add_out", "ff.net.0.proj","ff.net.2", "ff_context.net.0.proj","ff_context.net.2"]
15851595
# now we will add new LoRA weights to the attention layers
15861596
transformer_lora_config = LoraConfig(
15871597
r=args.rank,
15881598
lora_alpha=args.rank,
15891599
init_lora_weights="gaussian",
1590-
target_modules=["to_k", "to_q", "to_v", "to_out.0"],
1600+
target_modules=target_modules,
15911601
)
15921602
transformer.add_adapter(transformer_lora_config)
15931603

0 commit comments

Comments
 (0)