Skip to content

Conversation

@Ohm-Rishabh
Copy link
Contributor

No description provided.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @Ohm-Rishabh, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the FSDP module sharding process by integrating an optional size-based filtering capability. This feature empowers users to optimize performance by selectively sharding modules, avoiding the overhead associated with sharding very small components that may not benefit from FSDP. The implementation is designed to be opt-in and configurable through environment variables, ensuring flexibility and backward compatibility while offering advanced customization for distributed training setups.

Highlights

  • Size-based FSDP Sharding: Introduced an optional mechanism to filter modules for FSDP sharding based on their parameter count, allowing for more granular control over the sharding process.
  • Environment Variable Configuration: The new size-based filtering is controlled by environment variables: FASTVIDEO_FSDP2_AUTOWRAP to enable it, and FASTVIDEO_FSDP2_MIN_PARAMS to set the minimum parameter threshold for sharding.
  • Enhanced Logging: Added detailed logging within the shard_model function to provide visibility into the size-based filtering process, including module inspection, parameter counts, and sharding decisions.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a size-based filtering mechanism for FSDP auto-wrapping, which is a valuable addition for optimizing memory and performance during distributed training. The implementation is controlled via environment variables, providing good flexibility. My review focuses on improving code maintainability by refactoring a section with duplicated logic. I've also included a minor comment on log message formatting for consistency.

Comment on lines +222 to +256
if use_size_filtering:
# Size-based filtering mode
min_params = int(os.environ.get("FASTVIDEO_FSDP2_MIN_PARAMS", "10000000"))
logger.info("Using size-based filtering with threshold: %.2fM", min_params / 1e6)

for n, m in reversed(list(model.named_modules())):
if any([shard_condition(n, m) for shard_condition in fsdp_shard_conditions]):
# Count all parameters
param_count = sum(p.numel() for p in m.parameters(recurse=True))

logger.info( "Inspecting module: name = %s, type = %s, param_count = %.2fM", n, type(m).__name__, param_count / 1e6 )

# Skip small modules
if param_count < min_params:
logger.info("Skipping module %s (%.2fM params < %.2fM threshold)",
n, param_count / 1e6, min_params / 1e6)
continue

# Shard this module
logger.info("Sharding module %s (%.2fM params)", n, param_count / 1e6)
fully_shard(m, **fsdp_kwargs)
num_layers_sharded += 1
else:
# Original logic: shard all modules matching conditions
logger.info("Using original logic: shard all modules matching conditions")

for n, m in reversed(list(model.named_modules())):
if any([shard_condition(n, m) for shard_condition in fsdp_shard_conditions]):
fully_shard(m, **fsdp_kwargs)
num_layers_sharded += 1

if num_layers_sharded == 0:
raise ValueError(
"No layer modules were sharded. Please check if shard conditions are working as expected."
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The if/else block for use_size_filtering contains duplicated code, specifically the loop over model.named_modules() and the check for shard_condition. This can be refactored to improve maintainability by having a single loop and conditionally applying the size-based filtering logic inside it. This will make the code easier to read and modify in the future.

    if use_size_filtering:
        # Size-based filtering mode
        min_params = int(os.environ.get("FASTVIDEO_FSDP2_MIN_PARAMS", "10000000"))
        logger.info("Using size-based filtering with threshold: %.2fM", min_params / 1e6)
    else:
        # Original logic: shard all modules matching conditions
        logger.info("Using original logic: shard all modules matching conditions")

    for n, m in reversed(list(model.named_modules())):
        if not any(shard_condition(n, m) for shard_condition in fsdp_shard_conditions):
            continue

        if use_size_filtering:
            # Count all parameters
            param_count = sum(p.numel() for p in m.parameters(recurse=True))
            logger.info("Inspecting module: name = %s, type = %s, param_count = %.2fM", n, type(m).__name__, param_count / 1e6)

            # Skip small modules
            if param_count < min_params:
                logger.info("Skipping module %s (%.2fM params < %.2fM threshold)",
                            n, param_count / 1e6, min_params / 1e6)
                continue

            # Shard this module
            logger.info("Sharding module %s (%.2fM params)", n, param_count / 1e6)

        fully_shard(m, **fsdp_kwargs)
        num_layers_sharded += 1

    if not use_size_filtering and num_layers_sharded == 0:
        raise ValueError(
            "No layer modules were sharded. Please check if shard conditions are working as expected."
        )

# Count all parameters
param_count = sum(p.numel() for p in m.parameters(recurse=True))

logger.info( "Inspecting module: name = %s, type = %s, param_count = %.2fM", n, type(m).__name__, param_count / 1e6 )
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

There is a leading space in the log message format string, and a trailing space before the closing parenthesis. While this is a minor issue, maintaining consistent formatting in log messages is important for readability and to prevent issues with automated log parsing tools. Please remove the extra spaces.

                logger.info("Inspecting module: name = %s, type = %s, param_count = %.2fM", n, type(m).__name__, param_count / 1e6)

Copy link

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

Comment on lines +222 to +243
if use_size_filtering:
# Size-based filtering mode
min_params = int(os.environ.get("FASTVIDEO_FSDP2_MIN_PARAMS", "10000000"))
logger.info("Using size-based filtering with threshold: %.2fM", min_params / 1e6)

for n, m in reversed(list(model.named_modules())):
if any([shard_condition(n, m) for shard_condition in fsdp_shard_conditions]):
# Count all parameters
param_count = sum(p.numel() for p in m.parameters(recurse=True))

logger.info( "Inspecting module: name = %s, type = %s, param_count = %.2fM", n, type(m).__name__, param_count / 1e6 )

# Skip small modules
if param_count < min_params:
logger.info("Skipping module %s (%.2fM params < %.2fM threshold)",
n, param_count / 1e6, min_params / 1e6)
continue

# Shard this module
logger.info("Sharding module %s (%.2fM params)", n, param_count / 1e6)
fully_shard(m, **fsdp_kwargs)
num_layers_sharded += 1

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Restore zero-shard guard when size filtering enabled

When FASTVIDEO_FSDP2_AUTOWRAP=1 all modules are filtered by size, but the new branch no longer raises an error if nothing was wrapped. num_layers_sharded can remain zero when the threshold is too high or the shard conditions never fire, yet the function still calls fully_shard(model) on the entire network. The previous implementation surfaced this misconfiguration with a ValueError; now it silently turns the whole model into one giant FSDP shard, which can cause unexpected OOMs and defeats the purpose of the shard conditions. Consider keeping the num_layers_sharded == 0 check in both code paths so users are warned when no layers match.

Useful? React with 👍 / 👎.

Copy link
Collaborator

@Edenzzzz Edenzzzz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you show some loading and training speed diff?
Also I think you need
https://github.com/pytorch/pytorch/blob/c7eee495259a5ce2f2f5e8830bcec3b6eca84b31/test/distributed/_composable/fsdp/test_fully_shard_ignore_params.py#L305
to mix FSDP and non-FSDP params, which may add overheads

# Count all parameters
param_count = sum(p.numel() for p in m.parameters(recurse=True))

logger.info( "Inspecting module: name = %s, type = %s, param_count = %.2fM", n, type(m).__name__, param_count / 1e6 )
Copy link
Collaborator

@Edenzzzz Edenzzzz Oct 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no need to print verbose info

num_layers_sharded += 1
else:
# Original logic: shard all modules matching conditions
logger.info("Using original logic: shard all modules matching conditions")
Copy link
Collaborator

@Edenzzzz Edenzzzz Oct 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no need to print this?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants