Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions examples/dreambooth/README_flux2.md
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,9 @@ Flux.2 uses Mistral Small 3.1 as text encoder which is quite large and can take
This way, the text encoder model is not loaded into memory during training.
> [!NOTE]
> to enable remote text encoding you must either be logged in to your HuggingFace account (`hf auth login`) OR pass a token with `--hub_token`.
### FSDP Text Encoder
Flux.2 uses Mistral Small 3.1 as text encoder which is quite large and can take up a lot of memory. To mitigate this, we can use the `--fsdp_text_encoder` flag to enable distributed computation of the prompt embeddings.
This way, it distributes the memory cost across multiple nodes.
### CPU Offloading
To offload parts of the model to CPU memory, you can use `--offload` flag. This will offload the vae and text encoder to CPU memory and only move them to GPU when needed.
### Latent Caching
Expand Down
135 changes: 111 additions & 24 deletions examples/dreambooth/train_dreambooth_lora_flux2.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
import warnings
from contextlib import nullcontext
from pathlib import Path
from typing import Any

import numpy as np
import torch
Expand Down Expand Up @@ -75,13 +76,16 @@
from diffusers.optimization import get_scheduler
from diffusers.training_utils import (
_collate_lora_metadata,
_to_cpu_contiguous,
cast_training_params,
compute_density_for_timestep_sampling,
compute_loss_weighting_for_sd3,
find_nearest_bucket,
free_memory,
get_fsdp_kwargs_from_accelerator,
offload_models,
parse_buckets_string,
wrap_with_fsdp,
)
from diffusers.utils import (
check_min_version,
Expand All @@ -93,6 +97,9 @@
from diffusers.utils.torch_utils import is_compiled_module


if getattr(torch, "distributed", None) is not None:
import torch.distributed as dist

if is_wandb_available():
import wandb

Expand Down Expand Up @@ -722,6 +729,7 @@ def parse_args(input_args=None):
)
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
parser.add_argument("--enable_npu_flash_attention", action="store_true", help="Enabla Flash Attention for NPU")
parser.add_argument("--fsdp_text_encoder", action="store_true", help="Use FSDP for text encoder")

if input_args is not None:
args = parser.parse_args(input_args)
Expand Down Expand Up @@ -1219,7 +1227,11 @@ def main(args):
if args.bnb_quantization_config_path is not None
else {"device": accelerator.device, "dtype": weight_dtype}
)
transformer.to(**transformer_to_kwargs)

is_fsdp = accelerator.state.fsdp_plugin is not None
if not is_fsdp:
transformer.to(**transformer_to_kwargs)

