Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 15 additions & 2 deletions examples/controlnet/README_flux.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,19 @@ Training script provided by LibAI, which is an institution dedicated to the prog
> [!NOTE]
> **Memory consumption**
>
> Flux can be quite expensive to run on consumer hardware devices and as a result, ControlNet training of it comes with higher memory requirements than usual.
> Flux can be quite expensive to run on consumer hardware devices and as a result, ControlNet training of it comes with higher memory requirements than usual.

here is a gpu memory consumption for reference, tested on a single A100 with 80G.

| period | GPU |
| - | - |
| load as float32 | ~70G |
| mv transformer and vae to bf16 | ~48G |
| pre compute txt embeddings | ~62G |
| **offload te to cpu** | ~30G |
| training | ~58G |
| validation | ~71G |


> **Gated access**
>
Expand Down Expand Up @@ -98,8 +110,9 @@ accelerate launch train_controlnet_flux.py \
--validation_image "./conditioning_image_1.png" "./conditioning_image_2.png" \
--validation_prompt "red circle with blue background" "cyan circle with brown floral background" \
--train_batch_size=1 \
--gradient_accumulation_steps=4 \
--gradient_accumulation_steps=16 \
--report_to="wandb" \
--lr_scheduler="cosine" \
--num_double_layers=4 \
--num_single_layers=0 \
--seed=42 \
Expand Down
6 changes: 4 additions & 2 deletions examples/controlnet/train_controlnet_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def log_validation(
pooled_prompt_embeds=pooled_prompt_embeds,
control_image=validation_image,
num_inference_steps=28,
controlnet_conditioning_scale=0.7,
controlnet_conditioning_scale=1,
guidance_scale=3.5,
generator=generator,
).images[0]
Expand Down Expand Up @@ -1103,7 +1103,9 @@ def compute_embeddings(batch, proportion_empty_prompts, flux_controlnet_pipeline
compute_embeddings_fn, batched=True, new_fingerprint=new_fingerprint, batch_size=50
)

del text_encoders, tokenizers, text_encoder_one, text_encoder_two, tokenizer_one, tokenizer_two
# del text_encoders, tokenizers, text_encoder_one, text_encoder_two, tokenizer_one, tokenizer_two
text_encoder_one.to("cpu")
text_encoder_two.to("cpu")
free_memory()

# Then get the training dataset ready to be passed to the dataloader.
Expand Down
Loading