Fix FP8 torchao default config with padding and FSDP2 all-gather support#3831
Fix FP8 torchao default config with padding and FSDP2 all-gather support#3831SunMarc merged 11 commits intohuggingface:mainfrom
Conversation
SunMarc
left a comment
There was a problem hiding this comment.
Thanks a lot for the changes, really appreciate it ! Left a few minor comments to make it better !
| The configuration for the FP8 training. If `None`, a default config will be created with sensible | ||
| defaults for most use cases: | ||
| - `pad_inner_dim=True`: Pads matrix dimensions to be divisible by 16, required for `torch._scaled_mm` | ||
| operations to prevent runtime errors. | ||
| - `enable_fsdp_float8_all_gather=True`: Enables FP8 all-gather for FSDP2. This provides memory bandwidth | ||
| savings by casting parameters before the all-gather operation, saving 50% bandwidth compared to BF16. | ||
|
|
||
| You can override these defaults by providing your own `Float8LinearConfig` instance. |
There was a problem hiding this comment.
Nice, maybe we can also allow users to easily change that with env var + update the cluster.py file which is responsible of the behavior of accelerate config ? Here's a PR that should help with the changes: #2983
env_prefix = "ACCELERATE_FP8_"
enable_fsdp_float8_all_gather = os.environ.get(env_prefix + "ENABLE_FSDP_FLOAT_ALL_GATHER", True)
pad_inner_dim = os.environ.get(env_prefix + "PAD_INNER_DIM", True)There was a problem hiding this comment.
Makes sense, thanks for the reference. Will add that
There was a problem hiding this comment.
Hi @SunMarc let me know if there's anything else needed
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
d8f018f to
937b6ea
Compare
SunMarc
left a comment
There was a problem hiding this comment.
That's really nice, thanks for fixing this !
|
@SunMarc Sorry, had to fix style issue. Can you re-approve? |
What does this PR do?
enable_fsdp_float8_all_gather=Trueandpad_inner_dim=Trueor
Fixes #3830
Before submitting
This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
Did you read the contributor guideline,
Pull Request section?
Was this discussed/approved via a Github issue or the forum? Please add a link
to it if that's the case.
Did you make sure to update the documentation with your changes? Here are the
documentation guidelines, and
here are tips on formatting docstrings.
Did you write any new necessary tests?
Ran
pytest tests/test_fp8.py -vsuccessfullyWho can review?