Skip to content
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 107 additions & 25 deletions examples/dreambooth/train_dreambooth_lora_flux2.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@

import numpy as np
import torch
import torch.distributed as dist
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be guarded as well.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've modified it, please take a look

import transformers
from accelerate import Accelerator
from accelerate.logging import get_logger
Expand Down Expand Up @@ -80,8 +81,10 @@
compute_loss_weighting_for_sd3,
find_nearest_bucket,
free_memory,
get_fsdp_kwargs_from_accelerator,
offload_models,
parse_buckets_string,
wrap_with_fsdp,
)
from diffusers.utils import (
check_min_version,
Expand Down Expand Up @@ -722,6 +725,7 @@ def parse_args(input_args=None):
)
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
parser.add_argument("--enable_npu_flash_attention", action="store_true", help="Enabla Flash Attention for NPU")
parser.add_argument("--fsdp_text_encoder", action="store_true", help="Use FSDP for text encoder")

if input_args is not None:
args = parser.parse_args(input_args)
Expand Down Expand Up @@ -1219,7 +1223,11 @@ def main(args):
if args.bnb_quantization_config_path is not None
else {"device": accelerator.device, "dtype": weight_dtype}
)
transformer.to(**transformer_to_kwargs)

is_fsdp = accelerator.state.fsdp_plugin is not None
if not is_fsdp:
transformer.to(**transformer_to_kwargs)

if args.do_fp8_training:
convert_to_float8_training(
transformer, module_filter_fn=module_filter_fn, config=Float8LinearConfig(pad_inner_dim=True)
Expand Down Expand Up @@ -1263,19 +1271,43 @@ def unwrap_model(model):

# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
def save_model_hook(models, weights, output_dir):
if accelerator.is_main_process:
transformer_lora_layers_to_save = None
modules_to_save = {}
transformer_lora_layers_to_save = None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's simplify this block of code a bit:

transformer_cls = type(unwrap_model(transformer))

def _to_cpu_contiguous(sd):
    return {
        k: (v.detach().cpu().contiguous() if isinstance(v, torch.Tensor) else v)
        for k, v in sd.items()
    }

# 1) Validate and pick the transformer model
modules_to_save: dict[str, Any] = {}
transformer_model = None

for m in models:
    if isinstance(unwrap_model(m), transformer_cls):
        transformer_model = m
        modules_to_save["transformer"] = m
    else:
        raise ValueError(f"unexpected save model: {m.__class__}")

if transformer_model is None:
    raise ValueError("No transformer model found in `models`.")

# 2) Optionally gather FSDP state dict once
state_dict = accelerator.get_state_dict(models) if is_fsdp else None

# 3) Only main process materializes the LoRA state dict
transformer_lora_layers_to_save = None
if accelerator.is_main_process:
    peft_kwargs = {}
    if is_fsdp:
        peft_kwargs["state_dict"] = state_dict

    transformer_lora_layers_to_save = get_peft_model_state_dict(
        unwrap_model(transformer_model) if is_fsdp else transformer_model,
        **peft_kwargs,
    )

    if is_fsdp:
        transformer_lora_layers_to_save = _to_cpu_contiguous(transformer_lora_layers_to_save)

        # make sure to pop weight so that corresponding model is not saved again
        if weights:
            weights.pop()

We can move _to_cpu_contiguous() to the training_utils.py module.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great suggestion! I've modified it

modules_to_save = {}

if is_fsdp:
for model in models:
if isinstance(model, type(unwrap_model(transformer))):
transformer_lora_layers_to_save = get_peft_model_state_dict(model)
modules_to_save["transformer"] = model
else:
raise ValueError(f"unexpected save model: {model.__class__}")
if isinstance(unwrap_model(model), type(unwrap_model(transformer))):
state_dict = accelerator.get_state_dict(models)

if accelerator.is_main_process:
transformer_lora_layers_to_save = get_peft_model_state_dict(
unwrap_model(model),
state_dict=state_dict,
)
transformer_lora_layers_to_save = {
k: v.detach().cpu().contiguous() if isinstance(v, torch.Tensor) else v
for k, v in transformer_lora_layers_to_save.items()
}
modules_to_save["transformer"] = model

# make sure to pop weight so that corresponding model is not saved again
if weights:
weights.pop()
else:
if accelerator.is_main_process:
transformer_lora_layers_to_save = None
modules_to_save = {}
for model in models:
if isinstance(model, type(unwrap_model(transformer))):
transformer_lora_layers_to_save = get_peft_model_state_dict(model)
modules_to_save["transformer"] = model
else:
raise ValueError(f"unexpected save model: {model.__class__}")

# make sure to pop weight so that corresponding model is not saved again
weights.pop()
# make sure to pop weight so that corresponding model is not saved again
weights.pop()

if accelerator.is_main_process:
Flux2Pipeline.save_lora_weights(
output_dir,
transformer_lora_layers=transformer_lora_layers_to_save,
Expand All @@ -1285,13 +1317,20 @@ def save_model_hook(models, weights, output_dir):
def load_model_hook(models, input_dir):
transformer_ = None

while len(models) > 0:
model = models.pop()
if not is_fsdp:
while len(models) > 0:
model = models.pop()

if isinstance(model, type(unwrap_model(transformer))):
transformer_ = model
else:
raise ValueError(f"unexpected save model: {model.__class__}")
if isinstance(unwrap_model(model), type(unwrap_model(transformer))):
transformer_ = unwrap_model(model)
else:
raise ValueError(f"unexpected save model: {model.__class__}")
else:
transformer_ = Flux2Transformer2DModel.from_pretrained(
args.pretrained_model_name_or_path,
subfolder="transformer",
)
transformer_.add_adapter(transformer_lora_config)

lora_state_dict = Flux2Pipeline.lora_state_dict(input_dir)

Expand Down Expand Up @@ -1507,6 +1546,21 @@ def _encode_single(prompt: str):
args.validation_prompt, text_encoding_pipeline
)

# Init FSDP for text encoder
if args.fsdp_text_encoder:
fsdp_kwargs = get_fsdp_kwargs_from_accelerator(accelerator)
text_encoder_fsdp = wrap_with_fsdp(
model=text_encoding_pipeline.text_encoder,
device=accelerator.device,
offload=args.offload,
limit_all_gathers=True,
use_orig_params=True,
fsdp_kwargs=fsdp_kwargs,
)

text_encoding_pipeline.text_encoder = text_encoder_fsdp
dist.barrier()

# If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images),
# pack the statically computed variables appropriately here. This is so that we don't
# have to pass them to the dataloader.
Expand Down Expand Up @@ -1536,6 +1590,8 @@ def _encode_single(prompt: str):
if train_dataset.custom_instance_prompts:
if args.remote_text_encoder:
prompt_embeds, text_ids = compute_remote_text_embeddings(batch["prompts"])
elif args.fsdp_text_encoder:
prompt_embeds, text_ids = compute_text_embeddings(batch["prompts"], text_encoding_pipeline)
else:
with offload_models(text_encoding_pipeline, device=accelerator.device, offload=args.offload):
prompt_embeds, text_ids = compute_text_embeddings(batch["prompts"], text_encoding_pipeline)
Expand Down Expand Up @@ -1777,7 +1833,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
progress_bar.update(1)
global_step += 1

if accelerator.is_main_process:
if accelerator.is_main_process or is_fsdp:
if global_step % args.checkpointing_steps == 0:
# _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
if args.checkpoints_total_limit is not None:
Expand Down Expand Up @@ -1836,15 +1892,41 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):

