-
Notifications
You must be signed in to change notification settings - Fork 192
Improve FSDP loading with size-based filtering #853
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -4,6 +4,8 @@ | |
| # Copyright 2024 The TorchTune Authors. | ||
| # Copyright 2025 The FastVideo Authors. | ||
|
|
||
| from __future__ import annotations | ||
| import os | ||
| import contextlib | ||
| from collections.abc import Callable, Generator | ||
| from itertools import chain | ||
|
|
@@ -197,10 +199,11 @@ def shard_model( | |
| Raises: | ||
| ValueError: If no layer modules were sharded, indicating that no shard_condition was triggered. | ||
| """ | ||
| if fsdp_shard_conditions is None or len(fsdp_shard_conditions) == 0: | ||
| logger.warning( | ||
| "The FSDP shard condition list is empty or None. No modules will be sharded in %s", | ||
| type(model).__name__) | ||
| # Check if we should use size-based filtering | ||
| use_size_filtering = os.environ.get("FASTVIDEO_FSDP2_AUTOWRAP", "0") == "1" | ||
|
|
||
| if not fsdp_shard_conditions: | ||
| logger.warning("No FSDP shard conditions provided; nothing will be sharded.") | ||
| return | ||
|
|
||
| fsdp_kwargs = { | ||
|
|
@@ -215,20 +218,42 @@ def shard_model( | |
| # iterating in reverse to start with | ||
| # lowest-level modules first | ||
| num_layers_sharded = 0 | ||
| # TODO(will): don't reshard after forward for the last layer to save on the | ||
| # all-gather that will immediately happen Shard the model with FSDP, | ||
| for n, m in reversed(list(model.named_modules())): | ||
| if any([ | ||
| shard_condition(n, m) | ||
| for shard_condition in fsdp_shard_conditions | ||
| ]): | ||
| fully_shard(m, **fsdp_kwargs) | ||
| num_layers_sharded += 1 | ||
|
|
||
| if num_layers_sharded == 0: | ||
| raise ValueError( | ||
| "No layer modules were sharded. Please check if shard conditions are working as expected." | ||
| ) | ||
|
|
||
| if use_size_filtering: | ||
| # Size-based filtering mode | ||
| min_params = int(os.environ.get("FASTVIDEO_FSDP2_MIN_PARAMS", "10000000")) | ||
| logger.info("Using size-based filtering with threshold: %.2fM", min_params / 1e6) | ||
|
|
||
| for n, m in reversed(list(model.named_modules())): | ||
| if any([shard_condition(n, m) for shard_condition in fsdp_shard_conditions]): | ||
| # Count all parameters | ||
| param_count = sum(p.numel() for p in m.parameters(recurse=True)) | ||
|
|
||
| logger.info( "Inspecting module: name = %s, type = %s, param_count = %.2fM", n, type(m).__name__, param_count / 1e6 ) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. no need to print verbose info |
||
|
|
||
| # Skip small modules | ||
| if param_count < min_params: | ||
| logger.info("Skipping module %s (%.2fM params < %.2fM threshold)", | ||
| n, param_count / 1e6, min_params / 1e6) | ||
| continue | ||
|
|
||
| # Shard this module | ||
| logger.info("Sharding module %s (%.2fM params)", n, param_count / 1e6) | ||
| fully_shard(m, **fsdp_kwargs) | ||
| num_layers_sharded += 1 | ||
|
Comment on lines
+222
to
+243
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
When Useful? React with 👍 / 👎. |
||
| else: | ||
| # Original logic: shard all modules matching conditions | ||
| logger.info("Using original logic: shard all modules matching conditions") | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. no need to print this? |
||
|
|
||
| for n, m in reversed(list(model.named_modules())): | ||
| if any([shard_condition(n, m) for shard_condition in fsdp_shard_conditions]): | ||
| fully_shard(m, **fsdp_kwargs) | ||
| num_layers_sharded += 1 | ||
|
|
||
| if num_layers_sharded == 0: | ||
| raise ValueError( | ||
| "No layer modules were sharded. Please check if shard conditions are working as expected." | ||
| ) | ||
|
Comment on lines
+222
to
+256
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The if use_size_filtering:
# Size-based filtering mode
min_params = int(os.environ.get("FASTVIDEO_FSDP2_MIN_PARAMS", "10000000"))
logger.info("Using size-based filtering with threshold: %.2fM", min_params / 1e6)
else:
# Original logic: shard all modules matching conditions
logger.info("Using original logic: shard all modules matching conditions")
for n, m in reversed(list(model.named_modules())):
if not any(shard_condition(n, m) for shard_condition in fsdp_shard_conditions):
continue
if use_size_filtering:
# Count all parameters
param_count = sum(p.numel() for p in m.parameters(recurse=True))
logger.info("Inspecting module: name = %s, type = %s, param_count = %.2fM", n, type(m).__name__, param_count / 1e6)
# Skip small modules
if param_count < min_params:
logger.info("Skipping module %s (%.2fM params < %.2fM threshold)",
n, param_count / 1e6, min_params / 1e6)
continue
# Shard this module
logger.info("Sharding module %s (%.2fM params)", n, param_count / 1e6)
fully_shard(m, **fsdp_kwargs)
num_layers_sharded += 1
if not use_size_filtering and num_layers_sharded == 0:
raise ValueError(
"No layer modules were sharded. Please check if shard conditions are working as expected."
) |
||
|
|
||
| # Finally shard the entire model to account for any stragglers | ||
| fully_shard(model, **fsdp_kwargs) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is a leading space in the log message format string, and a trailing space before the closing parenthesis. While this is a minor issue, maintaining consistent formatting in log messages is important for readability and to prevent issues with automated log parsing tools. Please remove the extra spaces.