diff --git a/examples/dreambooth/README_flux2.md b/examples/dreambooth/README_flux2.md index 1d1777811387..41a77c3bbcc8 100644 --- a/examples/dreambooth/README_flux2.md +++ b/examples/dreambooth/README_flux2.md @@ -98,6 +98,9 @@ Flux.2 uses Mistral Small 3.1 as text encoder which is quite large and can take This way, the text encoder model is not loaded into memory during training. > [!NOTE] > to enable remote text encoding you must either be logged in to your HuggingFace account (`hf auth login`) OR pass a token with `--hub_token`. +### FSDP Text Encoder +Flux.2 uses Mistral Small 3.1 as text encoder which is quite large and can take up a lot of memory. To mitigate this, we can use the `--fsdp_text_encoder` flag to enable distributed computation of the prompt embeddings. +This way, it distributes the memory cost across multiple nodes. ### CPU Offloading To offload parts of the model to CPU memory, you can use `--offload` flag. This will offload the vae and text encoder to CPU memory and only move them to GPU when needed. ### Latent Caching diff --git a/examples/dreambooth/train_dreambooth_lora_flux2.py b/examples/dreambooth/train_dreambooth_lora_flux2.py index 81306940af8f..6bba0b94b1b2 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux2.py +++ b/examples/dreambooth/train_dreambooth_lora_flux2.py @@ -44,6 +44,7 @@ import warnings from contextlib import nullcontext from pathlib import Path +from typing import Any import numpy as np import torch @@ -75,13 +76,16 @@ from diffusers.optimization import get_scheduler from diffusers.training_utils import ( _collate_lora_metadata, + _to_cpu_contiguous, cast_training_params, compute_density_for_timestep_sampling, compute_loss_weighting_for_sd3, find_nearest_bucket, free_memory, + get_fsdp_kwargs_from_accelerator, offload_models, parse_buckets_string, + wrap_with_fsdp, ) from diffusers.utils import ( check_min_version, @@ -93,6 +97,9 @@ from diffusers.utils.torch_utils import is_compiled_module +if getattr(torch, "distributed", None) is not None: + import torch.distributed as dist + if is_wandb_available(): import wandb @@ -722,6 +729,7 @@ def parse_args(input_args=None): ) parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") parser.add_argument("--enable_npu_flash_attention", action="store_true", help="Enabla Flash Attention for NPU") + parser.add_argument("--fsdp_text_encoder", action="store_true", help="Use FSDP for text encoder") if input_args is not None: args = parser.parse_args(input_args) @@ -1219,7 +1227,11 @@ def main(args): if args.bnb_quantization_config_path is not None else {"device": accelerator.device, "dtype": weight_dtype} ) - transformer.to(**transformer_to_kwargs) + + is_fsdp = accelerator.state.fsdp_plugin is not None + if not is_fsdp: + transformer.to(**transformer_to_kwargs) + if args.do_fp8_training: convert_to_float8_training( transformer, module_filter_fn=module_filter_fn, config=Float8LinearConfig(pad_inner_dim=True) @@ -1263,17 +1275,42 @@ def unwrap_model(model): # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format def save_model_hook(models, weights, output_dir): + transformer_cls = type(unwrap_model(transformer)) + + # 1) Validate and pick the transformer model + modules_to_save: dict[str, Any] = {} + transformer_model = None + + for model in models: + if isinstance(unwrap_model(model), transformer_cls): + transformer_model = model + modules_to_save["transformer"] = model + else: + raise ValueError(f"unexpected save model: {model.__class__}") + + if transformer_model is None: + raise ValueError("No transformer model found in 'models'") + + # 2) Optionally gather FSDP state dict once + state_dict = accelerator.get_state_dict(model) if is_fsdp else None + + # 3) Only main process materializes the LoRA state dict + transformer_lora_layers_to_save = None if accelerator.is_main_process: - transformer_lora_layers_to_save = None - modules_to_save = {} - for model in models: - if isinstance(model, type(unwrap_model(transformer))): - transformer_lora_layers_to_save = get_peft_model_state_dict(model) - modules_to_save["transformer"] = model - else: - raise ValueError(f"unexpected save model: {model.__class__}") + peft_kwargs = {} + if is_fsdp: + peft_kwargs["state_dict"] = state_dict + + transformer_lora_layers_to_save = get_peft_model_state_dict( + unwrap_model(transformer_model) if is_fsdp else transformer_model, + **peft_kwargs, + ) - # make sure to pop weight so that corresponding model is not saved again + if is_fsdp: + transformer_lora_layers_to_save = _to_cpu_contiguous(transformer_lora_layers_to_save) + + # make sure to pop weight so that corresponding model is not saved again + if weights: weights.pop() Flux2Pipeline.save_lora_weights( @@ -1285,13 +1322,20 @@ def save_model_hook(models, weights, output_dir): def load_model_hook(models, input_dir): transformer_ = None - while len(models) > 0: - model = models.pop() + if not is_fsdp: + while len(models) > 0: + model = models.pop() - if isinstance(model, type(unwrap_model(transformer))): - transformer_ = model - else: - raise ValueError(f"unexpected save model: {model.__class__}") + if isinstance(unwrap_model(model), type(unwrap_model(transformer))): + transformer_ = unwrap_model(model) + else: + raise ValueError(f"unexpected save model: {model.__class__}") + else: + transformer_ = Flux2Transformer2DModel.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="transformer", + ) + transformer_.add_adapter(transformer_lora_config) lora_state_dict = Flux2Pipeline.lora_state_dict(input_dir) @@ -1507,6 +1551,21 @@ def _encode_single(prompt: str): args.validation_prompt, text_encoding_pipeline ) + # Init FSDP for text encoder + if args.fsdp_text_encoder: + fsdp_kwargs = get_fsdp_kwargs_from_accelerator(accelerator) + text_encoder_fsdp = wrap_with_fsdp( + model=text_encoding_pipeline.text_encoder, + device=accelerator.device, + offload=args.offload, + limit_all_gathers=True, + use_orig_params=True, + fsdp_kwargs=fsdp_kwargs, + ) + + text_encoding_pipeline.text_encoder = text_encoder_fsdp + dist.barrier() + # If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images), # pack the statically computed variables appropriately here. This is so that we don't # have to pass them to the dataloader. @@ -1536,6 +1595,8 @@ def _encode_single(prompt: str): if train_dataset.custom_instance_prompts: if args.remote_text_encoder: prompt_embeds, text_ids = compute_remote_text_embeddings(batch["prompts"]) + elif args.fsdp_text_encoder: + prompt_embeds, text_ids = compute_text_embeddings(batch["prompts"], text_encoding_pipeline) else: with offload_models(text_encoding_pipeline, device=accelerator.device, offload=args.offload): prompt_embeds, text_ids = compute_text_embeddings(batch["prompts"], text_encoding_pipeline) @@ -1777,7 +1838,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): progress_bar.update(1) global_step += 1 - if accelerator.is_main_process: + if accelerator.is_main_process or is_fsdp: if global_step % args.checkpointing_steps == 0: # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` if args.checkpoints_total_limit is not None: @@ -1836,15 +1897,41 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): # Save the lora layers accelerator.wait_for_everyone() + + if is_fsdp: + transformer = unwrap_model(transformer) + state_dict = accelerator.get_state_dict(transformer) if accelerator.is_main_process: modules_to_save = {} - transformer = unwrap_model(transformer) - if args.bnb_quantization_config_path is None: - if args.upcast_before_saving: - transformer.to(torch.float32) - else: - transformer = transformer.to(weight_dtype) - transformer_lora_layers = get_peft_model_state_dict(transformer) + if is_fsdp: + if args.bnb_quantization_config_path is None: + if args.upcast_before_saving: + state_dict = { + k: v.to(torch.float32) if isinstance(v, torch.Tensor) else v for k, v in state_dict.items() + } + else: + state_dict = { + k: v.to(weight_dtype) if isinstance(v, torch.Tensor) else v for k, v in state_dict.items() + } + + transformer_lora_layers = get_peft_model_state_dict( + transformer, + state_dict=state_dict, + ) + transformer_lora_layers = { + k: v.detach().cpu().contiguous() if isinstance(v, torch.Tensor) else v + for k, v in transformer_lora_layers.items() + } + + else: + transformer = unwrap_model(transformer) + if args.bnb_quantization_config_path is None: + if args.upcast_before_saving: + transformer.to(torch.float32) + else: + transformer = transformer.to(weight_dtype) + transformer_lora_layers = get_peft_model_state_dict(transformer) + modules_to_save["transformer"] = transformer Flux2Pipeline.save_lora_weights( diff --git a/examples/dreambooth/train_dreambooth_lora_flux2_img2img.py b/examples/dreambooth/train_dreambooth_lora_flux2_img2img.py index 0b9b9f993094..c22c48ecaeb6 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux2_img2img.py +++ b/examples/dreambooth/train_dreambooth_lora_flux2_img2img.py @@ -43,6 +43,7 @@ import shutil from contextlib import nullcontext from pathlib import Path +from typing import Any import numpy as np import torch @@ -74,13 +75,16 @@ from diffusers.pipelines.flux2.image_processor import Flux2ImageProcessor from diffusers.training_utils import ( _collate_lora_metadata, + _to_cpu_contiguous, cast_training_params, compute_density_for_timestep_sampling, compute_loss_weighting_for_sd3, find_nearest_bucket, free_memory, + get_fsdp_kwargs_from_accelerator, offload_models, parse_buckets_string, + wrap_with_fsdp, ) from diffusers.utils import ( check_min_version, @@ -93,6 +97,9 @@ from diffusers.utils.torch_utils import is_compiled_module +if getattr(torch, "distributed", None) is not None: + import torch.distributed as dist + if is_wandb_available(): import wandb @@ -691,6 +698,7 @@ def parse_args(input_args=None): parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") parser.add_argument("--enable_npu_flash_attention", action="store_true", help="Enabla Flash Attention for NPU") + parser.add_argument("--fsdp_text_encoder", action="store_true", help="Use FSDP for text encoder") if input_args is not None: args = parser.parse_args(input_args) @@ -1156,7 +1164,11 @@ def main(args): if args.bnb_quantization_config_path is not None else {"device": accelerator.device, "dtype": weight_dtype} ) - transformer.to(**transformer_to_kwargs) + + is_fsdp = accelerator.state.fsdp_plugin is not None + if not is_fsdp: + transformer.to(**transformer_to_kwargs) + if args.do_fp8_training: convert_to_float8_training( transformer, module_filter_fn=module_filter_fn, config=Float8LinearConfig(pad_inner_dim=True) @@ -1200,17 +1212,42 @@ def unwrap_model(model): # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format def save_model_hook(models, weights, output_dir): + transformer_cls = type(unwrap_model(transformer)) + + # 1) Validate and pick the transformer model + modules_to_save: dict[str, Any] = {} + transformer_model = None + + for model in models: + if isinstance(unwrap_model(model), transformer_cls): + transformer_model = model + modules_to_save["transformer"] = model + else: + raise ValueError(f"unexpected save model: {model.__class__}") + + if transformer_model is None: + raise ValueError("No transformer model found in 'models'") + + # 2) Optionally gather FSDP state dict once + state_dict = accelerator.get_state_dict(model) if is_fsdp else None + + # 3) Only main process materializes the LoRA state dict + transformer_lora_layers_to_save = None if accelerator.is_main_process: - transformer_lora_layers_to_save = None - modules_to_save = {} - for model in models: - if isinstance(model, type(unwrap_model(transformer))): - transformer_lora_layers_to_save = get_peft_model_state_dict(model) - modules_to_save["transformer"] = model - else: - raise ValueError(f"unexpected save model: {model.__class__}") + peft_kwargs = {} + if is_fsdp: + peft_kwargs["state_dict"] = state_dict + + transformer_lora_layers_to_save = get_peft_model_state_dict( + unwrap_model(transformer_model) if is_fsdp else transformer_model, + **peft_kwargs, + ) - # make sure to pop weight so that corresponding model is not saved again + if is_fsdp: + transformer_lora_layers_to_save = _to_cpu_contiguous(transformer_lora_layers_to_save) + + # make sure to pop weight so that corresponding model is not saved again + if weights: weights.pop() Flux2Pipeline.save_lora_weights( @@ -1222,13 +1259,20 @@ def save_model_hook(models, weights, output_dir): def load_model_hook(models, input_dir): transformer_ = None - while len(models) > 0: - model = models.pop() + if not is_fsdp: + while len(models) > 0: + model = models.pop() - if isinstance(model, type(unwrap_model(transformer))): - transformer_ = model - else: - raise ValueError(f"unexpected save model: {model.__class__}") + if isinstance(unwrap_model(model), type(unwrap_model(transformer))): + transformer_ = unwrap_model(model) + else: + raise ValueError(f"unexpected save model: {model.__class__}") + else: + transformer_ = Flux2Transformer2DModel.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="transformer", + ) + transformer_.add_adapter(transformer_lora_config) lora_state_dict = Flux2Pipeline.lora_state_dict(input_dir) @@ -1430,6 +1474,21 @@ def _encode_single(prompt: str): args.validation_prompt, text_encoding_pipeline ) + # Init FSDP for text encoder + if args.fsdp_text_encoder: + fsdp_kwargs = get_fsdp_kwargs_from_accelerator(accelerator) + text_encoder_fsdp = wrap_with_fsdp( + model=text_encoding_pipeline.text_encoder, + device=accelerator.device, + offload=args.offload, + limit_all_gathers=True, + use_orig_params=True, + fsdp_kwargs=fsdp_kwargs, + ) + + text_encoding_pipeline.text_encoder = text_encoder_fsdp + dist.barrier() + # If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images), # pack the statically computed variables appropriately here. This is so that we don't # have to pass them to the dataloader. @@ -1461,6 +1520,8 @@ def _encode_single(prompt: str): if train_dataset.custom_instance_prompts: if args.remote_text_encoder: prompt_embeds, text_ids = compute_remote_text_embeddings(batch["prompts"]) + elif args.fsdp_text_encoder: + prompt_embeds, text_ids = compute_text_embeddings(batch["prompts"], text_encoding_pipeline) else: with offload_models(text_encoding_pipeline, device=accelerator.device, offload=args.offload): prompt_embeds, text_ids = compute_text_embeddings(batch["prompts"], text_encoding_pipeline) @@ -1700,7 +1761,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): progress_bar.update(1) global_step += 1 - if accelerator.is_main_process: + if accelerator.is_main_process or is_fsdp: if global_step % args.checkpointing_steps == 0: # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` if args.checkpoints_total_limit is not None: @@ -1759,15 +1820,41 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): # Save the lora layers accelerator.wait_for_everyone() + + if is_fsdp: + transformer = unwrap_model(transformer) + state_dict = accelerator.get_state_dict(transformer) if accelerator.is_main_process: modules_to_save = {} - transformer = unwrap_model(transformer) - if args.bnb_quantization_config_path is None: - if args.upcast_before_saving: - transformer.to(torch.float32) - else: - transformer = transformer.to(weight_dtype) - transformer_lora_layers = get_peft_model_state_dict(transformer) + if is_fsdp: + if args.bnb_quantization_config_path is None: + if args.upcast_before_saving: + state_dict = { + k: v.to(torch.float32) if isinstance(v, torch.Tensor) else v for k, v in state_dict.items() + } + else: + state_dict = { + k: v.to(weight_dtype) if isinstance(v, torch.Tensor) else v for k, v in state_dict.items() + } + + transformer_lora_layers = get_peft_model_state_dict( + transformer, + state_dict=state_dict, + ) + transformer_lora_layers = { + k: v.detach().cpu().contiguous() if isinstance(v, torch.Tensor) else v + for k, v in transformer_lora_layers.items() + } + + else: + transformer = unwrap_model(transformer) + if args.bnb_quantization_config_path is None: + if args.upcast_before_saving: + transformer.to(torch.float32) + else: + transformer = transformer.to(weight_dtype) + transformer_lora_layers = get_peft_model_state_dict(transformer) + modules_to_save["transformer"] = transformer Flux2Pipeline.save_lora_weights( diff --git a/src/diffusers/training_utils.py b/src/diffusers/training_utils.py index 7a98fa3da14a..2d2f26b266a1 100644 --- a/src/diffusers/training_utils.py +++ b/src/diffusers/training_utils.py @@ -6,10 +6,18 @@ import re import warnings from contextlib import contextmanager -from typing import Any, Dict, Iterable, List, Optional, Tuple, Union +from functools import partial +from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Type, Union import numpy as np import torch +from accelerate.logging import get_logger + + +if getattr(torch, "distributed", None) is not None: + from torch.distributed.fsdp import CPUOffload, ShardingStrategy + from torch.distributed.fsdp import FullyShardedDataParallel as FSDP + from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy from .models import UNet2DConditionModel from .pipelines import DiffusionPipeline @@ -394,6 +402,86 @@ def find_nearest_bucket(h, w, bucket_options): return best_bucket_idx +def _to_cpu_contiguous(state_dicts) -> dict: + return {k: v.detach().cpu().contiguous() if isinstance(v, torch.Tensor) else v for k, v in state_dicts.items()} + + +def get_fsdp_kwargs_from_accelerator(accelerator) -> dict: + """ + Extract and convert FSDP config from Accelerator into PyTorch FSDP kwargs. + """ + + kwargs = {} + fsdp_state = getattr(accelerator.state, "fsdp_plugin", None) + + if fsdp_state is None: + raise ValueError("Accelerate isn't configured to handle FSDP. Please update your installation.") + + fsdp_plugin = accelerator.state.fsdp_plugin + + if fsdp_plugin is None: + # FSDP not enabled in Accelerator + kwargs["sharding_strategy"] = ShardingStrategy.FULL_SHARD + else: + # FSDP is enabled → use plugin's strategy, or default if None + kwargs["sharding_strategy"] = fsdp_plugin.sharding_strategy or ShardingStrategy.FULL_SHARD + + return kwargs + + +def wrap_with_fsdp( + model: torch.nn.Module, + device: Union[str, torch.device], + offload: bool = True, + use_orig_params: bool = True, + limit_all_gathers: bool = True, + fsdp_kwargs: Optional[Dict[str, Any]] = None, + transformer_layer_cls: Optional[Set[Type[torch.nn.Module]]] = None, +) -> FSDP: + """ + Wrap a model with FSDP using common defaults and optional transformer auto-wrapping. + + Args: + model: Model to wrap + device: Target device (e.g., accelerator.device) + offload: Whether to enable CPU parameter offloading + use_orig_params: Whether to use original parameters + limit_all_gathers: Whether to limit all gathers + fsdp_kwargs: FSDP arguments (sharding_strategy, etc.) — usually from Accelerate config + transformer_layer_cls: Classes for auto-wrapping (if not using policy from fsdp_kwargs) + + Returns: + FSDP-wrapped model + """ + + logger = get_logger(__name__) + + if transformer_layer_cls is None: + # Set the default layers if transformer_layer_cls is not provided + transformer_layer_cls = type(model.model.language_model.layers[0]) + logger.info(f"transformer_layer_cls is not provided, auto-inferred as {transformer_layer_cls.__name__}") + + # Add auto-wrap policy if transformer layers specified + auto_wrap_policy = partial( + transformer_auto_wrap_policy, + transformer_layer_cls={transformer_layer_cls}, + ) + + config = { + "device_id": device, + "cpu_offload": CPUOffload(offload_params=offload) if offload else None, + "use_orig_params": use_orig_params, + "limit_all_gathers": limit_all_gathers, + "auto_wrap_policy": auto_wrap_policy, + } + + if fsdp_kwargs: + config.update(fsdp_kwargs) + + fsdp_model = FSDP(model, **config) + return fsdp_model + + # Adapted from torch-ema https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py#L14 class EMAModel: """