Skip to content

Commit b4be29b

Browse files
leisuzzsayakpaulgithub-actions[bot]
authored
Add FSDP option for Flux2 (#12860)
* Add FSDP option for Flux2 * Apply style fixes * Add FSDP option for Flux2 * Add FSDP option for Flux2 * Add FSDP option for Flux2 * Add FSDP option for Flux2 * Add FSDP option for Flux2 * Update examples/dreambooth/README_flux2.md * guard accelerate import. --------- Co-authored-by: Sayak Paul <[email protected]> Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
1 parent 98479a9 commit b4be29b

File tree

4 files changed

+337
-49
lines changed

4 files changed

+337
-49
lines changed

examples/dreambooth/README_flux2.md

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,9 @@ Flux.2 uses Mistral Small 3.1 as text encoder which is quite large and can take
9898
This way, the text encoder model is not loaded into memory during training.
9999
> [!NOTE]
100100
> to enable remote text encoding you must either be logged in to your HuggingFace account (`hf auth login`) OR pass a token with `--hub_token`.
101+
### FSDP Text Encoder
102+
Flux.2 uses Mistral Small 3.1 as text encoder which is quite large and can take up a lot of memory. To mitigate this, we can use the `--fsdp_text_encoder` flag to enable distributed computation of the prompt embeddings.
103+
This way, it distributes the memory cost across multiple nodes.
101104
### CPU Offloading
102105
To offload parts of the model to CPU memory, you can use `--offload` flag. This will offload the vae and text encoder to CPU memory and only move them to GPU when needed.
103106
### Latent Caching
@@ -166,6 +169,26 @@ To better track our training experiments, we're using the following flags in the
166169
> [!NOTE]
167170
> If you want to train using long prompts with the T5 text encoder, you can use `--max_sequence_length` to set the token limit. The default is 77, but it can be increased to as high as 512. Note that this will use more resources and may slow down the training in some cases.
168171
172+
### FSDP on the transformer
173+
By setting the accelerate configuration with FSDP, the transformer block will be wrapped automatically. E.g. set the configuration to:
174+
175+
```shell
176+
distributed_type: FSDP
177+
fsdp_config:
178+
fsdp_version: 2
179+
fsdp_offload_params: false
180+
fsdp_sharding_strategy: HYBRID_SHARD
181+
fsdp_auto_wrap_policy: TRANSFOMER_BASED_WRAP
182+
fsdp_transformer_layer_cls_to_wrap: Flux2TransformerBlock, Flux2SingleTransformerBlock
183+
fsdp_forward_prefetch: true
184+
fsdp_sync_module_states: false
185+
fsdp_state_dict_type: FULL_STATE_DICT
186+
fsdp_use_orig_params: false
187+
fsdp_activation_checkpointing: true
188+
fsdp_reshard_after_forward: true
189+
fsdp_cpu_ram_efficient_loading: false
190+
```
191+
169192
## LoRA + DreamBooth
170193

171194
[LoRA](https://huggingface.co/docs/peft/conceptual_guides/adapter#low-rank-adaptation-lora) is a popular parameter-efficient fine-tuning technique that allows you to achieve full-finetuning like performance but with a fraction of learnable parameters.

examples/dreambooth/train_dreambooth_lora_flux2.py

Lines changed: 111 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
import warnings
4545
from contextlib import nullcontext
4646
from pathlib import Path
47+
from typing import Any
4748

4849
import numpy as np
4950
import torch
@@ -75,13 +76,16 @@
7576
from diffusers.optimization import get_scheduler
7677
from diffusers.training_utils import (
7778
_collate_lora_metadata,
79+
_to_cpu_contiguous,
7880
cast_training_params,
7981
compute_density_for_timestep_sampling,
8082
compute_loss_weighting_for_sd3,
8183
find_nearest_bucket,
8284
free_memory,
85+
get_fsdp_kwargs_from_accelerator,
8386
offload_models,
8487
parse_buckets_string,
88+
wrap_with_fsdp,
8589
)
8690
from diffusers.utils import (
8791
check_min_version,
@@ -93,6 +97,9 @@
9397
from diffusers.utils.torch_utils import is_compiled_module
9498

9599

100+
if getattr(torch, "distributed", None) is not None:
101+
import torch.distributed as dist
102+
96103
if is_wandb_available():
97104
import wandb
98105

@@ -722,6 +729,7 @@ def parse_args(input_args=None):
722729
)
723730
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
724731
parser.add_argument("--enable_npu_flash_attention", action="store_true", help="Enabla Flash Attention for NPU")
732+
parser.add_argument("--fsdp_text_encoder", action="store_true", help="Use FSDP for text encoder")
725733

726734
if input_args is not None:
727735
args = parser.parse_args(input_args)
@@ -1219,7 +1227,11 @@ def main(args):
12191227
if args.bnb_quantization_config_path is not None
12201228
else {"device": accelerator.device, "dtype": weight_dtype}
12211229
)
1222-
transformer.to(**transformer_to_kwargs)
1230+
1231+
is_fsdp = accelerator.state.fsdp_plugin is not None
1232+
if not is_fsdp:
1233+
transformer.to(**transformer_to_kwargs)
1234+
12231235
if args.do_fp8_training:
12241236
convert_to_float8_training(
12251237
transformer, module_filter_fn=module_filter_fn, config=Float8LinearConfig(pad_inner_dim=True)
@@ -1263,17 +1275,42 @@ def unwrap_model(model):
12631275

12641276
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
12651277
def save_model_hook(models, weights, output_dir):
1278+
transformer_cls = type(unwrap_model(transformer))
1279+
1280+
# 1) Validate and pick the transformer model
1281+
modules_to_save: dict[str, Any] = {}
1282+
transformer_model = None
1283+
1284+
for model in models:
1285+
if isinstance(unwrap_model(model), transformer_cls):
1286+
transformer_model = model
1287+
modules_to_save["transformer"] = model
1288+
else:
1289+
raise ValueError(f"unexpected save model: {model.__class__}")
1290+
1291+
if transformer_model is None:
1292+
raise ValueError("No transformer model found in 'models'")
1293+
1294+
# 2) Optionally gather FSDP state dict once
1295+
state_dict = accelerator.get_state_dict(model) if is_fsdp else None
1296+
1297+
# 3) Only main process materializes the LoRA state dict
1298+
transformer_lora_layers_to_save = None
12661299
if accelerator.is_main_process:
1267-
transformer_lora_layers_to_save = None
1268-
modules_to_save = {}
1269-
for model in models:
1270-
if isinstance(model, type(unwrap_model(transformer))):
1271-
transformer_lora_layers_to_save = get_peft_model_state_dict(model)
1272-
modules_to_save["transformer"] = model
1273-
else:
1274-
raise ValueError(f"unexpected save model: {model.__class__}")
1300+
peft_kwargs = {}
1301+
if is_fsdp:
1302+
peft_kwargs["state_dict"] = state_dict
1303+
1304+
transformer_lora_layers_to_save = get_peft_model_state_dict(
1305+
unwrap_model(transformer_model) if is_fsdp else transformer_model,
1306+
**peft_kwargs,
1307+
)
12751308

1276-
# make sure to pop weight so that corresponding model is not saved again
1309+
if is_fsdp:
1310+
transformer_lora_layers_to_save = _to_cpu_contiguous(transformer_lora_layers_to_save)
1311+
1312+
# make sure to pop weight so that corresponding model is not saved again
1313+
if weights:
12771314
weights.pop()
12781315

12791316
Flux2Pipeline.save_lora_weights(
@@ -1285,13 +1322,20 @@ def save_model_hook(models, weights, output_dir):
12851322
def load_model_hook(models, input_dir):
12861323
transformer_ = None
12871324

1288-
while len(models) > 0:
1289-
model = models.pop()
1325+
if not is_fsdp:
1326+
while len(models) > 0:
1327+
model = models.pop()
12901328

1291-
if isinstance(model, type(unwrap_model(transformer))):
1292-
transformer_ = model
1293-
else:
1294-
raise ValueError(f"unexpected save model: {model.__class__}")
1329+
if isinstance(unwrap_model(model), type(unwrap_model(transformer))):
1330+
transformer_ = unwrap_model(model)
1331+
else:
1332+
raise ValueError(f"unexpected save model: {model.__class__}")
1333+
else:
1334+
transformer_ = Flux2Transformer2DModel.from_pretrained(
1335+
args.pretrained_model_name_or_path,
1336+
subfolder="transformer",
1337+
)
1338+
transformer_.add_adapter(transformer_lora_config)
12951339

12961340
lora_state_dict = Flux2Pipeline.lora_state_dict(input_dir)
12971341

@@ -1507,6 +1551,21 @@ def _encode_single(prompt: str):
15071551
args.validation_prompt, text_encoding_pipeline
15081552
)
15091553

1554+
# Init FSDP for text encoder
1555+
if args.fsdp_text_encoder:
1556+
fsdp_kwargs = get_fsdp_kwargs_from_accelerator(accelerator)
1557+
text_encoder_fsdp = wrap_with_fsdp(
1558+
model=text_encoding_pipeline.text_encoder,
1559+
device=accelerator.device,
1560+
offload=args.offload,
1561+
limit_all_gathers=True,
1562+
use_orig_params=True,
1563+
fsdp_kwargs=fsdp_kwargs,
1564+
)
1565+
1566+
text_encoding_pipeline.text_encoder = text_encoder_fsdp
1567+
dist.barrier()
1568+
15101569
# If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images),
15111570
# pack the statically computed variables appropriately here. This is so that we don't
15121571
# have to pass them to the dataloader.
@@ -1536,6 +1595,8 @@ def _encode_single(prompt: str):
15361595
if train_dataset.custom_instance_prompts:
15371596
if args.remote_text_encoder:
15381597
prompt_embeds, text_ids = compute_remote_text_embeddings(batch["prompts"])
1598+
elif args.fsdp_text_encoder:
1599+
prompt_embeds, text_ids = compute_text_embeddings(batch["prompts"], text_encoding_pipeline)
15391600
else:
15401601
with offload_models(text_encoding_pipeline, device=accelerator.device, offload=args.offload):
15411602
prompt_embeds, text_ids = compute_text_embeddings(batch["prompts"], text_encoding_pipeline)
@@ -1777,7 +1838,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
17771838
progress_bar.update(1)
17781839
global_step += 1
17791840

1780-
if accelerator.is_main_process:
1841+
if accelerator.is_main_process or is_fsdp:
17811842
if global_step % args.checkpointing_steps == 0:
17821843
# _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
17831844
if args.checkpoints_total_limit is not None:
@@ -1836,15 +1897,41 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
18361897

18371898
# Save the lora layers
18381899
accelerator.wait_for_everyone()
1900+
1901+
if is_fsdp:
1902+
transformer = unwrap_model(transformer)
1903+
state_dict = accelerator.get_state_dict(transformer)
18391904
if accelerator.is_main_process:
18401905
modules_to_save = {}
1841-
transformer = unwrap_model(transformer)
1842-
if args.bnb_quantization_config_path is None:
1843-
if args.upcast_before_saving:
1844-
transformer.to(torch.float32)
1845-
else:
1846-
transformer = transformer.to(weight_dtype)
1847-
transformer_lora_layers = get_peft_model_state_dict(transformer)
1906+
if is_fsdp:
1907+
if args.bnb_quantization_config_path is None:
1908+
if args.upcast_before_saving:
1909+
state_dict = {
1910+
k: v.to(torch.float32) if isinstance(v, torch.Tensor) else v for k, v in state_dict.items()
1911+
}
1912+
else:
1913+
state_dict = {
1914+
k: v.to(weight_dtype) if isinstance(v, torch.Tensor) else v for k, v in state_dict.items()
1915+
}
1916+
1917+
transformer_lora_layers = get_peft_model_state_dict(
1918+
transformer,
1919+
state_dict=state_dict,
1920+
)
1921+
transformer_lora_layers = {
1922+
k: v.detach().cpu().contiguous() if isinstance(v, torch.Tensor) else v
1923+
for k, v in transformer_lora_layers.items()
1924+
}
1925+
1926+
else:
1927+
transformer = unwrap_model(transformer)
1928+
if args.bnb_quantization_config_path is None:
1929+
if args.upcast_before_saving:
1930+
transformer.to(torch.float32)
1931+
else:
1932+
transformer = transformer.to(weight_dtype)
1933+
transformer_lora_layers = get_peft_model_state_dict(transformer)
1934+
18481935
modules_to_save["transformer"] = transformer
18491936

18501937
Flux2Pipeline.save_lora_weights(

0 commit comments

Comments
 (0)