# Save the lora layers
accelerator.wait_for_everyone()

if is_fsdp:
transformer = unwrap_model(transformer)
state_dict = accelerator.get_state_dict(transformer)
if accelerator.is_main_process:
modules_to_save = {}
transformer = unwrap_model(transformer)
if args.bnb_quantization_config_path is None:
if args.upcast_before_saving:
transformer.to(torch.float32)
else:
transformer = transformer.to(weight_dtype)
transformer_lora_layers = get_peft_model_state_dict(transformer)
if is_fsdp:
if args.bnb_quantization_config_path is None:
if args.upcast_before_saving:
state_dict = {
k: v.to(torch.float32) if isinstance(v, torch.Tensor) else v for k, v in state_dict.items()
}
else:
state_dict = {
k: v.to(weight_dtype) if isinstance(v, torch.Tensor) else v for k, v in state_dict.items()
}

transformer_lora_layers = get_peft_model_state_dict(
transformer,
state_dict=state_dict,
)
transformer_lora_layers = {
k: v.detach().cpu().contiguous() if isinstance(v, torch.Tensor) else v
for k, v in transformer_lora_layers.items()
}

else:
transformer = unwrap_model(transformer)
if args.bnb_quantization_config_path is None:
if args.upcast_before_saving:
transformer.to(torch.float32)
else:
transformer = transformer.to(weight_dtype)
transformer_lora_layers = get_peft_model_state_dict(transformer)

modules_to_save["transformer"] = transformer

Flux2Pipeline.save_lora_weights(
Expand Down
Loading