-
Notifications
You must be signed in to change notification settings - Fork 6.6k
Add FSDP option for Flux2 #12860
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Add FSDP option for Flux2 #12860
Conversation
559a7a3 to
343b12a
Compare
|
@sayakpaul Please take a look at this PR. Thank you for your help :) |
sayakpaul
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very cool work, thank you for this!
Just confirming -- this is FSDP2, right?
Also, could you provide an example command and your setup so that we can test?
Additionally, can we similarly the denoiser like this?
It is FSDP2, and the script is: The accelerate config is: |
sayakpaul
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changes look neat to me!
Let's also update the README about this.
src/diffusers/training_utils.py
Outdated
| import torch.distributed as dist | ||
| from torch.distributed.fsdp import CPUOffload, ShardingStrategy | ||
| from torch.distributed.fsdp import FullyShardedDataParallel as FSDP | ||
| from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should guard this like so:
if getattr(torch, "distributed", None) is not None:
import torch.distributed as dist
Same for FSDP.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've modified it, please take a look
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Doesn't seem like the commits were pushed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please check it out
src/diffusers/training_utils.py
Outdated
| if dist.is_initialized(): | ||
| dist.barrier() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this needed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've modified it, please take a look
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
|
@bot /style |
|
Style bot fixed some files and pushed the changes. |
sayakpaul
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just a few more comments which are mostly minor. As also mentioned earlier, let's make a note of this in the README_flux2.md.
|
|
||
| import numpy as np | ||
| import torch | ||
| import torch.distributed as dist |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be guarded as well.
| if accelerator.is_main_process: | ||
| transformer_lora_layers_to_save = None | ||
| modules_to_save = {} | ||
| transformer_lora_layers_to_save = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's simplify this block of code a bit:
transformer_cls = type(unwrap_model(transformer))
def _to_cpu_contiguous(sd):
return {
k: (v.detach().cpu().contiguous() if isinstance(v, torch.Tensor) else v)
for k, v in sd.items()
}
# 1) Validate and pick the transformer model
modules_to_save: dict[str, Any] = {}
transformer_model = None
for m in models:
if isinstance(unwrap_model(m), transformer_cls):
transformer_model = m
modules_to_save["transformer"] = m
else:
raise ValueError(f"unexpected save model: {m.__class__}")
if transformer_model is None:
raise ValueError("No transformer model found in `models`.")
# 2) Optionally gather FSDP state dict once
state_dict = accelerator.get_state_dict(models) if is_fsdp else None
# 3) Only main process materializes the LoRA state dict
transformer_lora_layers_to_save = None
if accelerator.is_main_process:
peft_kwargs = {}
if is_fsdp:
peft_kwargs["state_dict"] = state_dict
transformer_lora_layers_to_save = get_peft_model_state_dict(
unwrap_model(transformer_model) if is_fsdp else transformer_model,
**peft_kwargs,
)
if is_fsdp:
transformer_lora_layers_to_save = _to_cpu_contiguous(transformer_lora_layers_to_save)
# make sure to pop weight so that corresponding model is not saved again
if weights:
weights.pop()We can move _to_cpu_contiguous() to the training_utils.py module.
| if accelerator.is_main_process: | ||
| transformer_lora_layers_to_save = None | ||
| modules_to_save = {} | ||
| transformer_lora_layers_to_save = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same as above.
What does this PR do?
The text encoder is too large in Flux2, and offload to cpu requires a lot of time to get the prompt.
It is FSDP2, and the script is:
The accelerate config is:
Before submitting
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.