From 6094ae49be2e0b770ba68ba1e01aacf5bd954c99 Mon Sep 17 00:00:00 2001 From: Ohm-Rishabh Date: Thu, 23 Oct 2025 21:57:10 +0000 Subject: [PATCH] Improve FSDP loading with size-based filtering --- fastvideo/models/loader/fsdp_load.py | 61 ++++++++++++++++++++-------- 1 file changed, 43 insertions(+), 18 deletions(-) diff --git a/fastvideo/models/loader/fsdp_load.py b/fastvideo/models/loader/fsdp_load.py index f8002392a..67685cb4e 100644 --- a/fastvideo/models/loader/fsdp_load.py +++ b/fastvideo/models/loader/fsdp_load.py @@ -4,6 +4,8 @@ # Copyright 2024 The TorchTune Authors. # Copyright 2025 The FastVideo Authors. +from __future__ import annotations +import os import contextlib from collections.abc import Callable, Generator from itertools import chain @@ -197,10 +199,11 @@ def shard_model( Raises: ValueError: If no layer modules were sharded, indicating that no shard_condition was triggered. """ - if fsdp_shard_conditions is None or len(fsdp_shard_conditions) == 0: - logger.warning( - "The FSDP shard condition list is empty or None. No modules will be sharded in %s", - type(model).__name__) + # Check if we should use size-based filtering + use_size_filtering = os.environ.get("FASTVIDEO_FSDP2_AUTOWRAP", "0") == "1" + + if not fsdp_shard_conditions: + logger.warning("No FSDP shard conditions provided; nothing will be sharded.") return fsdp_kwargs = { @@ -215,20 +218,42 @@ def shard_model( # iterating in reverse to start with # lowest-level modules first num_layers_sharded = 0 - # TODO(will): don't reshard after forward for the last layer to save on the - # all-gather that will immediately happen Shard the model with FSDP, - for n, m in reversed(list(model.named_modules())): - if any([ - shard_condition(n, m) - for shard_condition in fsdp_shard_conditions - ]): - fully_shard(m, **fsdp_kwargs) - num_layers_sharded += 1 - - if num_layers_sharded == 0: - raise ValueError( - "No layer modules were sharded. Please check if shard conditions are working as expected." - ) + + if use_size_filtering: + # Size-based filtering mode + min_params = int(os.environ.get("FASTVIDEO_FSDP2_MIN_PARAMS", "10000000")) + logger.info("Using size-based filtering with threshold: %.2fM", min_params / 1e6) + + for n, m in reversed(list(model.named_modules())): + if any([shard_condition(n, m) for shard_condition in fsdp_shard_conditions]): + # Count all parameters + param_count = sum(p.numel() for p in m.parameters(recurse=True)) + + logger.info( "Inspecting module: name = %s, type = %s, param_count = %.2fM", n, type(m).__name__, param_count / 1e6 ) + + # Skip small modules + if param_count < min_params: + logger.info("Skipping module %s (%.2fM params < %.2fM threshold)", + n, param_count / 1e6, min_params / 1e6) + continue + + # Shard this module + logger.info("Sharding module %s (%.2fM params)", n, param_count / 1e6) + fully_shard(m, **fsdp_kwargs) + num_layers_sharded += 1 + else: + # Original logic: shard all modules matching conditions + logger.info("Using original logic: shard all modules matching conditions") + + for n, m in reversed(list(model.named_modules())): + if any([shard_condition(n, m) for shard_condition in fsdp_shard_conditions]): + fully_shard(m, **fsdp_kwargs) + num_layers_sharded += 1 + + if num_layers_sharded == 0: + raise ValueError( + "No layer modules were sharded. Please check if shard conditions are working as expected." + ) # Finally shard the entire model to account for any stragglers fully_shard(model, **fsdp_kwargs)