diff --git a/examples/controlnet/README_flux.md b/examples/controlnet/README_flux.md index d8be36a6e17a..aa5fa251409e 100644 --- a/examples/controlnet/README_flux.md +++ b/examples/controlnet/README_flux.md @@ -6,7 +6,19 @@ Training script provided by LibAI, which is an institution dedicated to the prog > [!NOTE] > **Memory consumption** > -> Flux can be quite expensive to run on consumer hardware devices and as a result, ControlNet training of it comes with higher memory requirements than usual. +> Flux can be quite expensive to run on consumer hardware devices and as a result, ControlNet training of it comes with higher memory requirements than usual. + +Here is a gpu memory consumption for reference, tested on a single A100 with 80G. + +| period | GPU | +| - | - | +| load as float32 | ~70G | +| mv transformer and vae to bf16 | ~48G | +| pre compute txt embeddings | ~62G | +| **offload te to cpu** | ~30G | +| training | ~58G | +| validation | ~71G | + > **Gated access** > @@ -98,8 +110,9 @@ accelerate launch train_controlnet_flux.py \ --validation_image "./conditioning_image_1.png" "./conditioning_image_2.png" \ --validation_prompt "red circle with blue background" "cyan circle with brown floral background" \ --train_batch_size=1 \ - --gradient_accumulation_steps=4 \ + --gradient_accumulation_steps=16 \ --report_to="wandb" \ + --lr_scheduler="cosine" \ --num_double_layers=4 \ --num_single_layers=0 \ --seed=42 \ diff --git a/examples/controlnet/train_controlnet_flux.py b/examples/controlnet/train_controlnet_flux.py index 83965a73286d..edee3fbe557c 100644 --- a/examples/controlnet/train_controlnet_flux.py +++ b/examples/controlnet/train_controlnet_flux.py @@ -148,7 +148,7 @@ def log_validation( pooled_prompt_embeds=pooled_prompt_embeds, control_image=validation_image, num_inference_steps=28, - controlnet_conditioning_scale=0.7, + controlnet_conditioning_scale=1, guidance_scale=3.5, generator=generator, ).images[0] @@ -1085,8 +1085,6 @@ def compute_embeddings(batch, proportion_empty_prompts, flux_controlnet_pipeline return {"prompt_embeds": prompt_embeds, "pooled_prompt_embeds": pooled_prompt_embeds, "text_ids": text_ids} train_dataset = get_train_dataset(args, accelerator) - text_encoders = [text_encoder_one, text_encoder_two] - tokenizers = [tokenizer_one, tokenizer_two] compute_embeddings_fn = functools.partial( compute_embeddings, flux_controlnet_pipeline=flux_controlnet_pipeline, @@ -1103,7 +1101,8 @@ def compute_embeddings(batch, proportion_empty_prompts, flux_controlnet_pipeline compute_embeddings_fn, batched=True, new_fingerprint=new_fingerprint, batch_size=50 ) - del text_encoders, tokenizers, text_encoder_one, text_encoder_two, tokenizer_one, tokenizer_two + text_encoder_one.to("cpu") + text_encoder_two.to("cpu") free_memory() # Then get the training dataset ready to be passed to the dataloader.