Skip to content

Commit e4fe609

Browse files
committed
make lora target modules configurable and change the default, add notes to readme
1 parent 366a35e commit e4fe609

File tree

2 files changed

+29
-13
lines changed

2 files changed

+29
-13
lines changed

examples/advanced_diffusion_training/README_flux.md

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
> 💡 This example follows some of the techniques and recommended practices covered in the community derived guide we made for SDXL training: [LoRA training scripts of the world, unite!](https://huggingface.co/blog/sdxl_lora_advanced_script).
66
> As many of these are architecture agnostic & generally relevant to fine-tuning of diffusion models we suggest to take a look 🤗
77
8-
[DreamBooth](https://arxiv.org/abs/2208.12242) is a method to personalize text2image models like flux, stable diffusion given just a few(3~5) images of a subject.
8+
[DreamBooth](https://arxiv.org/abs/2208.12242) is a method to personalize text-to-image models like flux, stable diffusion given just a few(3~5) images of a subject.
99

1010
LoRA - Low-Rank Adaption of Large Language Models, was first introduced by Microsoft in [LoRA: Low-Rank Adaptation of Large Language Models](https://arxiv.org/abs/2106.09685) by *Edward J. Hu, Yelong Shen, Phillip Wallis, Zeyuan Allen-Zhu, Yuanzhi Li, Shean Wang, Lu Wang, Weizhu Chen*
1111
In a nutshell, LoRA allows to adapt pretrained models by adding pairs of rank-decomposition matrices to existing weights and **only** training those newly added weights. This has a couple of advantages:
@@ -65,6 +65,21 @@ write_basic_config()
6565
When running `accelerate config`, if we specify torch compile mode to True there can be dramatic speedups.
6666
Note also that we use PEFT library as backend for LoRA training, make sure to have `peft>=0.6.0` installed in your environment.
6767

68+
### Target Modules
69+
When LoRA was first adapted from language models to diffusion models, it was applied to the cross-attention layers in the Unet that relate the image representations with the prompts that describe them.
70+
More recently, SOTA text-to-image diffusion models replaced the Unet with a diffusion Transformer(DiT). With this change, we may also want to explore
71+
applying LoRA training onto different types of layers and blocks. To allow more flexibility and control over the targeted modules we added `--lora_layers`- in which you can specify in a comma seperated string
72+
the exact modules for LoRA training. Here are some examples of target modules you can provide:
73+
- for attention only layers: `--lora_layers="attn.to_k,attn.to_q,attn.to_v,attn.to_out.0"`
74+
- to train the same modules as in the fal trainer: `--lora_layers="attn.to_k,attn.to_q,attn.to_v,attn.to_out.0,attn.add_k_proj,attn.add_q_proj,attn.add_v_proj,attn.to_add_out,ff.net.0.proj,ff.net.2,ff_context.net.0.proj,ff_context.net.2"`
75+
- to train the same modules as in ostris ai-toolkit / replicate trainer: `--lora_blocks="attn.to_k,attn.to_q,attn.to_v,attn.to_out.0,attn.add_k_proj,attn.add_q_proj,attn.add_v_proj,attn.to_add_out,ff.net.0.proj,ff.net.2,ff_context.net.0.proj,ff_context.net.2,norm1_context.linear, norm1.linear,norm.linear,proj_mlp,proj_out"`
76+
> [!NOTE]
77+
> `--lora_layers` can also be used to specify which **blocks** to apply LoRA training to. To do so, simply add a block prefix to each layer in the comma seperated string:
78+
> **single DiT blocks**: to target the ith single transformer block, add the prefix `single_transformer_blocks.i`, e.g. - `single_transformer_blocks.i.attn.to_k`
79+
> **MMDiT blocks**: to target the ith MMDiT block, add the prefix `transformer_blocks.i`, e.g. - `transformer_blocks.i.attn.to_k`
80+
> [!NOTE]
81+
> keep in mind that while training more layers can improve quality and expressiveness, it also increases the size of the output LoRA weights.
82+
6883
### Pivotal Tuning (and more)
6984
**Training with text encoder(s)**
7085

examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -658,11 +658,12 @@ def parse_args(input_args=None):
658658
"--adam_weight_decay_text_encoder", type=float, default=1e-03, help="Weight decay to use for text_encoder"
659659
)
660660
parser.add_argument(
661-
"--lora_blocks",
661+
"--lora_layers",
662662
type=str,
663663
default=None,
664664
help=(
665-
'The transformer modules to apply LoRA training on. Please specify the layers in a comma seperated. E.g. - "q_proj,k_proj,v_proj,out_proj" will result in lora training of attention layers only'
665+
'The transformer modules to apply LoRA training on. Please specify the layers in a comma seperated. '
666+
'E.g. - "to_k,to_q,to_v,to_out.0" will result in lora training of attention layers only. For more examples refer to https://github.com/huggingface/diffusers/blob/main/examples/advanced_diffusion_training/README_flux.md'
666667
),
667668
)
668669
parser.add_argument(
@@ -1589,18 +1590,18 @@ def main(args):
15891590
if args.train_text_encoder:
15901591
text_encoder_one.gradient_checkpointing_enable()
15911592

1592-
if args.lora_blocks is not None:
1593-
target_modules = [block.strip() for block in args.lora_blocks.split(",")]
1593+
if args.lora_layers is not None:
1594+
target_modules = [layer.strip() for layer in args.lora_layers.split(",")]
15941595
else:
15951596
target_modules = [
1596-
"to_k",
1597-
"to_q",
1598-
"to_v",
1599-
"to_out.0",
1600-
"add_k_proj",
1601-
"add_q_proj",
1602-
"add_v_proj",
1603-
"to_add_out",
1597+
"attn.to_k",
1598+
"attn.to_q",
1599+
"attn.to_v",
1600+
"attn.to_out.0",
1601+
"attn.add_k_proj",
1602+
"attn.add_q_proj",
1603+
"attn.add_v_proj",
1604+
"attn.to_add_out",
16041605
"ff.net.0.proj",
16051606
"ff.net.2",
16061607
"ff_context.net.0.proj",

0 commit comments

Comments
 (0)