Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 43 additions & 18 deletions fastvideo/models/loader/fsdp_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
# Copyright 2024 The TorchTune Authors.
# Copyright 2025 The FastVideo Authors.

from __future__ import annotations
import os
import contextlib
from collections.abc import Callable, Generator
from itertools import chain
Expand Down Expand Up @@ -197,10 +199,11 @@ def shard_model(
Raises:
ValueError: If no layer modules were sharded, indicating that no shard_condition was triggered.
"""
if fsdp_shard_conditions is None or len(fsdp_shard_conditions) == 0:
logger.warning(
"The FSDP shard condition list is empty or None. No modules will be sharded in %s",
type(model).__name__)
# Check if we should use size-based filtering
use_size_filtering = os.environ.get("FASTVIDEO_FSDP2_AUTOWRAP", "0") == "1"

if not fsdp_shard_conditions:
logger.warning("No FSDP shard conditions provided; nothing will be sharded.")
return

fsdp_kwargs = {
Expand All @@ -215,20 +218,42 @@ def shard_model(
# iterating in reverse to start with
# lowest-level modules first
num_layers_sharded = 0
# TODO(will): don't reshard after forward for the last layer to save on the
# all-gather that will immediately happen Shard the model with FSDP,
for n, m in reversed(list(model.named_modules())):
if any([
shard_condition(n, m)
for shard_condition in fsdp_shard_conditions
]):
fully_shard(m, **fsdp_kwargs)
num_layers_sharded += 1

if num_layers_sharded == 0:
raise ValueError(
"No layer modules were sharded. Please check if shard conditions are working as expected."
)

if use_size_filtering:
# Size-based filtering mode
min_params = int(os.environ.get("FASTVIDEO_FSDP2_MIN_PARAMS", "10000000"))
logger.info("Using size-based filtering with threshold: %.2fM", min_params / 1e6)

for n, m in reversed(list(model.named_modules())):
if any([shard_condition(n, m) for shard_condition in fsdp_shard_conditions]):
# Count all parameters
param_count = sum(p.numel() for p in m.parameters(recurse=True))

logger.info( "Inspecting module: name = %s, type = %s, param_count = %.2fM", n, type(m).__name__, param_count / 1e6 )
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

There is a leading space in the log message format string, and a trailing space before the closing parenthesis. While this is a minor issue, maintaining consistent formatting in log messages is important for readability and to prevent issues with automated log parsing tools. Please remove the extra spaces.

                logger.info("Inspecting module: name = %s, type = %s, param_count = %.2fM", n, type(m).__name__, param_count / 1e6)

Copy link
Collaborator

@Edenzzzz Edenzzzz Oct 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no need to print verbose info


# Skip small modules
if param_count < min_params:
logger.info("Skipping module %s (%.2fM params < %.2fM threshold)",
n, param_count / 1e6, min_params / 1e6)
continue

# Shard this module
logger.info("Sharding module %s (%.2fM params)", n, param_count / 1e6)
fully_shard(m, **fsdp_kwargs)
num_layers_sharded += 1
Comment on lines +222 to +243

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Restore zero-shard guard when size filtering enabled

When FASTVIDEO_FSDP2_AUTOWRAP=1 all modules are filtered by size, but the new branch no longer raises an error if nothing was wrapped. num_layers_sharded can remain zero when the threshold is too high or the shard conditions never fire, yet the function still calls fully_shard(model) on the entire network. The previous implementation surfaced this misconfiguration with a ValueError; now it silently turns the whole model into one giant FSDP shard, which can cause unexpected OOMs and defeats the purpose of the shard conditions. Consider keeping the num_layers_sharded == 0 check in both code paths so users are warned when no layers match.

Useful? React with 👍 / 👎.

else:
# Original logic: shard all modules matching conditions
logger.info("Using original logic: shard all modules matching conditions")
Copy link
Collaborator

@Edenzzzz Edenzzzz Oct 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no need to print this?


for n, m in reversed(list(model.named_modules())):
if any([shard_condition(n, m) for shard_condition in fsdp_shard_conditions]):
fully_shard(m, **fsdp_kwargs)
num_layers_sharded += 1

if num_layers_sharded == 0:
raise ValueError(
"No layer modules were sharded. Please check if shard conditions are working as expected."
)
Comment on lines +222 to +256
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The if/else block for use_size_filtering contains duplicated code, specifically the loop over model.named_modules() and the check for shard_condition. This can be refactored to improve maintainability by having a single loop and conditionally applying the size-based filtering logic inside it. This will make the code easier to read and modify in the future.

    if use_size_filtering:
        # Size-based filtering mode
        min_params = int(os.environ.get("FASTVIDEO_FSDP2_MIN_PARAMS", "10000000"))
        logger.info("Using size-based filtering with threshold: %.2fM", min_params / 1e6)
    else:
        # Original logic: shard all modules matching conditions
        logger.info("Using original logic: shard all modules matching conditions")

    for n, m in reversed(list(model.named_modules())):
        if not any(shard_condition(n, m) for shard_condition in fsdp_shard_conditions):
            continue

        if use_size_filtering:
            # Count all parameters
            param_count = sum(p.numel() for p in m.parameters(recurse=True))
            logger.info("Inspecting module: name = %s, type = %s, param_count = %.2fM", n, type(m).__name__, param_count / 1e6)

            # Skip small modules
            if param_count < min_params:
                logger.info("Skipping module %s (%.2fM params < %.2fM threshold)",
                            n, param_count / 1e6, min_params / 1e6)
                continue

            # Shard this module
            logger.info("Sharding module %s (%.2fM params)", n, param_count / 1e6)

        fully_shard(m, **fsdp_kwargs)
        num_layers_sharded += 1

    if not use_size_filtering and num_layers_sharded == 0:
        raise ValueError(
            "No layer modules were sharded. Please check if shard conditions are working as expected."
        )


# Finally shard the entire model to account for any stragglers
fully_shard(model, **fsdp_kwargs)
Expand Down