-
Couldn't load subscription status.
- Fork 192
Improve FSDP loading with size-based filtering #853
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Improve FSDP loading with size-based filtering #853
Conversation
Summary of ChangesHello @Ohm-Rishabh, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly enhances the FSDP module sharding process by integrating an optional size-based filtering capability. This feature empowers users to optimize performance by selectively sharding modules, avoiding the overhead associated with sharding very small components that may not benefit from FSDP. The implementation is designed to be opt-in and configurable through environment variables, ensuring flexibility and backward compatibility while offering advanced customization for distributed training setups. Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a size-based filtering mechanism for FSDP auto-wrapping, which is a valuable addition for optimizing memory and performance during distributed training. The implementation is controlled via environment variables, providing good flexibility. My review focuses on improving code maintainability by refactoring a section with duplicated logic. I've also included a minor comment on log message formatting for consistency.
| if use_size_filtering: | ||
| # Size-based filtering mode | ||
| min_params = int(os.environ.get("FASTVIDEO_FSDP2_MIN_PARAMS", "10000000")) | ||
| logger.info("Using size-based filtering with threshold: %.2fM", min_params / 1e6) | ||
|
|
||
| for n, m in reversed(list(model.named_modules())): | ||
| if any([shard_condition(n, m) for shard_condition in fsdp_shard_conditions]): | ||
| # Count all parameters | ||
| param_count = sum(p.numel() for p in m.parameters(recurse=True)) | ||
|
|
||
| logger.info( "Inspecting module: name = %s, type = %s, param_count = %.2fM", n, type(m).__name__, param_count / 1e6 ) | ||
|
|
||
| # Skip small modules | ||
| if param_count < min_params: | ||
| logger.info("Skipping module %s (%.2fM params < %.2fM threshold)", | ||
| n, param_count / 1e6, min_params / 1e6) | ||
| continue | ||
|
|
||
| # Shard this module | ||
| logger.info("Sharding module %s (%.2fM params)", n, param_count / 1e6) | ||
| fully_shard(m, **fsdp_kwargs) | ||
| num_layers_sharded += 1 | ||
| else: | ||
| # Original logic: shard all modules matching conditions | ||
| logger.info("Using original logic: shard all modules matching conditions") | ||
|
|
||
| for n, m in reversed(list(model.named_modules())): | ||
| if any([shard_condition(n, m) for shard_condition in fsdp_shard_conditions]): | ||
| fully_shard(m, **fsdp_kwargs) | ||
| num_layers_sharded += 1 | ||
|
|
||
| if num_layers_sharded == 0: | ||
| raise ValueError( | ||
| "No layer modules were sharded. Please check if shard conditions are working as expected." | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The if/else block for use_size_filtering contains duplicated code, specifically the loop over model.named_modules() and the check for shard_condition. This can be refactored to improve maintainability by having a single loop and conditionally applying the size-based filtering logic inside it. This will make the code easier to read and modify in the future.
if use_size_filtering:
# Size-based filtering mode
min_params = int(os.environ.get("FASTVIDEO_FSDP2_MIN_PARAMS", "10000000"))
logger.info("Using size-based filtering with threshold: %.2fM", min_params / 1e6)
else:
# Original logic: shard all modules matching conditions
logger.info("Using original logic: shard all modules matching conditions")
for n, m in reversed(list(model.named_modules())):
if not any(shard_condition(n, m) for shard_condition in fsdp_shard_conditions):
continue
if use_size_filtering:
# Count all parameters
param_count = sum(p.numel() for p in m.parameters(recurse=True))
logger.info("Inspecting module: name = %s, type = %s, param_count = %.2fM", n, type(m).__name__, param_count / 1e6)
# Skip small modules
if param_count < min_params:
logger.info("Skipping module %s (%.2fM params < %.2fM threshold)",
n, param_count / 1e6, min_params / 1e6)
continue
# Shard this module
logger.info("Sharding module %s (%.2fM params)", n, param_count / 1e6)
fully_shard(m, **fsdp_kwargs)
num_layers_sharded += 1
if not use_size_filtering and num_layers_sharded == 0:
raise ValueError(
"No layer modules were sharded. Please check if shard conditions are working as expected."
)| # Count all parameters | ||
| param_count = sum(p.numel() for p in m.parameters(recurse=True)) | ||
|
|
||
| logger.info( "Inspecting module: name = %s, type = %s, param_count = %.2fM", n, type(m).__name__, param_count / 1e6 ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is a leading space in the log message format string, and a trailing space before the closing parenthesis. While this is a minor issue, maintaining consistent formatting in log messages is important for readability and to prevent issues with automated log parsing tools. Please remove the extra spaces.
logger.info("Inspecting module: name = %s, type = %s, param_count = %.2fM", n, type(m).__name__, param_count / 1e6)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💡 Codex Review
Here are some automated review suggestions for this pull request.
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| if use_size_filtering: | ||
| # Size-based filtering mode | ||
| min_params = int(os.environ.get("FASTVIDEO_FSDP2_MIN_PARAMS", "10000000")) | ||
| logger.info("Using size-based filtering with threshold: %.2fM", min_params / 1e6) | ||
|
|
||
| for n, m in reversed(list(model.named_modules())): | ||
| if any([shard_condition(n, m) for shard_condition in fsdp_shard_conditions]): | ||
| # Count all parameters | ||
| param_count = sum(p.numel() for p in m.parameters(recurse=True)) | ||
|
|
||
| logger.info( "Inspecting module: name = %s, type = %s, param_count = %.2fM", n, type(m).__name__, param_count / 1e6 ) | ||
|
|
||
| # Skip small modules | ||
| if param_count < min_params: | ||
| logger.info("Skipping module %s (%.2fM params < %.2fM threshold)", | ||
| n, param_count / 1e6, min_params / 1e6) | ||
| continue | ||
|
|
||
| # Shard this module | ||
| logger.info("Sharding module %s (%.2fM params)", n, param_count / 1e6) | ||
| fully_shard(m, **fsdp_kwargs) | ||
| num_layers_sharded += 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Restore zero-shard guard when size filtering enabled
When FASTVIDEO_FSDP2_AUTOWRAP=1 all modules are filtered by size, but the new branch no longer raises an error if nothing was wrapped. num_layers_sharded can remain zero when the threshold is too high or the shard conditions never fire, yet the function still calls fully_shard(model) on the entire network. The previous implementation surfaced this misconfiguration with a ValueError; now it silently turns the whole model into one giant FSDP shard, which can cause unexpected OOMs and defeats the purpose of the shard conditions. Consider keeping the num_layers_sharded == 0 check in both code paths so users are warned when no layers match.
Useful? React with 👍 / 👎.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you show some loading and training speed diff?
Also I think you need
https://github.com/pytorch/pytorch/blob/c7eee495259a5ce2f2f5e8830bcec3b6eca84b31/test/distributed/_composable/fsdp/test_fully_shard_ignore_params.py#L305
to mix FSDP and non-FSDP params, which may add overheads
| # Count all parameters | ||
| param_count = sum(p.numel() for p in m.parameters(recurse=True)) | ||
|
|
||
| logger.info( "Inspecting module: name = %s, type = %s, param_count = %.2fM", n, type(m).__name__, param_count / 1e6 ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no need to print verbose info
| num_layers_sharded += 1 | ||
| else: | ||
| # Original logic: shard all modules matching conditions | ||
| logger.info("Using original logic: shard all modules matching conditions") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no need to print this?
No description provided.