if args.do_fp8_training:
convert_to_float8_training(
transformer, module_filter_fn=module_filter_fn, config=Float8LinearConfig(pad_inner_dim=True)
Expand Down Expand Up @@ -1263,17 +1275,42 @@ def unwrap_model(model):

# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
def save_model_hook(models, weights, output_dir):
transformer_cls = type(unwrap_model(transformer))

# 1) Validate and pick the transformer model
modules_to_save: dict[str, Any] = {}
transformer_model = None

for model in models:
if isinstance(unwrap_model(model), transformer_cls):
transformer_model = model
modules_to_save["transformer"] = model
else:
raise ValueError(f"unexpected save model: {model.__class__}")

if transformer_model is None:
raise ValueError("No transformer model found in 'models'")

# 2) Optionally gather FSDP state dict once
state_dict = accelerator.get_state_dict(model) if is_fsdp else None

# 3) Only main process materializes the LoRA state dict
transformer_lora_layers_to_save = None
if accelerator.is_main_process:
transformer_lora_layers_to_save = None
modules_to_save = {}
for model in models:
if isinstance(model, type(unwrap_model(transformer))):
transformer_lora_layers_to_save = get_peft_model_state_dict(model)
modules_to_save["transformer"] = model
else:
raise ValueError(f"unexpected save model: {model.__class__}")
peft_kwargs = {}
if is_fsdp:
peft_kwargs["state_dict"] = state_dict

transformer_lora_layers_to_save = get_peft_model_state_dict(
unwrap_model(transformer_model) if is_fsdp else transformer_model,
**peft_kwargs,
)

# make sure to pop weight so that corresponding model is not saved again
if is_fsdp:
transformer_lora_layers_to_save = _to_cpu_contiguous(transformer_lora_layers_to_save)

# make sure to pop weight so that corresponding model is not saved again
if weights:
weights.pop()

Flux2Pipeline.save_lora_weights(
Expand All @@ -1285,13 +1322,20 @@ def save_model_hook(models, weights, output_dir):
def load_model_hook(models, input_dir):
transformer_ = None

while len(models) > 0:
model = models.pop()
if not is_fsdp:
while len(models) > 0:
model = models.pop()

if isinstance(model, type(unwrap_model(transformer))):
transformer_ = model
else:
raise ValueError(f"unexpected save model: {model.__class__}")
if isinstance(unwrap_model(model), type(unwrap_model(transformer))):
transformer_ = unwrap_model(model)
else:
raise ValueError(f"unexpected save model: {model.__class__}")
else:
transformer_ = Flux2Transformer2DModel.from_pretrained(
args.pretrained_model_name_or_path,
subfolder="transformer",
)
transformer_.add_adapter(transformer_lora_config)

lora_state_dict = Flux2Pipeline.lora_state_dict(input_dir)

Expand Down Expand Up @@ -1507,6 +1551,21 @@ def _encode_single(prompt: str):
args.validation_prompt, text_encoding_pipeline
)

# Init FSDP for text encoder
if args.fsdp_text_encoder:
fsdp_kwargs = get_fsdp_kwargs_from_accelerator(accelerator)
text_encoder_fsdp = wrap_with_fsdp(
model=text_encoding_pipeline.text_encoder,
device=accelerator.device,
offload=args.offload,
limit_all_gathers=True,
use_orig_params=True,
fsdp_kwargs=fsdp_kwargs,
)

text_encoding_pipeline.text_encoder = text_encoder_fsdp
dist.barrier()

# If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images),
# pack the statically computed variables appropriately here. This is so that we don't
# have to pass them to the dataloader.
Expand Down Expand Up @@ -1536,6 +1595,8 @@ def _encode_single(prompt: str):
if train_dataset.custom_instance_prompts:
if args.remote_text_encoder:
prompt_embeds, text_ids = compute_remote_text_embeddings(batch["prompts"])
elif args.fsdp_text_encoder:
prompt_embeds, text_ids = compute_text_embeddings(batch["prompts"], text_encoding_pipeline)
else:
with offload_models(text_encoding_pipeline, device=accelerator.device, offload=args.offload):
prompt_embeds, text_ids = compute_text_embeddings(batch["prompts"], text_encoding_pipeline)
Expand Down Expand Up @@ -1777,7 +1838,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
progress_bar.update(1)
global_step += 1

if accelerator.is_main_process:
if accelerator.is_main_process or is_fsdp:
if global_step % args.checkpointing_steps == 0:
# _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
if args.checkpoints_total_limit is not None:
Expand Down Expand Up @@ -1836,15 +1897,41 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):

# Save the lora layers
accelerator.wait_for_everyone()

if is_fsdp:
transformer = unwrap_model(transformer)
state_dict = accelerator.get_state_dict(transformer)
if accelerator.is_main_process:
modules_to_save = {}
transformer = unwrap_model(transformer)
if args.bnb_quantization_config_path is None:
if args.upcast_before_saving:
transformer.to(torch.float32)
else:
transformer = transformer.to(weight_dtype)
transformer_lora_layers = get_peft_model_state_dict(transformer)
if is_fsdp:
if args.bnb_quantization_config_path is None:
if args.upcast_before_saving:
state_dict = {
k: v.to(torch.float32) if isinstance(v, torch.Tensor) else v for k, v in state_dict.items()
}
else:
state_dict = {
k: v.to(weight_dtype) if isinstance(v, torch.Tensor) else v for k, v in state_dict.items()
}

transformer_lora_layers = get_peft_model_state_dict(
transformer,
state_dict=state_dict,
)
transformer_lora_layers = {
k: v.detach().cpu().contiguous() if isinstance(v, torch.Tensor) else v
for k, v in transformer_lora_layers.items()
}

else:
transformer = unwrap_model(transformer)
if args.bnb_quantization_config_path is None:
if args.upcast_before_saving:
transformer.to(torch.float32)
else:
transformer = transformer.to(weight_dtype)
transformer_lora_layers = get_peft_model_state_dict(transformer)

modules_to_save["transformer"] = transformer

Flux2Pipeline.save_lora_weights(
Expand Down
Loading