Skip to content

Commit ff5511c

Browse files
committed
make lora target modules configurable and change the default
1 parent 31d8576 commit ff5511c

File tree

1 file changed

+12
-12
lines changed

1 file changed

+12
-12
lines changed

examples/dreambooth/train_dreambooth_lora_flux.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -555,11 +555,11 @@ def parse_args(input_args=None):
555555
)
556556

557557
parser.add_argument(
558-
"--lora_blocks",
558+
"--lora_layers",
559559
type=str,
560560
default=None,
561561
help=(
562-
'The transformer modules to apply LoRA training on. Please specify the layers in a comma seperated. E.g. - "q_proj,k_proj,v_proj,out_proj" will result in lora training of attention layers only'
562+
'The transformer modules to apply LoRA training on. Please specify the layers in a comma seperated. E.g. - "to_k,to_q,to_v,to_out.0" will result in lora training of attention layers only'
563563
),
564564
)
565565

@@ -1197,18 +1197,18 @@ def main(args):
11971197
if args.train_text_encoder:
11981198
text_encoder_one.gradient_checkpointing_enable()
11991199

1200-
if args.lora_blocks is not None:
1201-
target_modules = [block.strip() for block in args.lora_blocks.split(",")]
1200+
if args.lora_layers is not None:
1201+
target_modules = [layer.strip() for layer in args.lora_layers.split(",")]
12021202
else:
12031203
target_modules = [
1204-
"to_k",
1205-
"to_q",
1206-
"to_v",
1207-
"to_out.0",
1208-
"add_k_proj",
1209-
"add_q_proj",
1210-
"add_v_proj",
1211-
"to_add_out",
1204+
"attn.to_k",
1205+
"attn.to_q",
1206+
"attn.to_v",
1207+
"attn.to_out.0",
1208+
"attn.add_k_proj",
1209+
"attn.add_q_proj",
1210+
"attn.add_v_proj",
1211+
"attn.to_add_out",
12121212
"ff.net.0.proj",
12131213
"ff.net.2",
12141214
"ff_context.net.0.proj",

0 commit comments

Comments
 (0)