Skip to content

Commit a6158d7

Browse files
committed
fixes
1 parent eac6fd1 commit a6158d7

File tree

2 files changed

+40
-1
lines changed

2 files changed

+40
-1
lines changed

examples/control-lora/README.md

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
# Training Control LoRA with Flux
2+
3+
This example shows how train Control LoRA with Flux to condition it with additional structural controls (like depth maps, poses, etc.).
4+
5+
This is still an experimental version and the following differences exist:
6+
7+
* No use of bias on `lora_B`.
8+
* Mo updates on the norm scales.
9+
10+
We simply expand the input channels of Flux.1 Dev from 64 to 128 to allow for additional inputs and then train a regular LoRA on top of it. To account for the newly added input channels, we additional append a LoRA on the underlying layer (`x_embedder`). Inference, however, is performed with the `FluxControlPipeline`.
11+
12+
Example command:
13+
14+
```bash
15+
accelerate launch train_control_lora_flux.py \
16+
--pretrained_model_name_or_path="black-forest-labs/FLUX.1-dev" \
17+
--dataset_name="raulc0399/open_pose_controlnet" \
18+
--output_dir="pose-control-lora" \
19+
--mixed_precision="bf16" \
20+
--train_batch_size=1 \
21+
--gradient_accumulation_steps=4 \
22+
--gradient_checkpointing \
23+
--use_8bit_adam \
24+
--learning_rate=1e-4 \
25+
--report_to="wandb" \
26+
--lr_scheduler="constant" \
27+
--lr_warmup_steps=0 \
28+
--max_train_steps=5000 \
29+
--validation_image="openpose.png" \
30+
--validation_prompt="A couple, 4k photo, highly detailed" \
31+
--seed="0" \
32+
--push_to_hub
33+
```
34+
35+
You need to install `diffusers` from the branch of [this PR](https://github.com/huggingface/diffusers/pull/9999).

examples/control-lora/train_control_lora_flux.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -817,12 +817,16 @@ def main(args):
817817
new_linear.weight[:, :initial_input_channels].copy_(flux_transformer.x_embedder.weight)
818818
new_linear.bias.copy_(flux_transformer.x_embedder.bias)
819819
flux_transformer.x_embedder = new_linear
820-
flux_transformer.register_config(in_channels=initial_input_channels * 2)
820+
flux_transformer.register_to_config(in_channels=initial_input_channels * 2)
821821

822822
if args.lora_layers is not None:
823823
target_modules = [layer.strip() for layer in args.lora_layers.split(",")]
824+
# add the input layer to the mix.
825+
if "x_embedder" not in target_modules:
826+
target_modules.append("x_embedder")
824827
else:
825828
target_modules = [
829+
"x_embedder",
826830
"attn.to_k",
827831
"attn.to_q",
828832
"attn.to_v",

0 commit comments

Comments
 (0)