From 9f537c1c1e196c8d5e56148cf243a9e975b4afe0 Mon Sep 17 00:00:00 2001 From: Deependu Jha Date: Mon, 8 Sep 2025 21:59:08 +0530 Subject: [PATCH 01/17] update --- src/lightning/pytorch/strategies/fsdp2.py | 662 ++++++++++++++++++++++ 1 file changed, 662 insertions(+) create mode 100644 src/lightning/pytorch/strategies/fsdp2.py diff --git a/src/lightning/pytorch/strategies/fsdp2.py b/src/lightning/pytorch/strategies/fsdp2.py new file mode 100644 index 0000000000000..88f4b3790805b --- /dev/null +++ b/src/lightning/pytorch/strategies/fsdp2.py @@ -0,0 +1,662 @@ +# Copyright The Lightning AI team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +import shutil +from collections.abc import Generator, Mapping +from contextlib import contextmanager, nullcontext +from datetime import timedelta +from pathlib import Path +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Literal, + Optional, + Union, +) + +import torch +from lightning_utilities.core.rank_zero import rank_zero_only as utils_rank_zero_only +from torch import Tensor +from torch.nn import Module +from torch.optim import Optimizer +from typing_extensions import override + +import lightning.pytorch as pl +from lightning.fabric.plugins import CheckpointIO, ClusterEnvironment +from lightning.fabric.plugins.collectives.torch_collective import default_pg_timeout +from lightning.fabric.strategies import _StrategyRegistry +from lightning.fabric.strategies.fsdp import ( + _METADATA_FILENAME, + _distributed_checkpoint_load, + _distributed_checkpoint_save, + _get_full_state_dict_context, + _get_sharded_state_dict_context, + _is_full_checkpoint, + _is_sharded_checkpoint, + _move_torchmetrics_to_device, + _optimizer_has_flat_params, + _setup_activation_checkpointing, +) +from lightning.fabric.strategies.model_parallel import _load_raw_module_state +from lightning.fabric.utilities.distributed import ( + _distributed_is_initialized, + _get_default_process_group_backend_for_device, + _init_dist_connection, + _sync_ddp_if_available, +) +from lightning.fabric.utilities.distributed import group as _group +from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_3 +from lightning.fabric.utilities.init import _has_meta_device_parameters_or_buffers +from lightning.fabric.utilities.load import _lazy_load, _materialize_tensors +from lightning.fabric.utilities.optimizer import _optimizers_to_device +from lightning.fabric.utilities.seed import reset_seed +from lightning.fabric.utilities.types import _PATH, ReduceOp +from lightning.pytorch.core.optimizer import LightningOptimizer +from lightning.pytorch.plugins.precision import Precision +from lightning.pytorch.plugins.precision.fsdp import FSDPPrecision +from lightning.pytorch.strategies.launchers.subprocess_script import _SubprocessScriptLauncher +from lightning.pytorch.strategies.parallel import ParallelStrategy +from lightning.pytorch.strategies.strategy import TBroadcast +from lightning.pytorch.trainer.states import TrainerFn +from lightning.pytorch.utilities.model_helpers import is_overridden +from lightning.pytorch.utilities.rank_zero import rank_zero_info, rank_zero_only, rank_zero_warn + +if TYPE_CHECKING: + from torch.distributed.device_mesh import DeviceMesh + from torch.distributed.fsdp import CPUOffloadPolicy, MixedPrecisionPolicy + from torch.distributed.fsdp.fully_sharded_data_parallel import ShardingStrategy + from torch.distributed.fsdp.wrap import ModuleWrapPolicy + + _POLICY = Union[set[type[Module]], Callable[[Module, bool, int], bool], ModuleWrapPolicy] + _SHARDING_STRATEGY = Union[ShardingStrategy, Literal["FULL_SHARD", "SHARD_GRAD_OP", "NO_SHARD", "HYBRID_SHARD"]] + + +log = logging.getLogger(__name__) + + +class FSDP2Strategy(ParallelStrategy): + r"""Strategy for Fully Sharded Data Parallel v2 (FSDP2) provided by ``torch.distributed``. + + FSDP2 is the next-generation implementation of Fully Sharded Data Parallel, built on top of the + ``DeviceMesh`` and ``DTensor`` abstractions. It provides a more robust and extensible way of + scaling models across devices, addressing many of the limitations and inconsistencies of the + original FSDP (referred to here as FSDP1). + + Compared to FSDP1, FSDP2 offers: + - Deterministic and composable sharding plans via ``DeviceMesh`` + - A unified tensor abstraction (``DTensor``) that enables interoperability between FSDP, + tensor parallelism, and pipeline parallelism + - Cleaner checkpointing semantics, reducing many of the loading/saving issues seen in FSDP1 + - Forward compatibility, as PyTorch maintainers are actively deprecating FSDP1 in favor of FSDP2 + + For background, see the RFC: + https://github.com/pytorch/pytorch/issues/114299 + + Arguments: + device_mesh: A :class:`torch.distributed.device_mesh.DeviceMesh` object that specifies + how devices are arranged and how tensors should be sharded/replicated. + parallelize_module: Optional policy function or mapping that specifies how to wrap or + distribute submodules of the model using ``DTensor``. + checkpoint_policy: Defines how checkpoint saving/loading is performed with DTensor-based + modules. See ``torch.distributed.checkpoint`` for available options. + mixed_precision: Optional policy for mixed precision training. Can be used to specify + precision for parameters, gradients, and buffers. + \**kwargs: Additional keyword arguments passed to the underlying FSDP2 APIs. + + .. note:: + FSDP2 is still marked as "not fully stable" in PyTorch, but it is the recommended path + forward. FSDP1 will eventually be deprecated. Users are encouraged to migrate to FSDP2 + for new projects, but should test thoroughly before deploying in production-critical + environments. + + """ + + strategy_name = "fsdp2" + _registered_strategies: list[str] = [] + + def __init__( + self, + device_mesh: Union[tuple[int], "DeviceMesh"] = None, + accelerator: Optional["pl.accelerators.Accelerator"] = None, + parallel_devices: Optional[list[torch.device]] = None, + cluster_environment: Optional[ClusterEnvironment] = None, + checkpoint_io: Optional[CheckpointIO] = None, + precision_plugin: Optional[Precision] = None, + process_group_backend: Optional[str] = None, + timeout: Optional[timedelta] = default_pg_timeout, + cpu_offload: Union[bool, "CPUOffloadPolicy", None] = None, + mp_policy: Optional["MixedPrecisionPolicy"] = None, + state_dict_type: Literal["full", "sharded"] = "full", + **kwargs: Any, + ) -> None: + super().__init__( + accelerator=accelerator, + parallel_devices=parallel_devices, + cluster_environment=cluster_environment, + checkpoint_io=checkpoint_io, + precision_plugin=precision_plugin, + ) + self.num_nodes = 1 + self._process_group_backend = process_group_backend + self._timeout: Optional[timedelta] = timeout + self.cpu_offload = _init_fsdp2_cpu_offload(cpu_offload) + self.mp_policy = _init_fsdp2_mp_policy(mp_policy) + + self.device_mesh = device_mesh + self._state_dict_type = state_dict_type + + @property + @override + def root_device(self) -> torch.device: + assert self.parallel_devices is not None + return self.parallel_devices[self.local_rank] + + @property + def num_processes(self) -> int: + return len(self.parallel_devices) if self.parallel_devices is not None else 0 + + @property + def process_group_backend(self) -> Optional[str]: + return self._process_group_backend + + @property + @override + def precision_plugin(self) -> FSDPPrecision: + plugin = self._precision_plugin + if plugin is not None: + assert isinstance(plugin, FSDPPrecision) + return plugin + return FSDPPrecision("32-true") + + @precision_plugin.setter + @override + def precision_plugin(self, precision_plugin: Optional[Precision]) -> None: + if precision_plugin is not None and not isinstance(precision_plugin, FSDPPrecision): + raise TypeError( + f"The FSDP strategy can only work with the `FSDPPrecision` plugin, found {precision_plugin}" + ) + self._precision_plugin = precision_plugin + + @property + @override + def distributed_sampler_kwargs(self) -> dict: + return {"num_replicas": (self.num_nodes * self.num_processes), "rank": self.global_rank} + + @property + @override + def restore_checkpoint_after_setup(self) -> bool: + return True + + @property + @override + def lightning_restore_optimizer(self) -> bool: + return False + + @override + def setup_environment(self) -> None: + super().setup_environment() + log.debug(f"{self.__class__.__name__}: setting up distributed...") + reset_seed() + + # determine which process we are and world size + self.set_world_ranks() + + self._process_group_backend = self._get_process_group_backend() + assert self.cluster_environment is not None + kwargs: dict[str, Any] = {"timeout": self._timeout} + if _TORCH_GREATER_EQUAL_2_3: + kwargs["device_id"] = self.root_device if self.root_device.type != "cpu" else None + _init_dist_connection(self.cluster_environment, self._process_group_backend, **kwargs) + + # if device_mesh is None, get the world_size and create a 1D device mesh + if self.device_mesh is None: + world_size = self.cluster_environment.world_size() + self.device_mesh = (world_size,) # a 1-D tuple + # if 'device_mesh' is provided as a tuple, update it into the `DeviceMesh` object here + if isinstance(self.device_mesh, tuple): + from torch.distributed.device_mesh import init_device_mesh + + self.device_mesh = init_device_mesh("cuda", self.device_mesh) + + def _get_process_group_backend(self) -> str: + return self._process_group_backend or _get_default_process_group_backend_for_device(self.root_device) + + def set_world_ranks(self) -> None: + if self.cluster_environment is not None: + self.cluster_environment.set_global_rank(self.node_rank * self.num_processes + self.local_rank) + self.cluster_environment.set_world_size(self.num_nodes * self.num_processes) + # `LightningEnvironment.set_global_rank` will do this too, but we cannot rely on that implementation detail + # additionally, for some implementations, the setter is a no-op, so it's safer to access the getter + rank_zero_only.rank = utils_rank_zero_only.rank = self.global_rank + + @override + def _configure_launcher(self) -> None: + assert self.cluster_environment is not None + if not self.cluster_environment.creates_processes_externally: + self._launcher = _SubprocessScriptLauncher(self.cluster_environment, self.num_processes, self.num_nodes) + + @override + def _setup_model(self, model: Module) -> Module: + """Wraps the model into a :class:`~torch.distributed.fsdp.fully_sharded_data_parallel.FullyShardedDataParallel` + module.""" + from torch.distributed.fsdp import FullyShardedDataParallel + + if any(isinstance(mod, FullyShardedDataParallel) for mod in model.modules()): + if _has_meta_device_parameters_or_buffers(model): + rank_zero_warn( + "The model is already wrapped in `FSDP` but there are still parameters on the meta device." + ) + if "auto_wrap_policy" in self.kwargs: + # The user has wrapped their submodules manually, don't apply the auto wrap policy. + rank_zero_warn( + "A FSDP `auto_wrap_policy` is set, but the model is already wrapped. The policy will be ignored." + ) + del self.kwargs["auto_wrap_policy"] + else: + log.debug(f"setting up FSDP model with device id: {self.root_device.index}, kwargs: {self.kwargs}") + model = FullyShardedDataParallel( + module=model, + cpu_offload=self.cpu_offload, + mixed_precision=self.mixed_precision_config, + sharding_strategy=self.sharding_strategy, + device_id=self.root_device.index, + **self.kwargs, + ) + + _move_torchmetrics_to_device(model, self.root_device) + + # activation checkpointing needs to be set up after wrapping the model + _setup_activation_checkpointing(model, self._activation_checkpointing_kwargs) + + return model + + @override + def setup(self, trainer: "pl.Trainer") -> None: + assert self.accelerator is not None + self.accelerator.setup(trainer) + + assert self.model is not None + if trainer.state.fn == TrainerFn.FITTING and self._layer_sync: + self.model = self._layer_sync.apply(self.model) + + self.model = self.precision_plugin.convert_module(self.model) + + if is_overridden("configure_sharded_model", self.lightning_module): + # legacy: we don't skip setup with the `configure_model` alternative + rank_zero_info( + "You have overridden `LightningModule.configure_sharded_model` hook. It will assume that all the layers" + " are already wrapped for sharding and won't wrap the entire model using `FullyShardedDataParallel`." + ) + else: + self.model = self._setup_model(self.model) + self.barrier() + + if trainer.state.fn == TrainerFn.FITTING: + self.setup_optimizers(trainer) + self.setup_precision_plugin() + if trainer.state.fn == TrainerFn.FITTING: + _optimizers_to_device(self.optimizers, self.root_device) + + @override + def setup_optimizers(self, trainer: "pl.Trainer") -> None: + # If we're setting up for evaluation after fitting, we need to discard the optimizers + # since we're rewrapping the model, otherwise optimizer param references are no longer valid + # and subsequent checkpoint saving can fail + self._reset_optimizers_and_schedulers() + + if self.kwargs.get("use_orig_params"): + return super().setup_optimizers(trainer) + + invalid_params_error = False + try: + # If `use_orig_params=False` the user needs to do access `self.trainer.model.parameters()` in + # `configure_optimizers()` + super().setup_optimizers(trainer) + except ValueError as ex: + if "optimizer got an empty parameter list" not in str(ex): + raise + invalid_params_error = True + + if invalid_params_error or any(not _optimizer_has_flat_params(optimizer) for optimizer in self.optimizers): + # We avoid this limitation by setting `use_orig_params=True` + raise ValueError( + "The optimizer does not seem to reference any FSDP parameters. HINT: Make sure to create the" + " optimizer after setting up the model by referencing `self.trainer.model.parameters()` in the" + " `configure_optimizers()` hook." + ) + return None + + @override + def model_to_device(self) -> None: + # FSDP takes care of moving the model to device + pass + + @contextmanager + @override + def tensor_init_context(self, empty_init: Optional[bool] = None) -> Generator[None, None, None]: + # Materialization happens in `setup`. When modules get wrapped by FSDP, the sequence of operations is: + # 1) materialize module 2) call `reset_parameters()` 3) shard the module. + # These operations are applied to each submodule 'bottom up' in the module hierarchy. + empty_init_context = torch.device("meta") if empty_init else nullcontext() + with empty_init_context, self.precision_plugin.tensor_init_context(): + yield + + @contextmanager + @override + def model_sharded_context(self) -> Generator[None, None, None]: + log.debug(f"{self.__class__.__name__}: entered model_sharded_context.") + from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel + from torch.distributed.fsdp.wrap import enable_wrap + + with enable_wrap( + wrapper_cls=FullyShardedDataParallel, + cpu_offload=self.cpu_offload, + mixed_precision=self.mixed_precision_config, + sharding_strategy=self.sharding_strategy, + device_id=self.root_device.index, + **self.kwargs, + ): + yield + + @override + def barrier(self, name: Optional[str] = None) -> None: + if not _distributed_is_initialized(): + return + if torch.distributed.get_backend() == "nccl": + torch.distributed.barrier(device_ids=self._determine_device_ids()) + else: + torch.distributed.barrier() + + @override + def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast: + if not _distributed_is_initialized(): + return obj + + obj = [obj] + torch.distributed.broadcast_object_list(obj, src, group=_group.WORLD) + return obj[0] + + @override + def reduce( + self, + tensor: Union[Tensor, Any], + group: Optional[Any] = None, + reduce_op: Optional[Union[ReduceOp, str]] = "mean", + ) -> Tensor: + """Reduces a tensor from several distributed processes to one aggregated tensor. + + Args: + tensor: the tensor to sync and reduce + group: the process group to gather results from. Defaults to all processes (world) + reduce_op: the reduction operation. Defaults to 'mean'/'avg'. + Can also be a string 'sum' to calculate the sum during reduction. + + Return: + reduced value, except when the input was not a tensor the output remains is unchanged + + """ + if isinstance(tensor, Tensor): + return _sync_ddp_if_available(tensor, group, reduce_op=reduce_op) + return tensor + + def _determine_device_ids(self) -> list[int]: + return [self.root_device.index] + + @override + def teardown(self) -> None: + log.debug(f"{self.__class__.__name__}: tearing down strategy...") + + pl_module = self.lightning_module + if ( + pl_module is not None + # `self.lightning_module._trainer` can be None if teardown gets called on an exception before + # the trainer gets set on the LightningModule + and pl_module._trainer is not None + and pl_module._trainer.state.fn == TrainerFn.FITTING + and self._layer_sync + ): + assert self.model is not None + self.model = self._layer_sync.revert(self.model) + + assert self.cluster_environment is not None + assert self.accelerator is not None + self.cluster_environment.teardown() + self.precision_plugin.teardown() + self.accelerator.teardown() + + @classmethod + def get_registered_strategies(cls) -> list[str]: + return cls._registered_strategies + + @classmethod + @override + def register_strategies(cls, strategy_registry: _StrategyRegistry) -> None: + if not torch.distributed.is_available(): + return + strategy_registry.register( + "fsdp2", + cls, + description="FSDP2 training", + ) + cls._registered_strategies.append("fsdp2") + + strategy_registry.register( + "fsdp_cpu_offload", + cls, + description="FSDP2 training with Full Sharding and CPU Offloading", + cpu_offload=True, + ) + cls._registered_strategies.append("fsdp2_cpu_offload") + + @override + def lightning_module_state_dict(self) -> dict[str, Any]: + assert self.model is not None + if self._state_dict_type == "sharded": + state_dict_ctx = _get_sharded_state_dict_context(self.model) + elif self._state_dict_type == "full": + state_dict_ctx = _get_full_state_dict_context(self.model, world_size=self.world_size) + else: + raise ValueError(f"Unknown state_dict_type: {self._state_dict_type}") + with state_dict_ctx: + return self.model.state_dict() + + @override + def load_model_state_dict(self, checkpoint: Mapping[str, Any], strict: bool = True) -> None: + # Override to do nothing, FSDP already loaded the states in `load_checkpoint()` + pass + + @override + def optimizer_state(self, optimizer: Optimizer) -> dict[str, Tensor]: + from torch.distributed.fsdp import FullyShardedDataParallel as FSDP + from torch.distributed.fsdp import OptimStateKeyType + + if isinstance(optimizer, LightningOptimizer): + optimizer = optimizer._optimizer + + assert self.model is not None + if self._state_dict_type == "sharded": + with _get_sharded_state_dict_context(self.model): + return FSDP.optim_state_dict(self.model, optimizer) + + elif self._state_dict_type == "full": + with _get_full_state_dict_context(self.model, world_size=self.world_size): + state_dict = FSDP.optim_state_dict(self.model, optimizer) + if self.global_rank == 0: + # Store the optimizer state dict in standard format + state_dict = FSDP.rekey_optim_state_dict(state_dict, OptimStateKeyType.PARAM_ID, self.model) + return state_dict + + raise ValueError(f"Unknown state_dict_type: {self._state_dict_type}") + + @override + def load_optimizer_state_dict(self, checkpoint: Mapping[str, Any]) -> None: + # Override to do nothing, the FSDP already loaded the states in `load_checkpoint()` + pass + + @override + def save_checkpoint( + self, checkpoint: dict[str, Any], filepath: _PATH, storage_options: Optional[Any] = None + ) -> None: + if storage_options is not None: + raise TypeError( + "`FSDPStrategy.save_checkpoint(..., storage_options=...)` is not supported because" + " `FSDPStrategy` does not use the `CheckpointIO`." + ) + + path = Path(self.broadcast(filepath)) + if path.is_dir() and self._state_dict_type == "full" and not _is_sharded_checkpoint(path): + raise IsADirectoryError(f"The checkpoint path exists and is a directory: {path}") + + if self._state_dict_type == "sharded": + if path.is_file(): + path.unlink() + path.mkdir(parents=True, exist_ok=True) + + converted_state = {"model": checkpoint.pop("state_dict")} + converted_state.update({ + f"optimizer_{idx}": optim_state + for idx, optim_state in enumerate(checkpoint.pop("optimizer_states", [])) + }) + + _distributed_checkpoint_save(converted_state, path) + + if self.global_rank == 0: + torch.save(checkpoint, path / _METADATA_FILENAME) + elif self._state_dict_type == "full": + if _is_sharded_checkpoint(path): + shutil.rmtree(path) + return super().save_checkpoint(checkpoint=checkpoint, filepath=path) + else: + raise ValueError(f"Unknown state_dict_type: {self._state_dict_type}") + + @override + def load_checkpoint(self, checkpoint_path: _PATH) -> dict[str, Any]: + # broadcast the path from rank 0 to ensure all the states are loaded from a common path + path = Path(self.broadcast(checkpoint_path)) + + from torch.distributed.fsdp import FullyShardedDataParallel as FSDP + + assert self.model is not None + assert self.lightning_module is not None + + if _is_sharded_checkpoint(path): + from torch.distributed.checkpoint.optimizer import load_sharded_optimizer_state_dict + + state_dict_ctx = _get_sharded_state_dict_context(self.model) + + with state_dict_ctx: + module_state = {"model": self.model.state_dict()} + _distributed_checkpoint_load(module_state, path) + self.model.load_state_dict(module_state["model"], strict=self.lightning_module.strict_loading) + + if self.lightning_module.trainer.state.fn == TrainerFn.FITTING and self.optimizers: + from torch.distributed.checkpoint import FileSystemReader + + # TODO: replace with newer APIs + # https://github.com/pytorch/pytorch/issues/119800#issuecomment-1942156271 + reader = FileSystemReader(path=path) + # the optimizer states must be loaded separately + for idx, optim in enumerate(self.optimizers): + optim_key = f"optimizer_{idx}" + optim_state = load_sharded_optimizer_state_dict( + model_state_dict=module_state["model"], + optimizer_key=optim_key, + storage_reader=reader, + ) + flattened_osd = FSDP.optim_state_dict_to_load( + optim_state_dict=optim_state[optim_key], + model=self.model, + optim=optim, + ) + optim.load_state_dict(flattened_osd) + + # Load metadata (anything not a module or optimizer) + metadata = torch.load(path / _METADATA_FILENAME) + return metadata + + if _is_full_checkpoint(path): + checkpoint = _lazy_load(path) + _load_raw_module_state( + checkpoint.pop("state_dict"), + module=self.model, + world_size=self.world_size, + strict=self.lightning_module.strict_loading, + ) + + # Materialize lazy tensors if there are any left in the checkpoint + # The `torch.Optimizer.load_state_dict` method can't load lazy tensors because of deepcopy pickle issues + checkpoint = _materialize_tensors(checkpoint) + + from torch.distributed.fsdp import FullyShardedDataParallel as FSDP + from torch.distributed.fsdp import OptimStateKeyType + + optimizer_states = checkpoint.get("optimizer_states") + if optimizer_states is None or self.lightning_module.trainer.state.fn != TrainerFn.FITTING: + # If the optimizer states are not present, we don't need to do anything (backward compatibility) + return checkpoint + if len(self.optimizers) != len(optimizer_states): + raise RuntimeError( + f"You have configured {len(self.optimizers)} optimizers but the checkpoint contains" + f" {len(optimizer_states)} optimizers to load. Please resume training with the same number" + " of optimizers or edit the checkpoint manually to remove states." + ) + + # rank0_only should be false because we need to load the optimizer state on all ranks + with _get_full_state_dict_context(self.model, world_size=self.world_size, rank0_only=False): + for optimizer, opt_state in zip(self.optimizers, optimizer_states): + if isinstance(list(opt_state["state"].keys())[0], int): + # Handling the case where the optimizer state is saved from a normal optimizer + opt_state = FSDP.rekey_optim_state_dict(opt_state, OptimStateKeyType.PARAM_NAME, self.model) + + opt_state = FSDP.optim_state_dict_to_load( + optim_state_dict=opt_state, + model=self.model, + optim=optimizer, + ) + optimizer.load_state_dict(opt_state) + + return checkpoint + + raise ValueError( + f"The path {str(path)!r} does not point to a valid checkpoint. Make sure the path points to either a" + " directory with FSDP checkpoint shards, or a single file with a full checkpoint." + ) + + +def _init_fsdp2_cpu_offload(cpu_offload: Optional[Union[bool, "CPUOffloadPolicy"]]) -> "CPUOffloadPolicy": + from torch.distributed.fsdp import CPUOffloadPolicy, OffloadPolicy + + if cpu_offload is None or cpu_offload is False: + return OffloadPolicy() + + if cpu_offload is True: + return CPUOffloadPolicy(pin_memory=True) + + if isinstance(cpu_offload, CPUOffloadPolicy): + return cpu_offload + + raise TypeError(f"`cpu_offload` should be of type `bool` or `CPUOffloadPolicy`, got {type(cpu_offload)}") + + +def _init_fsdp2_mp_policy(mp_policy: Optional["MixedPrecisionPolicy"]) -> Optional["MixedPrecisionPolicy"]: + from torch.distributed.fsdp import MixedPrecisionPolicy + + if mp_policy is None: + return MixedPrecisionPolicy(param_dtype=None, reduce_dtype=None, output_dtype=None, cast_forward_inputs=True) + + if isinstance(mp_policy, MixedPrecisionPolicy): + return mp_policy + + raise TypeError(f"`mp_policy` should be of type `MixedPrecisionPolicy`, got {type(mp_policy)}") From d8a4d846b4d21ace64a7be2ee6af8a5b7fc638c1 Mon Sep 17 00:00:00 2001 From: Deependu Jha Date: Tue, 9 Sep 2025 10:52:49 +0530 Subject: [PATCH 02/17] add fsdp2 precision plugin --- .../pytorch/plugins/precision/fsdp2.py | 110 ++++++++++++++++++ src/lightning/pytorch/strategies/fsdp2.py | 12 +- 2 files changed, 116 insertions(+), 6 deletions(-) create mode 100644 src/lightning/pytorch/plugins/precision/fsdp2.py diff --git a/src/lightning/pytorch/plugins/precision/fsdp2.py b/src/lightning/pytorch/plugins/precision/fsdp2.py new file mode 100644 index 0000000000000..2bda780a97282 --- /dev/null +++ b/src/lightning/pytorch/plugins/precision/fsdp2.py @@ -0,0 +1,110 @@ +# Copyright The Lightning AI team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from contextlib import AbstractContextManager +from typing import Any + +import torch +from lightning_utilities import apply_to_collection +from torch import Tensor +from torch.nn import Module +from typing_extensions import get_args, override + +from lightning.fabric.plugins.precision.fsdp import _PRECISION_INPUT +from lightning.fabric.plugins.precision.utils import _convert_fp_tensor, _DtypeContextManager +from lightning.pytorch.plugins.precision.precision import Precision +from lightning.pytorch.utilities.exceptions import MisconfigurationException + + +class FSDP2Precision(Precision): + """Precision plugin for training with FSDP2 (Fully Sharded Data Parallel v2). + + .. warning:: This is an :ref:`experimental ` feature. + + Args: + precision: Full precision (32-true), half precision (16-true, bf16-true) or + mixed precision (16-mixed, bf16-mixed). + scaler: An optional :class:`torch.distributed.fsdp.sharded_grad_scaler.ShardedGradScaler` to use. + + Raises: + ValueError: + If unsupported ``precision`` is provided. + + """ + + def __init__(self, precision: _PRECISION_INPUT, scaler: Any = None) -> None: + supported_precision = get_args(_PRECISION_INPUT) + if precision not in supported_precision: + raise ValueError( + f"`precision={precision!r})` is not supported in FSDP." + f" `precision` must be one of: {supported_precision}." + ) + + if scaler is not None: + raise ValueError( + f"`scaler` is not supported in `{self.__class__.__name__}`, found {scaler}." + "Use `mixed-precision policy` instead to configure the scaler." + ) + + if "mixed" in precision: + raise ValueError( + f"`precision={precision!r}` is not supported in `{self.__class__.__name__}`." + "Only `true` precision is supported." + "Use `mixed-precision policy (mp_policy)` instead to configure mixed precision." + ) + + self.precision = precision + + precision_to_type = { + "bf16-true": torch.bfloat16, + "16-true": torch.float16, + "32-true": torch.float32, + } + self._desired_input_dtype = precision_to_type[self.precision] + + @override + def convert_module(self, module: Module) -> Module: + if "true" in self.precision: + return module.to(dtype=self._desired_input_dtype) + return module + + @override + def clip_grad_by_norm(self, *_: Any, **__: Any) -> None: + # see https://pytorch.org/docs/stable/fsdp.html#torch.distributed.fsdp.FullyShardedDataParallel.clip_grad_norm_ + # section `Gradient Clipping`, using `torch.nn.utils.clip_grad_norm_` is incorrect with FSDP. + # To overcome this we need to call root_sharded_module.clip_grad_norm(clip_val), but we don't have a reference + # to the root module + raise MisconfigurationException( + f"`gradient_clip_algorithm='norm'` is currently not supported for `{self.__class__.__name__}`" + ) + + @override + def tensor_init_context(self) -> AbstractContextManager: + return _DtypeContextManager(self._desired_input_dtype) + + @override + def module_init_context(self) -> AbstractContextManager: + # Use float32 for module parameter initialization to ensure numerical stability + return _DtypeContextManager(self._desired_input_dtype) + + @override + def forward_context(self) -> AbstractContextManager: + return _DtypeContextManager(self._desired_input_dtype) + + @override + def convert_input(self, data: Any) -> Any: + return apply_to_collection(data, function=_convert_fp_tensor, dtype=Tensor, dst_type=self._desired_input_dtype) + + @override + def convert_output(self, data: Any) -> Any: + return apply_to_collection(data, function=_convert_fp_tensor, dtype=Tensor, dst_type=torch.get_default_dtype()) diff --git a/src/lightning/pytorch/strategies/fsdp2.py b/src/lightning/pytorch/strategies/fsdp2.py index 88f4b3790805b..927679e453866 100644 --- a/src/lightning/pytorch/strategies/fsdp2.py +++ b/src/lightning/pytorch/strategies/fsdp2.py @@ -65,7 +65,7 @@ from lightning.fabric.utilities.types import _PATH, ReduceOp from lightning.pytorch.core.optimizer import LightningOptimizer from lightning.pytorch.plugins.precision import Precision -from lightning.pytorch.plugins.precision.fsdp import FSDPPrecision +from lightning.pytorch.plugins.precision.fsdp2 import FSDP2Precision from lightning.pytorch.strategies.launchers.subprocess_script import _SubprocessScriptLauncher from lightning.pytorch.strategies.parallel import ParallelStrategy from lightning.pytorch.strategies.strategy import TBroadcast @@ -173,19 +173,19 @@ def process_group_backend(self) -> Optional[str]: @property @override - def precision_plugin(self) -> FSDPPrecision: + def precision_plugin(self) -> FSDP2Precision: plugin = self._precision_plugin if plugin is not None: - assert isinstance(plugin, FSDPPrecision) + assert isinstance(plugin, FSDP2Precision) return plugin - return FSDPPrecision("32-true") + return FSDP2Precision("32-true") @precision_plugin.setter @override def precision_plugin(self, precision_plugin: Optional[Precision]) -> None: - if precision_plugin is not None and not isinstance(precision_plugin, FSDPPrecision): + if precision_plugin is not None and not isinstance(precision_plugin, FSDP2Precision): raise TypeError( - f"The FSDP strategy can only work with the `FSDPPrecision` plugin, found {precision_plugin}" + f"The FSDP2 strategy can only work with the `FSDP2Precision` plugin, found {precision_plugin}" ) self._precision_plugin = precision_plugin From 35bf1a2633e8e05204735dc71c2d96eff322eeda Mon Sep 17 00:00:00 2001 From: Deependu Jha Date: Tue, 9 Sep 2025 17:48:29 +0530 Subject: [PATCH 03/17] time to test fsdp2 --- src/lightning/fabric/utilities/init.py | 18 ++ src/lightning/pytorch/strategies/fsdp2.py | 276 +++++++--------------- 2 files changed, 99 insertions(+), 195 deletions(-) diff --git a/src/lightning/fabric/utilities/init.py b/src/lightning/fabric/utilities/init.py index 4f8519eec9610..6243c91a7a402 100644 --- a/src/lightning/fabric/utilities/init.py +++ b/src/lightning/fabric/utilities/init.py @@ -113,3 +113,21 @@ def _has_meta_device_parameters_or_buffers(obj: Union[Module, Optimizer], recurs if isinstance(obj, Module): return any(t.is_meta for t in itertools.chain(obj.parameters(recurse=recurse), obj.buffers(recurse=recurse))) raise TypeError(f"Expected `torch.nn.Module` or `torch.optim.Optimizer`, got: {type(obj).__name__}") + + +def _has_all_dtensor_params_or_buffers(obj: Union[Module, Optimizer], recurse: bool = True) -> bool: + from torch.distributed.tensor import DTensor + + if isinstance(obj, Optimizer): + return all( + isinstance(t, DTensor) + for param_group in obj.param_groups + for t in param_group["params"] + if isinstance(t, Parameter) + ) + if isinstance(obj, Module): + return all( + isinstance(t, DTensor) + for t in itertools.chain(obj.parameters(recurse=recurse), obj.buffers(recurse=recurse)) + ) + raise TypeError(f"Expected `torch.nn.Module` or `torch.optim.Optimizer`, got: {type(obj).__name__}") diff --git a/src/lightning/pytorch/strategies/fsdp2.py b/src/lightning/pytorch/strategies/fsdp2.py index 927679e453866..0b86d2da633ca 100644 --- a/src/lightning/pytorch/strategies/fsdp2.py +++ b/src/lightning/pytorch/strategies/fsdp2.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -import shutil from collections.abc import Generator, Mapping from contextlib import contextmanager, nullcontext from datetime import timedelta @@ -20,8 +19,6 @@ from typing import ( TYPE_CHECKING, Any, - Callable, - Literal, Optional, Union, ) @@ -29,6 +26,8 @@ import torch from lightning_utilities.core.rank_zero import rank_zero_only as utils_rank_zero_only from torch import Tensor +from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_dict +from torch.distributed.checkpoint.stateful import Stateful from torch.nn import Module from torch.optim import Optimizer from typing_extensions import override @@ -41,15 +40,9 @@ _METADATA_FILENAME, _distributed_checkpoint_load, _distributed_checkpoint_save, - _get_full_state_dict_context, - _get_sharded_state_dict_context, - _is_full_checkpoint, - _is_sharded_checkpoint, _move_torchmetrics_to_device, _optimizer_has_flat_params, - _setup_activation_checkpointing, ) -from lightning.fabric.strategies.model_parallel import _load_raw_module_state from lightning.fabric.utilities.distributed import ( _distributed_is_initialized, _get_default_process_group_backend_for_device, @@ -58,12 +51,10 @@ ) from lightning.fabric.utilities.distributed import group as _group from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_3 -from lightning.fabric.utilities.init import _has_meta_device_parameters_or_buffers -from lightning.fabric.utilities.load import _lazy_load, _materialize_tensors +from lightning.fabric.utilities.init import _has_all_dtensor_params_or_buffers, _has_meta_device_parameters_or_buffers from lightning.fabric.utilities.optimizer import _optimizers_to_device from lightning.fabric.utilities.seed import reset_seed from lightning.fabric.utilities.types import _PATH, ReduceOp -from lightning.pytorch.core.optimizer import LightningOptimizer from lightning.pytorch.plugins.precision import Precision from lightning.pytorch.plugins.precision.fsdp2 import FSDP2Precision from lightning.pytorch.strategies.launchers.subprocess_script import _SubprocessScriptLauncher @@ -76,12 +67,6 @@ if TYPE_CHECKING: from torch.distributed.device_mesh import DeviceMesh from torch.distributed.fsdp import CPUOffloadPolicy, MixedPrecisionPolicy - from torch.distributed.fsdp.fully_sharded_data_parallel import ShardingStrategy - from torch.distributed.fsdp.wrap import ModuleWrapPolicy - - _POLICY = Union[set[type[Module]], Callable[[Module, bool, int], bool], ModuleWrapPolicy] - _SHARDING_STRATEGY = Union[ShardingStrategy, Literal["FULL_SHARD", "SHARD_GRAD_OP", "NO_SHARD", "HYBRID_SHARD"]] - log = logging.getLogger(__name__) @@ -138,7 +123,6 @@ def __init__( timeout: Optional[timedelta] = default_pg_timeout, cpu_offload: Union[bool, "CPUOffloadPolicy", None] = None, mp_policy: Optional["MixedPrecisionPolicy"] = None, - state_dict_type: Literal["full", "sharded"] = "full", **kwargs: Any, ) -> None: super().__init__( @@ -155,7 +139,6 @@ def __init__( self.mp_policy = _init_fsdp2_mp_policy(mp_policy) self.device_mesh = device_mesh - self._state_dict_type = state_dict_type @property @override @@ -251,34 +234,43 @@ def _configure_launcher(self) -> None: def _setup_model(self, model: Module) -> Module: """Wraps the model into a :class:`~torch.distributed.fsdp.fully_sharded_data_parallel.FullyShardedDataParallel` module.""" - from torch.distributed.fsdp import FullyShardedDataParallel + from torch.distributed.fsdp import fully_shard - if any(isinstance(mod, FullyShardedDataParallel) for mod in model.modules()): - if _has_meta_device_parameters_or_buffers(model): - rank_zero_warn( - "The model is already wrapped in `FSDP` but there are still parameters on the meta device." - ) - if "auto_wrap_policy" in self.kwargs: - # The user has wrapped their submodules manually, don't apply the auto wrap policy. + if _has_all_dtensor_params_or_buffers(model): + rank_zero_warn("Model is already in DTensor format. FSDP2Strategy will not re-wrap the model.") + + else: + is_on_meta_device = True + if not _has_meta_device_parameters_or_buffers(model): + is_on_meta_device = False rank_zero_warn( - "A FSDP `auto_wrap_policy` is set, but the model is already wrapped. The policy will be ignored." + "To make the best use of FSDP2 strategy, and to avoid unnecessary memory copies, we recommend to" + " initialize your model's parameters and buffers on the `meta` device." ) - del self.kwargs["auto_wrap_policy"] - else: + log.debug(f"setting up FSDP model with device id: {self.root_device.index}, kwargs: {self.kwargs}") - model = FullyShardedDataParallel( + fully_shard( module=model, + mesh=self.device_mesh, + mp_policy=self.mp_policy, + offload_policy=self.cpu_offload, cpu_offload=self.cpu_offload, - mixed_precision=self.mixed_precision_config, - sharding_strategy=self.sharding_strategy, - device_id=self.root_device.index, - **self.kwargs, ) - _move_torchmetrics_to_device(model, self.root_device) + if is_on_meta_device: + # Allocate buffers and sharded parameters on device + model.to_empty(device=self.root_device) - # activation checkpointing needs to be set up after wrapping the model - _setup_activation_checkpointing(model, self._activation_checkpointing_kwargs) + # Run your custom initialization + def init_weights(m): + if isinstance(m, torch.nn.Linear): + torch.nn.init.kaiming_uniform_(m.weight) + if m.bias is not None: + torch.nn.init.zeros_(m.bias) + + model.apply(init_weights) + + _move_torchmetrics_to_device(model, self.root_device) return model @@ -297,7 +289,7 @@ def setup(self, trainer: "pl.Trainer") -> None: # legacy: we don't skip setup with the `configure_model` alternative rank_zero_info( "You have overridden `LightningModule.configure_sharded_model` hook. It will assume that all the layers" - " are already wrapped for sharding and won't wrap the entire model using `FullyShardedDataParallel`." + " are already wrapped for sharding and won't wrap the entire model using `fully_shard`." ) else: self.model = self._setup_model(self.model) @@ -349,27 +341,10 @@ def tensor_init_context(self, empty_init: Optional[bool] = None) -> Generator[No # Materialization happens in `setup`. When modules get wrapped by FSDP, the sequence of operations is: # 1) materialize module 2) call `reset_parameters()` 3) shard the module. # These operations are applied to each submodule 'bottom up' in the module hierarchy. - empty_init_context = torch.device("meta") if empty_init else nullcontext() + empty_init_context = torch.device("meta") if empty_init in (True, None) else nullcontext() with empty_init_context, self.precision_plugin.tensor_init_context(): yield - @contextmanager - @override - def model_sharded_context(self) -> Generator[None, None, None]: - log.debug(f"{self.__class__.__name__}: entered model_sharded_context.") - from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel - from torch.distributed.fsdp.wrap import enable_wrap - - with enable_wrap( - wrapper_cls=FullyShardedDataParallel, - cpu_offload=self.cpu_offload, - mixed_precision=self.mixed_precision_config, - sharding_strategy=self.sharding_strategy, - device_id=self.root_device.index, - **self.kwargs, - ): - yield - @override def barrier(self, name: Optional[str] = None) -> None: if not _distributed_is_initialized(): @@ -463,14 +438,8 @@ def register_strategies(cls, strategy_registry: _StrategyRegistry) -> None: @override def lightning_module_state_dict(self) -> dict[str, Any]: assert self.model is not None - if self._state_dict_type == "sharded": - state_dict_ctx = _get_sharded_state_dict_context(self.model) - elif self._state_dict_type == "full": - state_dict_ctx = _get_full_state_dict_context(self.model, world_size=self.world_size) - else: - raise ValueError(f"Unknown state_dict_type: {self._state_dict_type}") - with state_dict_ctx: - return self.model.state_dict() + # Override to do nothing, `save_checkpoint()` method will take care of saving the model state + return {} @override def load_model_state_dict(self, checkpoint: Mapping[str, Any], strict: bool = True) -> None: @@ -479,26 +448,8 @@ def load_model_state_dict(self, checkpoint: Mapping[str, Any], strict: bool = Tr @override def optimizer_state(self, optimizer: Optimizer) -> dict[str, Tensor]: - from torch.distributed.fsdp import FullyShardedDataParallel as FSDP - from torch.distributed.fsdp import OptimStateKeyType - - if isinstance(optimizer, LightningOptimizer): - optimizer = optimizer._optimizer - - assert self.model is not None - if self._state_dict_type == "sharded": - with _get_sharded_state_dict_context(self.model): - return FSDP.optim_state_dict(self.model, optimizer) - - elif self._state_dict_type == "full": - with _get_full_state_dict_context(self.model, world_size=self.world_size): - state_dict = FSDP.optim_state_dict(self.model, optimizer) - if self.global_rank == 0: - # Store the optimizer state dict in standard format - state_dict = FSDP.rekey_optim_state_dict(state_dict, OptimStateKeyType.PARAM_ID, self.model) - return state_dict - - raise ValueError(f"Unknown state_dict_type: {self._state_dict_type}") + # Override to do nothing, `save_checkpoint()` method will take care of saving the optimizer state + return {} @override def load_optimizer_state_dict(self, checkpoint: Mapping[str, Any]) -> None: @@ -511,128 +462,36 @@ def save_checkpoint( ) -> None: if storage_options is not None: raise TypeError( - "`FSDPStrategy.save_checkpoint(..., storage_options=...)` is not supported because" - " `FSDPStrategy` does not use the `CheckpointIO`." + "`FSDP2Strategy.save_checkpoint(..., storage_options=...)` is not supported because" + " `FSDP2Strategy` does not use the `CheckpointIO`." ) path = Path(self.broadcast(filepath)) - if path.is_dir() and self._state_dict_type == "full" and not _is_sharded_checkpoint(path): - raise IsADirectoryError(f"The checkpoint path exists and is a directory: {path}") - - if self._state_dict_type == "sharded": - if path.is_file(): - path.unlink() - path.mkdir(parents=True, exist_ok=True) - - converted_state = {"model": checkpoint.pop("state_dict")} - converted_state.update({ - f"optimizer_{idx}": optim_state - for idx, optim_state in enumerate(checkpoint.pop("optimizer_states", [])) - }) - - _distributed_checkpoint_save(converted_state, path) - - if self.global_rank == 0: - torch.save(checkpoint, path / _METADATA_FILENAME) - elif self._state_dict_type == "full": - if _is_sharded_checkpoint(path): - shutil.rmtree(path) - return super().save_checkpoint(checkpoint=checkpoint, filepath=path) - else: - raise ValueError(f"Unknown state_dict_type: {self._state_dict_type}") + + if path.is_file(): + path.unlink() + path.mkdir(parents=True, exist_ok=True) + + state_dict = {"fsdp2_checkpoint_state_dict": AppState(self.model, self.optimizers)} + _distributed_checkpoint_save(state_dict, path) + + if self.global_rank == 0: + torch.save(checkpoint, path / _METADATA_FILENAME) @override def load_checkpoint(self, checkpoint_path: _PATH) -> dict[str, Any]: # broadcast the path from rank 0 to ensure all the states are loaded from a common path path = Path(self.broadcast(checkpoint_path)) - from torch.distributed.fsdp import FullyShardedDataParallel as FSDP - assert self.model is not None assert self.lightning_module is not None - if _is_sharded_checkpoint(path): - from torch.distributed.checkpoint.optimizer import load_sharded_optimizer_state_dict - - state_dict_ctx = _get_sharded_state_dict_context(self.model) - - with state_dict_ctx: - module_state = {"model": self.model.state_dict()} - _distributed_checkpoint_load(module_state, path) - self.model.load_state_dict(module_state["model"], strict=self.lightning_module.strict_loading) - - if self.lightning_module.trainer.state.fn == TrainerFn.FITTING and self.optimizers: - from torch.distributed.checkpoint import FileSystemReader - - # TODO: replace with newer APIs - # https://github.com/pytorch/pytorch/issues/119800#issuecomment-1942156271 - reader = FileSystemReader(path=path) - # the optimizer states must be loaded separately - for idx, optim in enumerate(self.optimizers): - optim_key = f"optimizer_{idx}" - optim_state = load_sharded_optimizer_state_dict( - model_state_dict=module_state["model"], - optimizer_key=optim_key, - storage_reader=reader, - ) - flattened_osd = FSDP.optim_state_dict_to_load( - optim_state_dict=optim_state[optim_key], - model=self.model, - optim=optim, - ) - optim.load_state_dict(flattened_osd) - - # Load metadata (anything not a module or optimizer) - metadata = torch.load(path / _METADATA_FILENAME) - return metadata - - if _is_full_checkpoint(path): - checkpoint = _lazy_load(path) - _load_raw_module_state( - checkpoint.pop("state_dict"), - module=self.model, - world_size=self.world_size, - strict=self.lightning_module.strict_loading, - ) - - # Materialize lazy tensors if there are any left in the checkpoint - # The `torch.Optimizer.load_state_dict` method can't load lazy tensors because of deepcopy pickle issues - checkpoint = _materialize_tensors(checkpoint) - - from torch.distributed.fsdp import FullyShardedDataParallel as FSDP - from torch.distributed.fsdp import OptimStateKeyType - - optimizer_states = checkpoint.get("optimizer_states") - if optimizer_states is None or self.lightning_module.trainer.state.fn != TrainerFn.FITTING: - # If the optimizer states are not present, we don't need to do anything (backward compatibility) - return checkpoint - if len(self.optimizers) != len(optimizer_states): - raise RuntimeError( - f"You have configured {len(self.optimizers)} optimizers but the checkpoint contains" - f" {len(optimizer_states)} optimizers to load. Please resume training with the same number" - " of optimizers or edit the checkpoint manually to remove states." - ) + state_dict = {"fsdp2_checkpoint_state_dict": AppState(self.model, self.optimizers)} + _distributed_checkpoint_load(state_dict, path) - # rank0_only should be false because we need to load the optimizer state on all ranks - with _get_full_state_dict_context(self.model, world_size=self.world_size, rank0_only=False): - for optimizer, opt_state in zip(self.optimizers, optimizer_states): - if isinstance(list(opt_state["state"].keys())[0], int): - # Handling the case where the optimizer state is saved from a normal optimizer - opt_state = FSDP.rekey_optim_state_dict(opt_state, OptimStateKeyType.PARAM_NAME, self.model) - - opt_state = FSDP.optim_state_dict_to_load( - optim_state_dict=opt_state, - model=self.model, - optim=optimizer, - ) - optimizer.load_state_dict(opt_state) - - return checkpoint - - raise ValueError( - f"The path {str(path)!r} does not point to a valid checkpoint. Make sure the path points to either a" - " directory with FSDP checkpoint shards, or a single file with a full checkpoint." - ) + # Load metadata (anything not a module or optimizer) + metadata = torch.load(path / _METADATA_FILENAME) + return metadata def _init_fsdp2_cpu_offload(cpu_offload: Optional[Union[bool, "CPUOffloadPolicy"]]) -> "CPUOffloadPolicy": @@ -660,3 +519,30 @@ def _init_fsdp2_mp_policy(mp_policy: Optional["MixedPrecisionPolicy"]) -> Option return mp_policy raise TypeError(f"`mp_policy` should be of type `MixedPrecisionPolicy`, got {type(mp_policy)}") + + +# Code taken from: https://docs.pytorch.org/tutorials/recipes/distributed_checkpoint_recipe.html#saving +class AppState(Stateful): + """This is a useful wrapper for checkpointing the Application State. Since this object is compliant with the + Stateful protocol, DCP will automatically call state_dict/load_stat_dict as needed in the dcp.save/load APIs. + + Note: We take advantage of this wrapper to handle calling distributed state dict methods on the model + and optimizer. + + """ + + def __init__(self, model, optimizers): + self.model = model + self.optimizers = optimizers + + def state_dict(self): + # this line automatically manages FSDP FQN's, + # as well as sets the default state dict type to FSDP.SHARDED_STATE_DICT + model_state_dict, optimizer_state_dict = get_state_dict(self.model, self.optimizers) + return {"model": model_state_dict, "optim": optimizer_state_dict} + + def load_state_dict(self, state_dict): + # sets our state dicts on the model and optimizer, now that we've loaded + set_state_dict( + self.model, self.optimizers, model_state_dict=state_dict["model"], optim_state_dict=state_dict["optim"] + ) From cc6de8246200121e65f8638296e6e46fba3ef709 Mon Sep 17 00:00:00 2001 From: Deependu Date: Wed, 10 Sep 2025 05:24:41 +0000 Subject: [PATCH 04/17] works. i'm still worthy --- src/lightning/fabric/strategies/fsdp.py | 5 +++++ src/lightning/pytorch/plugins/__init__.py | 2 ++ src/lightning/pytorch/strategies/__init__.py | 2 ++ src/lightning/pytorch/strategies/fsdp2.py | 16 ++++++++++++---- .../trainer/connectors/accelerator_connector.py | 4 ++++ 5 files changed, 25 insertions(+), 4 deletions(-) diff --git a/src/lightning/fabric/strategies/fsdp.py b/src/lightning/fabric/strategies/fsdp.py index baaee74af0ec9..8550835f9177c 100644 --- a/src/lightning/fabric/strategies/fsdp.py +++ b/src/lightning/fabric/strategies/fsdp.py @@ -31,6 +31,7 @@ from lightning_utilities.core.imports import RequirementCache from lightning_utilities.core.rank_zero import rank_zero_only as utils_rank_zero_only from torch import Tensor +from torch.distributed.tensor import DTensor from torch.nn import Module from torch.optim import Optimizer from typing_extensions import TypeGuard, override @@ -795,6 +796,10 @@ def _optimizer_has_flat_params(optimizer: Optimizer) -> bool: ) +def _optimizer_has_dtensor_params(optimizer: Optimizer) -> bool: + return any(isinstance(param, DTensor) for group in optimizer.param_groups for param in group["params"]) + + def _get_sharded_state_dict_context(module: Module) -> Generator[None, None, None]: from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp.api import ShardedOptimStateDictConfig, ShardedStateDictConfig, StateDictType diff --git a/src/lightning/pytorch/plugins/__init__.py b/src/lightning/pytorch/plugins/__init__.py index d4fd63807c78d..8ec85a2b5fc67 100644 --- a/src/lightning/pytorch/plugins/__init__.py +++ b/src/lightning/pytorch/plugins/__init__.py @@ -8,6 +8,7 @@ from lightning.pytorch.plugins.precision.deepspeed import DeepSpeedPrecision from lightning.pytorch.plugins.precision.double import DoublePrecision from lightning.pytorch.plugins.precision.fsdp import FSDPPrecision +from lightning.pytorch.plugins.precision.fsdp2 import FSDP2Precision from lightning.pytorch.plugins.precision.half import HalfPrecision from lightning.pytorch.plugins.precision.precision import Precision from lightning.pytorch.plugins.precision.transformer_engine import TransformerEnginePrecision @@ -28,6 +29,7 @@ "Precision", "TransformerEnginePrecision", "FSDPPrecision", + "FSDP2Precision", "XLAPrecision", "LayerSync", "TorchSyncBatchNorm", diff --git a/src/lightning/pytorch/strategies/__init__.py b/src/lightning/pytorch/strategies/__init__.py index 9c2b2a6a3a621..33b28b6ae59c2 100644 --- a/src/lightning/pytorch/strategies/__init__.py +++ b/src/lightning/pytorch/strategies/__init__.py @@ -18,6 +18,7 @@ from lightning.pytorch.strategies.ddp import DDPStrategy from lightning.pytorch.strategies.deepspeed import DeepSpeedStrategy from lightning.pytorch.strategies.fsdp import FSDPStrategy +from lightning.pytorch.strategies.fsdp2 import FSDP2Strategy from lightning.pytorch.strategies.model_parallel import ModelParallelStrategy from lightning.pytorch.strategies.parallel import ParallelStrategy from lightning.pytorch.strategies.single_device import SingleDeviceStrategy @@ -32,6 +33,7 @@ "DDPStrategy", "DeepSpeedStrategy", "FSDPStrategy", + "FSDP2Strategy", "ModelParallelStrategy", "ParallelStrategy", "SingleDeviceStrategy", diff --git a/src/lightning/pytorch/strategies/fsdp2.py b/src/lightning/pytorch/strategies/fsdp2.py index 0b86d2da633ca..0937dbe80387d 100644 --- a/src/lightning/pytorch/strategies/fsdp2.py +++ b/src/lightning/pytorch/strategies/fsdp2.py @@ -41,7 +41,7 @@ _distributed_checkpoint_load, _distributed_checkpoint_save, _move_torchmetrics_to_device, - _optimizer_has_flat_params, + _optimizer_has_dtensor_params, ) from lightning.fabric.utilities.distributed import ( _distributed_is_initialized, @@ -139,6 +139,7 @@ def __init__( self.mp_policy = _init_fsdp2_mp_policy(mp_policy) self.device_mesh = device_mesh + self.kwargs = kwargs @property @override @@ -249,12 +250,19 @@ def _setup_model(self, model: Module) -> Module: ) log.debug(f"setting up FSDP model with device id: {self.root_device.index}, kwargs: {self.kwargs}") + if isinstance(self.device_mesh, tuple): + from torch.distributed.device_mesh import DeviceMesh + + self.device_mesh = DeviceMesh("cuda", self.device_mesh) + + if self.mp_policy is None: + raise ValueError("`mp_policy` cannot be None when calling `fully_shard`.") + fully_shard( module=model, mesh=self.device_mesh, mp_policy=self.mp_policy, offload_policy=self.cpu_offload, - cpu_offload=self.cpu_offload, ) if is_on_meta_device: @@ -321,7 +329,7 @@ def setup_optimizers(self, trainer: "pl.Trainer") -> None: raise invalid_params_error = True - if invalid_params_error or any(not _optimizer_has_flat_params(optimizer) for optimizer in self.optimizers): + if invalid_params_error or any(not _optimizer_has_dtensor_params(optimizer) for optimizer in self.optimizers): # We avoid this limitation by setting `use_orig_params=True` raise ValueError( "The optimizer does not seem to reference any FSDP parameters. HINT: Make sure to create the" @@ -428,7 +436,7 @@ def register_strategies(cls, strategy_registry: _StrategyRegistry) -> None: cls._registered_strategies.append("fsdp2") strategy_registry.register( - "fsdp_cpu_offload", + "fsdp2_cpu_offload", cls, description="FSDP2 training with Full Sharding and CPU Offloading", cpu_offload=True, diff --git a/src/lightning/pytorch/trainer/connectors/accelerator_connector.py b/src/lightning/pytorch/trainer/connectors/accelerator_connector.py index 7f44de0589938..434ebc1a4afb6 100644 --- a/src/lightning/pytorch/trainer/connectors/accelerator_connector.py +++ b/src/lightning/pytorch/trainer/connectors/accelerator_connector.py @@ -42,6 +42,7 @@ CheckpointIO, DeepSpeedPrecision, DoublePrecision, + FSDP2Precision, FSDPPrecision, HalfPrecision, MixedPrecision, @@ -53,6 +54,7 @@ from lightning.pytorch.strategies import ( DDPStrategy, DeepSpeedStrategy, + FSDP2Strategy, FSDPStrategy, ModelParallelStrategy, ParallelStrategy, @@ -493,6 +495,8 @@ def _check_and_init_precision(self) -> Precision: return DeepSpeedPrecision(self._precision_flag) # type: ignore[arg-type] if isinstance(self.strategy, FSDPStrategy): return FSDPPrecision(self._precision_flag) # type: ignore[arg-type] + if isinstance(self.strategy, FSDP2Strategy): + return FSDP2Precision(self._precision_flag) # type: ignore[arg-type] if self._precision_flag in ("16-true", "bf16-true"): return HalfPrecision(self._precision_flag) # type: ignore if self._precision_flag == "32-true": From daa86674755ce7301f0787235a5ca8e982c24875 Mon Sep 17 00:00:00 2001 From: Deependu Date: Wed, 10 Sep 2025 05:26:47 +0000 Subject: [PATCH 05/17] update --- src/lightning/fabric/strategies/fsdp.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/lightning/fabric/strategies/fsdp.py b/src/lightning/fabric/strategies/fsdp.py index 8550835f9177c..3809976548849 100644 --- a/src/lightning/fabric/strategies/fsdp.py +++ b/src/lightning/fabric/strategies/fsdp.py @@ -31,7 +31,6 @@ from lightning_utilities.core.imports import RequirementCache from lightning_utilities.core.rank_zero import rank_zero_only as utils_rank_zero_only from torch import Tensor -from torch.distributed.tensor import DTensor from torch.nn import Module from torch.optim import Optimizer from typing_extensions import TypeGuard, override @@ -797,6 +796,8 @@ def _optimizer_has_flat_params(optimizer: Optimizer) -> bool: def _optimizer_has_dtensor_params(optimizer: Optimizer) -> bool: + from torch.distributed.tensor import DTensor + return any(isinstance(param, DTensor) for group in optimizer.param_groups for param in group["params"]) From 9279f4923a0763e5c32e124377003cfad27666ff Mon Sep 17 00:00:00 2001 From: Deependu Jha Date: Wed, 10 Sep 2025 14:31:58 +0530 Subject: [PATCH 06/17] fix mypy issues and install-pkg ci --- src/lightning/fabric/utilities/registry.py | 4 +-- src/lightning/pytorch/strategies/fsdp2.py | 33 ++++++++++++++++------ 2 files changed, 25 insertions(+), 12 deletions(-) diff --git a/src/lightning/fabric/utilities/registry.py b/src/lightning/fabric/utilities/registry.py index 7d8f6ca17712e..79ca438a75d93 100644 --- a/src/lightning/fabric/utilities/registry.py +++ b/src/lightning/fabric/utilities/registry.py @@ -36,9 +36,7 @@ def _load_external_callbacks(group: str) -> list[Any]: A list of all callbacks collected from external factories. """ - factories = ( - entry_points(group=group) if _PYTHON_GREATER_EQUAL_3_10_0 else entry_points().get(group, {}) # type: ignore[arg-type] - ) + factories = entry_points(group=group) if _PYTHON_GREATER_EQUAL_3_10_0 else entry_points().get(group, {}) external_callbacks: list[Any] = [] for factory in factories: diff --git a/src/lightning/pytorch/strategies/fsdp2.py b/src/lightning/pytorch/strategies/fsdp2.py index 0937dbe80387d..c5a0475422e0d 100644 --- a/src/lightning/pytorch/strategies/fsdp2.py +++ b/src/lightning/pytorch/strategies/fsdp2.py @@ -26,8 +26,6 @@ import torch from lightning_utilities.core.rank_zero import rank_zero_only as utils_rank_zero_only from torch import Tensor -from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_dict -from torch.distributed.checkpoint.stateful import Stateful from torch.nn import Module from torch.optim import Optimizer from typing_extensions import override @@ -66,7 +64,15 @@ if TYPE_CHECKING: from torch.distributed.device_mesh import DeviceMesh - from torch.distributed.fsdp import CPUOffloadPolicy, MixedPrecisionPolicy + from torch.distributed.fsdp import CPUOffloadPolicy, MixedPrecisionPolicy, OffloadPolicy + +try: + from torch.distributed.checkpoint.stateful import Stateful +except ImportError: + # define a no-op base class for compatibility + class Stateful: + pass + log = logging.getLogger(__name__) @@ -113,7 +119,7 @@ class FSDP2Strategy(ParallelStrategy): def __init__( self, - device_mesh: Union[tuple[int], "DeviceMesh"] = None, + device_mesh: Optional[Union[tuple[int], "DeviceMesh"]] = None, accelerator: Optional["pl.accelerators.Accelerator"] = None, parallel_devices: Optional[list[torch.device]] = None, cluster_environment: Optional[ClusterEnvironment] = None, @@ -270,7 +276,7 @@ def _setup_model(self, model: Module) -> Module: model.to_empty(device=self.root_device) # Run your custom initialization - def init_weights(m): + def init_weights(m: Module) -> None: if isinstance(m, torch.nn.Linear): torch.nn.init.kaiming_uniform_(m.weight) if m.bias is not None: @@ -480,6 +486,11 @@ def save_checkpoint( path.unlink() path.mkdir(parents=True, exist_ok=True) + if self.model is None: + raise RuntimeError( + "Cannot save checkpoint: FSDP2Strategy model is not initialized." + " Please ensure the strategy is set up before saving." + ) state_dict = {"fsdp2_checkpoint_state_dict": AppState(self.model, self.optimizers)} _distributed_checkpoint_save(state_dict, path) @@ -502,7 +513,7 @@ def load_checkpoint(self, checkpoint_path: _PATH) -> dict[str, Any]: return metadata -def _init_fsdp2_cpu_offload(cpu_offload: Optional[Union[bool, "CPUOffloadPolicy"]]) -> "CPUOffloadPolicy": +def _init_fsdp2_cpu_offload(cpu_offload: Optional[Union[bool, "CPUOffloadPolicy"]]) -> "OffloadPolicy": from torch.distributed.fsdp import CPUOffloadPolicy, OffloadPolicy if cpu_offload is None or cpu_offload is False: @@ -539,17 +550,21 @@ class AppState(Stateful): """ - def __init__(self, model, optimizers): + def __init__(self, model: Module, optimizers: list[Optimizer]) -> None: self.model = model self.optimizers = optimizers - def state_dict(self): + def state_dict(self) -> dict[str, Any]: + from torch.distributed.checkpoint.state_dict import get_state_dict + # this line automatically manages FSDP FQN's, # as well as sets the default state dict type to FSDP.SHARDED_STATE_DICT model_state_dict, optimizer_state_dict = get_state_dict(self.model, self.optimizers) return {"model": model_state_dict, "optim": optimizer_state_dict} - def load_state_dict(self, state_dict): + def load_state_dict(self, state_dict: dict[str, Any]) -> None: + from torch.distributed.checkpoint.state_dict import set_state_dict + # sets our state dicts on the model and optimizer, now that we've loaded set_state_dict( self.model, self.optimizers, model_state_dict=state_dict["model"], optim_state_dict=state_dict["optim"] From caa2dd936e5e1befca75c65a246853a47e0e0f57 Mon Sep 17 00:00:00 2001 From: Deependu Jha Date: Wed, 10 Sep 2025 14:43:51 +0530 Subject: [PATCH 07/17] update --- src/lightning/fabric/utilities/registry.py | 2 +- src/lightning/pytorch/strategies/fsdp2.py | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/lightning/fabric/utilities/registry.py b/src/lightning/fabric/utilities/registry.py index 79ca438a75d93..03c30972c98b0 100644 --- a/src/lightning/fabric/utilities/registry.py +++ b/src/lightning/fabric/utilities/registry.py @@ -36,7 +36,7 @@ def _load_external_callbacks(group: str) -> list[Any]: A list of all callbacks collected from external factories. """ - factories = entry_points(group=group) if _PYTHON_GREATER_EQUAL_3_10_0 else entry_points().get(group, {}) + factories = entry_points(group=group) if _PYTHON_GREATER_EQUAL_3_10_0 else entry_points().get(group, {}) # type: ignore[arg-type] external_callbacks: list[Any] = [] for factory in factories: diff --git a/src/lightning/pytorch/strategies/fsdp2.py b/src/lightning/pytorch/strategies/fsdp2.py index c5a0475422e0d..db9dfb8d12f65 100644 --- a/src/lightning/pytorch/strategies/fsdp2.py +++ b/src/lightning/pytorch/strategies/fsdp2.py @@ -67,10 +67,10 @@ from torch.distributed.fsdp import CPUOffloadPolicy, MixedPrecisionPolicy, OffloadPolicy try: - from torch.distributed.checkpoint.stateful import Stateful + from torch.distributed.checkpoint.stateful import Stateful as _TorchStateful except ImportError: - # define a no-op base class for compatibility - class Stateful: + + class _TorchStateful: # type: ignore[no-redef] pass @@ -541,7 +541,7 @@ def _init_fsdp2_mp_policy(mp_policy: Optional["MixedPrecisionPolicy"]) -> Option # Code taken from: https://docs.pytorch.org/tutorials/recipes/distributed_checkpoint_recipe.html#saving -class AppState(Stateful): +class AppState(_TorchStateful): """This is a useful wrapper for checkpointing the Application State. Since this object is compliant with the Stateful protocol, DCP will automatically call state_dict/load_stat_dict as needed in the dcp.save/load APIs. From 648729599f3ba631cedd674c635a47edbebed59b Mon Sep 17 00:00:00 2001 From: Deependu Jha Date: Wed, 10 Sep 2025 16:06:12 +0530 Subject: [PATCH 08/17] fsdp2 tests started --- .../connectors/accelerator_connector.py | 14 +- tests/tests_pytorch/strategies/test_fsdp2.py | 822 ++++++++++++++++++ 2 files changed, 833 insertions(+), 3 deletions(-) create mode 100644 tests/tests_pytorch/strategies/test_fsdp2.py diff --git a/src/lightning/pytorch/trainer/connectors/accelerator_connector.py b/src/lightning/pytorch/trainer/connectors/accelerator_connector.py index 434ebc1a4afb6..9028e0ca02470 100644 --- a/src/lightning/pytorch/trainer/connectors/accelerator_connector.py +++ b/src/lightning/pytorch/trainer/connectors/accelerator_connector.py @@ -453,11 +453,19 @@ def _check_strategy_and_fallback(self) -> None: # TODO this logic should apply to both str and object config strategy_flag = "" if isinstance(self._strategy_flag, Strategy) else self._strategy_flag - if ( + is_fsdp1_str = ( strategy_flag in FSDPStrategy.get_registered_strategies() or type(self._strategy_flag) is FSDPStrategy - ) and not (self._accelerator_flag in ("cuda", "gpu") or isinstance(self._accelerator_flag, CUDAAccelerator)): + ) + is_fsdp2_str = ( + strategy_flag in FSDP2Strategy.get_registered_strategies() or type(self._strategy_flag) is FSDP2Strategy + ) + + if (is_fsdp1_str or is_fsdp2_str) and not ( + self._accelerator_flag in ("cuda", "gpu") or isinstance(self._accelerator_flag, CUDAAccelerator) + ): + strategy_name = FSDP2Strategy.strategy_name if is_fsdp2_str else FSDPStrategy.strategy_name raise ValueError( - f"The strategy `{FSDPStrategy.strategy_name}` requires a GPU accelerator, but received " + f"The strategy `{strategy_name}` requires a GPU accelerator, but received " f"`accelerator={self._accelerator_flag!r}`. Please set `accelerator='cuda'`, `accelerator='gpu'`," " or pass a `CUDAAccelerator()` instance to use FSDP." ) diff --git a/tests/tests_pytorch/strategies/test_fsdp2.py b/tests/tests_pytorch/strategies/test_fsdp2.py new file mode 100644 index 0000000000000..cac6b9f121002 --- /dev/null +++ b/tests/tests_pytorch/strategies/test_fsdp2.py @@ -0,0 +1,822 @@ +import os +from contextlib import nullcontext +from copy import deepcopy +from datetime import timedelta +from functools import partial +from pathlib import Path +from re import escape +from typing import Optional +from unittest import mock +from unittest.mock import ANY, MagicMock, Mock + +import pytest +import torch +import torch.nn as nn +from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload +from torch.distributed.fsdp.wrap import ModuleWrapPolicy, size_based_auto_wrap_policy +from torchmetrics import Accuracy + +from lightning.fabric.plugins.environments import LightningEnvironment +from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_3 +from lightning.fabric.utilities.init import _has_all_dtensor_params_or_buffers +from lightning.fabric.utilities.load import _load_distributed_checkpoint +from lightning.pytorch import Trainer +from lightning.pytorch.callbacks import ModelCheckpoint +from lightning.pytorch.demos.boring_classes import BoringModel +from lightning.pytorch.plugins import HalfPrecision +from lightning.pytorch.strategies import FSDP2Strategy, FSDPStrategy +from lightning.pytorch.trainer.states import TrainerFn +from lightning.pytorch.utilities.consolidate_checkpoint import _format_checkpoint +from tests_pytorch.helpers.runif import RunIf + + +class TestFSDP2Model(BoringModel): + def __init__(self): + super().__init__() + self.layer: Optional[nn.Module] = None + + def _init_model(self) -> None: + self.layer = torch.nn.Sequential(torch.nn.Linear(32, 32), torch.nn.ReLU(), torch.nn.Linear(32, 2)) + + def configure_optimizers(self): + # There is some issue with SGD optimizer state in FSDP + return torch.optim.AdamW(self.layer.parameters(), lr=0.1) + + def on_train_batch_start(self, batch, batch_idx): + assert batch.dtype == torch.float32 + + def on_train_batch_end(self, _, batch, batch_idx): + assert batch.dtype == torch.float32 + self._assert_layer_fsdp2_instance() + + def on_test_batch_end(self, _, batch, batch_idx): + assert batch.dtype == torch.float32 + self._assert_layer_fsdp2_instance() + + def on_validation_batch_end(self, _, batch, batch_idx): + assert batch.dtype == torch.float32 + self._assert_layer_fsdp2_instance() + + def on_predict_batch_end(self, _, batch, batch_idx): + assert batch.dtype == torch.float32 + self._assert_layer_fsdp2_instance() + + def _assert_layer_fsdp2_instance(self): + # FSDP2 does not wrap modules in a distinct class (like FSDP1’s FullyShardedDataParallel). + # Instead, it injects an internal `_fsdp_state` attribute and replaces all parameters/buffers + # with DTensors in place. These two checks together confirm that `self.layer` is FSDP2-wrapped. + assert hasattr(self.layer, "_fsdp_state") + assert _has_all_dtensor_params_or_buffers(self.layer) + + +class TestBoringModel(BoringModel): + def __init__(self): + super().__init__(self) + + self.save_hyperparameters() + self.layer = torch.nn.Sequential(torch.nn.Linear(32, 32), torch.nn.ReLU(), torch.nn.Linear(32, 2)) + + def configure_optimizers(self): + # SGD's FSDP optimizer, state is fixed in https://github.com/pytorch/pytorch/pull/99214 + return torch.optim.AdamW(self.parameters(), lr=0.1) + + +class TestFSDP2ModelAutoWrapped(TestBoringModel): + def on_train_batch_start(self, batch, batch_idx): + assert batch.dtype == torch.float32 + + def on_train_batch_end(self, _, batch, batch_idx): + assert batch.dtype == torch.float32 + self._assert_layer_fsdp2_instance() + + def on_test_batch_end(self, _, batch, batch_idx): + assert batch.dtype == torch.float32 + self._assert_layer_fsdp2_instance() + + def on_validation_batch_end(self, _, batch, batch_idx): + assert batch.dtype == torch.float32 + self._assert_layer_fsdp2_instance() + + def on_predict_batch_end(self, _, batch, batch_idx): + assert batch.dtype == torch.float32 + self._assert_layer_fsdp2_instance() + + def _assert_layer_fsdp2_instance(self): + # FSDP2 does not wrap modules in a distinct class (like FSDP1’s FullyShardedDataParallel). + # Instead, it injects an internal `_fsdp_state` attribute and replaces all parameters/buffers + # with DTensors in place. These two checks together confirm that `self.layer` is FSDP2-wrapped. + assert hasattr(self.layer, "_fsdp_state") + assert _has_all_dtensor_params_or_buffers(self.layer) + + +def _run_multiple_stages(trainer, model, model_path: Optional[str] = None): + trainer.fit(model) + trainer.test(model) + + model_path = trainer.strategy.broadcast(model_path) + model_path = Path(model_path if model_path else trainer.checkpoint_callback.last_model_path) + + # Save another checkpoint after testing, without optimizer states + trainer.save_checkpoint(model_path.with_name("after-test")) + trainer.save_checkpoint(model_path, weights_only=True) + + if not model_path.is_dir(): # TODO (@awaelchli): Add support for asserting equality of sharded checkpoints + _assert_save_equality(trainer, model_path, cls=model.__class__) + + with torch.inference_mode(): + # Test entry point + trainer.test(model) # model is wrapped, will not call `configure_model` + + # provide model path, will create a new unwrapped model and load and then call `configure_shared_model` to wrap + trainer.test(ckpt_path=model_path) + + # Predict entry point + trainer.predict(model) # model is wrapped, will not call `configure_model` + + # provide model path, will create a new unwrapped model and load and then call `configure_shared_model` to wrap + trainer.predict(ckpt_path=model_path) + + +def _assert_save_equality(trainer, ckpt_path, cls=TestFSDP2Model): + # Use FullySharded to get the state dict for the sake of comparison + model_state_dict = trainer.strategy.lightning_module_state_dict() + + if trainer.is_global_zero: + saved_model = cls.load_from_checkpoint(ckpt_path) + + # Assert model parameters are identical after loading + for ddp_param, shard_param in zip(model_state_dict.values(), saved_model.state_dict().values()): + assert torch.equal(ddp_param, shard_param) + + +@pytest.mark.parametrize("strategy", ["fsdp2", "fsdp2_cpu_offload"]) +def test_invalid_on_cpu(tmp_path, cuda_count_0, strategy): + """Test to ensure that we raise Misconfiguration for FSDP on CPU.""" + with pytest.raises(ValueError, match="The strategy `fsdp2` requires a GPU accelerator"): + trainer = Trainer(accelerator="cpu", default_root_dir=tmp_path, fast_dev_run=True, strategy=strategy) + assert isinstance(trainer.strategy, FSDP2Strategy) + trainer.strategy.setup_environment() + + +def test_custom_mixed_precision(): + """Test to ensure that passing a custom mixed precision config works.""" + from torch.distributed.fsdp import MixedPrecisionPolicy + + # custom mp policy + mp_policy = MixedPrecisionPolicy( + compute_dtype=torch.float16, param_dtype=torch.bfloat16, buffer_dtype=torch.float16, cast_forward_inputs=True + ) + strategy = FSDP2Strategy(mp_policy=mp_policy) + assert strategy.mp_policy == mp_policy + + # default mp policy + strategy = FSDP2Strategy(mp_policy=None) + assert isinstance(strategy.mp_policy, MixedPrecisionPolicy) + assert strategy.mp_policy.param_dtype is None + assert strategy.mp_policy.reduce_dtype is None + assert strategy.mp_policy.output_dtype is None + assert strategy.mp_policy.cast_forward_inputs is True + + # invalid mp policy + class InvalidMPPolicy: + pass + + with pytest.raises(TypeError, match="`mp_policy` should be of type `MixedPrecisionPolicy`"): + FSDP2Strategy(mp_policy=InvalidMPPolicy()) + + +@pytest.mark.filterwarnings("ignore::FutureWarning") +@RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True) +def test_strategy_sync_batchnorm(tmp_path): + """Test to ensure that sync_batchnorm works when using FSDP and GPU, and all stages can be run.""" + model = TestFSDP2Model() + trainer = Trainer( + default_root_dir=tmp_path, + accelerator="gpu", + devices=2, + strategy="fsdp2", + precision="16-true", + max_epochs=1, + sync_batchnorm=True, + ) + _run_multiple_stages(trainer, model, os.path.join(tmp_path, "last.ckpt")) + + +@pytest.mark.filterwarnings("ignore::FutureWarning") +@RunIf(min_cuda_gpus=1, skip_windows=True) +def test_modules_without_parameters(tmp_path): + """Test that TorchMetrics get moved to the device despite not having any parameters.""" + + class MetricsModel(BoringModel): + def __init__(self): + super().__init__() + self.metric = Accuracy("multiclass", num_classes=10) + assert self.metric.device == self.metric.tp.device == torch.device("cpu") + + def setup(self, stage) -> None: + assert self.metric.device == self.metric.tp.device == torch.device("cpu") + + def training_step(self, batch, batch_idx): + loss = super().training_step(batch, batch_idx) + assert self.metric.device == self.metric.tp.device == torch.device("cuda", 0) + self.metric(torch.rand(2, 10, device=self.device), torch.randint(0, 10, size=(2,), device=self.device)) + return loss + + model = MetricsModel() + trainer = Trainer( + default_root_dir=tmp_path, + accelerator="cuda", + devices=1, + strategy="fsdp2", + max_steps=1, + ) + trainer.fit(model) + + +@pytest.mark.filterwarnings("ignore::FutureWarning") +@RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True) +@pytest.mark.parametrize("precision", ["16-mixed", pytest.param("bf16-mixed", marks=RunIf(bf16_cuda=True))]) +def test_strategy_checkpoint(state_dict_type, precision, tmp_path): + """Test to ensure that checkpoint is saved correctly when using a single GPU, and all stages can be run.""" + model = TestFSDP2Model() + strategy = FSDP2Strategy() + trainer = Trainer( + default_root_dir=tmp_path, accelerator="gpu", devices=2, strategy=strategy, precision=precision, max_epochs=1 + ) + _run_multiple_stages(trainer, model, os.path.join(tmp_path, "last.ckpt")) + + +def custom_auto_wrap_policy( + module, + recurse, + nonwrapped_numel: int, +) -> bool: + return nonwrapped_numel >= 2 + + +@pytest.mark.filterwarnings("ignore::FutureWarning") +@RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True) +@pytest.mark.parametrize( + ("model", "strategy", "strategy_cfg"), + [ + pytest.param(TestFSDP2Model(), "fsdp", None, id="manually_wrapped"), + pytest.param( + TestFSDP2ModelAutoWrapped(), + FSDPStrategy, + {"auto_wrap_policy": custom_auto_wrap_policy}, + id="autowrap_2x", + ), + pytest.param( + TestFSDP2ModelAutoWrapped(), + FSDPStrategy, + { + "auto_wrap_policy": ModuleWrapPolicy({nn.Linear}), + "use_orig_params": True, + }, + id="autowrap_use_orig_params", + ), + ], +) +def test_checkpoint_multi_gpus(tmp_path, model, strategy, strategy_cfg): + """Test to ensure that checkpoint is saved correctly when using multiple GPUs, and all stages can be run.""" + ck = ModelCheckpoint(save_last=True) + + strategy_cfg = strategy_cfg or {} + if not isinstance(strategy, str): + strategy = strategy(**strategy_cfg) + + trainer = Trainer( + default_root_dir=tmp_path, + accelerator="gpu", + devices=2, + strategy=strategy, + precision="16-mixed", + max_epochs=1, + limit_train_batches=2, + limit_val_batches=2, + limit_test_batches=2, + limit_predict_batches=2, + callbacks=[ck], + ) + _run_multiple_stages(trainer, model) + + +@RunIf(min_cuda_gpus=1, skip_windows=True, standalone=True) +@pytest.mark.parametrize("use_orig_params", [None, False, True]) +def test_invalid_parameters_in_optimizer(use_orig_params): + fsdp_kwargs = {} + if use_orig_params is not None: + fsdp_kwargs = {"use_orig_params": use_orig_params} + + trainer = Trainer( + strategy=FSDPStrategy(**fsdp_kwargs), + accelerator="cuda", + devices=1, + fast_dev_run=1, + ) + + class EmptyParametersModel(BoringModel): + def configure_optimizers(self): + return torch.optim.Adam(self.parameters(), lr=1e-2) + + model = EmptyParametersModel() + trainer.fit(model) + + class NoFlatParametersModel(BoringModel): + def configure_optimizers(self): + layer = torch.nn.Linear(4, 5) + return torch.optim.Adam(layer.parameters(), lr=1e-2) + + error_context = ( + nullcontext() + if use_orig_params is not False + else pytest.raises(ValueError, match="The optimizer does not seem to reference any FSDP parameters") + ) + + model = NoFlatParametersModel() + with error_context: + trainer.fit(model) + + +def test_forbidden_precision_raises(): + with pytest.raises(TypeError, match="can only work with the `FSDPPrecision"): + FSDPStrategy(precision_plugin=HalfPrecision()) + + strategy = FSDPStrategy() + with pytest.raises(TypeError, match="can only work with the `FSDPPrecision"): + strategy.precision_plugin = HalfPrecision() + + +def test_activation_checkpointing(): + """Test that the FSDP strategy can apply activation checkpointing to the given layers.""" + + class Block1(nn.Linear): + pass + + class Block2(nn.Linear): + pass + + class Model(BoringModel): + def __init__(self): + super().__init__() + self.layer0 = nn.Sequential(Block1(4, 4), Block1(5, 5)) + self.layer1 = Block2(2, 2) + self.layer2 = nn.Linear(3, 3) + + strategy = FSDPStrategy(activation_checkpointing_policy={Block1}) + assert set(strategy._activation_checkpointing_kwargs) == {"auto_wrap_policy"} + assert isinstance(strategy._activation_checkpointing_kwargs["auto_wrap_policy"], ModuleWrapPolicy) + + strategy = FSDPStrategy(activation_checkpointing_policy=ModuleWrapPolicy({Block1, Block2})) + assert set(strategy._activation_checkpointing_kwargs) == {"auto_wrap_policy"} + assert isinstance(strategy._activation_checkpointing_kwargs["auto_wrap_policy"], ModuleWrapPolicy) + + model = Model() + strategy._parallel_devices = [torch.device("cuda", 0)] + strategy._lightning_module = model + strategy._process_group = Mock() + with ( + mock.patch("torch.distributed.fsdp.FullyShardedDataParallel", new=MagicMock), + mock.patch( + "torch.distributed.algorithms._checkpoint.checkpoint_wrapper.apply_activation_checkpointing" + ) as apply_mock, + ): + wrapped = strategy._setup_model(model) + apply_mock.assert_called_with(wrapped, checkpoint_wrapper_fn=ANY, **strategy._activation_checkpointing_kwargs) + + +def test_strategy_cpu_offload(): + """Test the different ways cpu offloading can be enabled.""" + # bool + strategy = FSDPStrategy(cpu_offload=True) + assert strategy.cpu_offload == CPUOffload(offload_params=True) + + # dataclass + config = CPUOffload() + strategy = FSDPStrategy(cpu_offload=config) + assert strategy.cpu_offload == config + + +def test_sharding_strategy(): + """Test the different ways the sharding strategy can be set.""" + from torch.distributed.fsdp import ShardingStrategy + + # default + strategy = FSDPStrategy() + assert strategy.sharding_strategy == ShardingStrategy.FULL_SHARD + + # enum + strategy = FSDPStrategy(sharding_strategy=ShardingStrategy.SHARD_GRAD_OP) + assert strategy.sharding_strategy == ShardingStrategy.SHARD_GRAD_OP + + # string + strategy = FSDPStrategy(sharding_strategy="NO_SHARD") + assert strategy.sharding_strategy == ShardingStrategy.NO_SHARD + strategy = FSDPStrategy(sharding_strategy="no_shard") + assert strategy.sharding_strategy == ShardingStrategy.NO_SHARD + + +@pytest.mark.parametrize("sharding_strategy", ["HYBRID_SHARD", "_HYBRID_SHARD_ZERO2"]) +def test_hybrid_shard_configuration(sharding_strategy, monkeypatch): + """Test that the hybrid sharding strategies can only be used with automatic wrapping or a manually specified pg.""" + with pytest.raises(RuntimeError, match="The hybrid sharding strategy requires you to pass at least one of"): + FSDPStrategy(sharding_strategy=sharding_strategy) + + strategy = FSDPStrategy(auto_wrap_policy={nn.Linear}, sharding_strategy=sharding_strategy) + assert strategy.sharding_strategy.name == sharding_strategy + + process_group = (Mock(), Mock()) + strategy = FSDPStrategy(sharding_strategy=sharding_strategy, process_group=process_group) + assert strategy.sharding_strategy.name == sharding_strategy + assert strategy.kwargs["process_group"] is process_group + + monkeypatch.setattr("lightning.pytorch.strategies.fsdp._TORCH_GREATER_EQUAL_2_2", False) + with pytest.raises(ValueError, match="`device_mesh` argument is only supported in torch >= 2.2."): + FSDPStrategy(device_mesh=Mock()) + + monkeypatch.setattr("lightning.pytorch.strategies.fsdp._TORCH_GREATER_EQUAL_2_2", True) + device_mesh = Mock() + strategy = FSDPStrategy(sharding_strategy=sharding_strategy, device_mesh=device_mesh) + assert strategy.sharding_strategy.name == sharding_strategy + assert strategy.kwargs["device_mesh"] is device_mesh + + with pytest.raises(ValueError, match="process_group.* device_mesh=.* are mutually exclusive"): + FSDPStrategy(sharding_strategy=sharding_strategy, process_group=process_group, device_mesh=device_mesh) + + +def test_use_orig_params(): + """Test that Lightning enables `use_orig_params` automatically.""" + strategy = FSDPStrategy() + assert strategy.kwargs["use_orig_params"] + strategy = FSDPStrategy(use_orig_params=False) + assert not strategy.kwargs["use_orig_params"] + + +@mock.patch("torch.distributed.init_process_group") +def test_set_timeout(init_process_group_mock): + """Test that the timeout gets passed to the ``torch.distributed.init_process_group`` function.""" + test_timedelta = timedelta(seconds=30) + strategy = FSDPStrategy(timeout=test_timedelta, parallel_devices=[torch.device("cpu")]) + strategy.cluster_environment = LightningEnvironment() + strategy.accelerator = Mock() + strategy.setup_environment() + process_group_backend = strategy._get_process_group_backend() + global_rank = strategy.cluster_environment.global_rank() + world_size = strategy.cluster_environment.world_size() + kwargs = {} + if _TORCH_GREATER_EQUAL_2_3: + kwargs["device_id"] = strategy.root_device if strategy.root_device.type != "cpu" else None + init_process_group_mock.assert_called_with( + process_group_backend, rank=global_rank, world_size=world_size, timeout=test_timedelta, **kwargs + ) + + +@mock.patch("lightning.pytorch.strategies.fsdp._load_raw_module_state") +def test_strategy_load_optimizer_states_multiple(_, tmp_path): + strategy = FSDPStrategy(parallel_devices=[torch.device("cpu")], state_dict_type="full") + trainer = Trainer() + trainer.state.fn = TrainerFn.FITTING + strategy._lightning_module = Mock(trainer=trainer) + spec = torch.optim.Optimizer + + # More states than optimizers configured + strategy.optimizers = [Mock(spec=spec)] + checkpoint = {"state_dict": {}, "optimizer_states": [{"state": {}}, {"state": {}}]} + torch.save(checkpoint, tmp_path / "two-states.ckpt") + with pytest.raises(RuntimeError, match="1 optimizers but the checkpoint contains 2 optimizers to load"): + strategy.load_checkpoint(tmp_path / "two-states.ckpt") + + # Fewer states than optimizers configured + strategy.optimizers = [Mock(spec=spec), Mock(spec=spec)] + checkpoint = {"state_dict": {}, "optimizer_states": [{"state": {}}]} + torch.save(checkpoint, tmp_path / "one-state.ckpt") + with pytest.raises(RuntimeError, match="2 optimizers but the checkpoint contains 1 optimizers to load"): + strategy.load_checkpoint(tmp_path / "one-state.ckpt") + + +@pytest.mark.filterwarnings("ignore::FutureWarning") +@RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True) +@pytest.mark.parametrize("wrap_min_params", [2, 1024, 100000000]) +def test_strategy_save_optimizer_states(tmp_path, wrap_min_params): + """Test to ensure that the full state dict and optimizer states is saved when using FSDP strategy. + + Based on `wrap_min_params`, the model will be fully wrapped, half wrapped, and not wrapped at all. If the model can + be restored to DDP, it means that the optimizer states were saved correctly. + + """ + model = TestFSDP2ModelAutoWrapped(wrap_min_params=wrap_min_params) + + strategy = FSDPStrategy(auto_wrap_policy=partial(size_based_auto_wrap_policy, min_num_params=wrap_min_params)) + trainer = Trainer( + default_root_dir=tmp_path, + accelerator="gpu", + devices=2, + strategy=strategy, + precision="16-mixed", + max_epochs=1, + barebones=True, + ) + + trainer.fit(model) + model_path = os.path.join(tmp_path, "last.ckpt") + model_path = trainer.strategy.broadcast(model_path) + trainer.save_checkpoint(model_path) + + model_state_dict = trainer.strategy.lightning_module_state_dict() + optimizer_state_dict = trainer.strategy.optimizer_state(model.optimizers()) + + if trainer.global_rank != 0: + assert len(model_state_dict) == 0 + + if trainer.global_rank != 0: + assert len(optimizer_state_dict) == 0 + + # restore model to ddp + model = TestBoringModel() + trainer = Trainer(default_root_dir=tmp_path, accelerator="gpu", devices=2, strategy="ddp", max_epochs=1) + + # This step will restore the model and optimizer states + trainer.fit(model, ckpt_path=model_path) + + # Get the model and optimizer states from the restored ddp model + restored_model_state_dict = trainer.strategy.lightning_module_state_dict() + restored_optimizer_state_dict = trainer.strategy.optimizer_state(model.optimizers()) + + if trainer.global_rank == 0: + # assert everything is the same + assert len(model_state_dict) == len(restored_model_state_dict) + assert len(optimizer_state_dict) == len(restored_optimizer_state_dict) + + torch.testing.assert_close(model_state_dict, restored_model_state_dict, atol=0, rtol=0) + torch.testing.assert_close(optimizer_state_dict, restored_optimizer_state_dict, atol=0, rtol=0) + + trainer.strategy.barrier() + + +@pytest.mark.filterwarnings("ignore::FutureWarning") +@RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True) +@pytest.mark.parametrize("wrap_min_params", [2, 1024, 100000000]) +def test_strategy_load_optimizer_states(wrap_min_params, tmp_path): + """Test to ensure that the full state dict and optimizer states can be load when using FSDP strategy. + + Based on `wrap_min_params`, the model will be fully wrapped, half wrapped, and not wrapped at all. If the DDP model + can be restored to FSDP, it means that the optimizer states were restored correctly. + + """ + + # restore model to ddp + model = TestBoringModel() + trainer = Trainer(default_root_dir=tmp_path, accelerator="gpu", devices=2, strategy="ddp", max_epochs=1) + + # This step will restore the model and optimizer states + trainer.fit(model) + model_path = os.path.join(tmp_path, "last.ckpt") + model_path = trainer.strategy.broadcast(model_path) + trainer.save_checkpoint(model_path) + + # Get the model and optimizer states from the restored ddp model + model_state_dict = trainer.strategy.lightning_module_state_dict() + optimizer_state_dict = trainer.strategy.optimizer_state(model.optimizers()) + + # Build a new FSDP model + model = TestFSDP2ModelAutoWrapped(wrap_min_params=wrap_min_params) + + strategy = FSDPStrategy(auto_wrap_policy=partial(size_based_auto_wrap_policy, min_num_params=wrap_min_params)) + trainer = Trainer( + default_root_dir=tmp_path, + accelerator="gpu", + devices=2, + strategy=strategy, + precision="16-mixed", + max_epochs=1, + barebones=True, + ) + + trainer.fit(model, ckpt_path=model_path) + + restored_model_state_dict = trainer.strategy.lightning_module_state_dict() + restored_optimizer_state_dict = trainer.strategy.optimizer_state(model.optimizers()) + + if trainer.global_rank != 0: + assert len(restored_model_state_dict) == 0 + + if trainer.global_rank != 0: + assert len(restored_optimizer_state_dict) == 0 + + if trainer.global_rank == 0: + # assert everything is the same + assert len(model_state_dict) == len(restored_model_state_dict) + assert len(optimizer_state_dict) == len(restored_optimizer_state_dict) + torch.testing.assert_close(model_state_dict, restored_model_state_dict, atol=0, rtol=0) + torch.testing.assert_close(optimizer_state_dict, restored_optimizer_state_dict, atol=0, rtol=0) + + trainer.strategy.barrier() + + +@RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True) +@pytest.mark.parametrize( + ("precision", "expected_dtype"), + [ + ("32-true", torch.float32), + ], +) +def test_configure_model(precision, expected_dtype, tmp_path): + """Test that the module under configure_model gets moved to the right device and dtype.""" + trainer = Trainer( + default_root_dir=tmp_path, + accelerator="cuda", + devices=2, + strategy=FSDP2Strategy(), + precision=precision, + max_epochs=1, + enable_checkpointing=False, + logger=False, + ) + + class MyModel(BoringModel): + def configure_model(self): + self.layer = torch.nn.Linear(32, 2) + # The model is on the CPU until after `.setup()`` + # TODO: Support initialization on meta device + expected_device = torch.device("cpu") + assert self.layer.weight.device == expected_device + assert self.layer.weight.dtype == expected_dtype + + def configure_optimizers(self): + # There is some issue with SGD optimizer state in FSDP + return torch.optim.AdamW(self.layer.parameters(), lr=0.1) + + def on_fit_start(self): + # Parameters get sharded in `.setup()` and moved to the target device + assert self.layer.weight.device == torch.device("cuda", self.local_rank) + assert self.layer.weight.dtype == expected_dtype + + model = MyModel() + trainer.fit(model) + + +def test_save_checkpoint_storage_options(tmp_path): + """Test that the FSDP strategy does not accept storage options for saving checkpoints.""" + strategy = FSDP2Strategy() + with pytest.raises(TypeError, match=escape("FSDP2Strategy.save_checkpoint(..., storage_options=...)` is not")): + strategy.save_checkpoint(filepath=tmp_path, checkpoint=Mock(), storage_options=Mock()) + + +def test_load_unknown_checkpoint_type(tmp_path): + """Test that the strategy validates the contents at the checkpoint path.""" + strategy = FSDP2Strategy() + strategy.model = Mock() + strategy._lightning_module = Mock() + path = tmp_path / "empty_dir" # neither a single file nor a directory with meta file + path.mkdir() + with pytest.raises(ValueError, match="does not point to a valid checkpoint"): + strategy.load_checkpoint(checkpoint_path=path) + + +class TestFSDP2CheckpointModel(BoringModel): + def __init__(self, params_to_compare=None): + super().__init__() + self.layer = torch.nn.Sequential(torch.nn.Linear(32, 32), torch.nn.ReLU(), torch.nn.Linear(32, 2)) + self.params_to_compare = params_to_compare + + def configure_optimizers(self): + # SGD's FSDP optimizer, state is fixed in https://github.com/pytorch/pytorch/pull/99214 + return torch.optim.AdamW(self.parameters(), lr=0.1) + + def on_train_start(self): + if self.params_to_compare is None: + return + for p0, p1 in zip(self.params_to_compare, self.trainer.model.parameters()): + torch.testing.assert_close(p0, p1, atol=0, rtol=0, equal_nan=True) + + +@pytest.mark.filterwarnings("ignore::FutureWarning") +@RunIf(min_cuda_gpus=2, standalone=True) +def test_save_load_sharded_state_dict(tmp_path): + """Test FSDP saving and loading with the sharded state dict format.""" + strategy = FSDP2Strategy() + trainer_kwargs = { + "default_root_dir": tmp_path, + "accelerator": "cuda", + "devices": 2, + "max_epochs": 1, + "enable_progress_bar": False, + "enable_model_summary": False, + "logger": False, + } + + # Initial training + model = TestFSDP2CheckpointModel() + trainer = Trainer(**trainer_kwargs, strategy=strategy) + trainer.fit(model) + params_before = deepcopy(list(trainer.model.parameters())) + + checkpoint_path = Path(trainer.strategy.broadcast(trainer.checkpoint_callback.best_model_path)) + assert set(os.listdir(checkpoint_path)) == {"meta.pt", ".metadata", "__0_0.distcp", "__1_0.distcp"} + + metadata = torch.load(checkpoint_path / "meta.pt", weights_only=True) + assert "pytorch-lightning_version" in metadata + assert len(metadata["callbacks"]) == 1 # model checkpoint callback + assert "state_dict" not in metadata + assert "optimizer_states" not in metadata + + # Load checkpoint and continue training + trainer_kwargs.update(max_epochs=2) + model = TestFSDP2CheckpointModel(params_to_compare=params_before) + strategy = FSDP2Strategy() + trainer = Trainer(**trainer_kwargs, strategy=strategy) + trainer.fit(model, ckpt_path=checkpoint_path) + + +@RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True) +@pytest.mark.parametrize( + ("precision", "expected_dtype"), + [ + ("32-true", torch.float32), + ("16-true", torch.float16), + ], +) +def test_module_init_context(precision, expected_dtype, tmp_path): + """Test that the module under the init-context gets moved to the right device and dtype.""" + + class Model(BoringModel): + def configure_optimizers(self): + return torch.optim.Adam(self.parameters(), lr=1e-2) + + def on_train_start(self): + # Parameters get sharded in `FSDPStrategy.setup()` and moved to the target device + assert self.layer.weight.device == torch.device("cuda", self.local_rank) + assert self.layer.weight.dtype == expected_dtype + optimizer = self.optimizers(use_pl_optimizer=False) + assert optimizer.param_groups[0]["params"][0].device.type == "cuda" + + def _run_setup_assertions(empty_init, expected_device): + trainer = Trainer( + default_root_dir=tmp_path, + accelerator="cuda", + devices=2, + strategy=FSDP2Strategy(), + precision=precision, + max_steps=1, + barebones=True, + enable_checkpointing=False, + logger=False, + ) + with trainer.init_module(empty_init=empty_init): + model = Model() + + # The model is on the CPU/meta-device until after `FSDPStrategy.setup()` + assert model.layer.weight.device == expected_device + assert model.layer.weight.dtype == expected_dtype + trainer.fit(model) + + # Case 1: No empty init + _run_setup_assertions(empty_init=False, expected_device=torch.device("cpu")) + + # Case 2: Empty-init with meta device + _run_setup_assertions(empty_init=True, expected_device=torch.device("meta")) + + +@pytest.mark.filterwarnings("ignore::FutureWarning") +@RunIf(min_cuda_gpus=2, standalone=True, min_torch="2.3.0") +def test_save_sharded_and_consolidate_and_load(tmp_path): + """Test the consolidation of a FSDP2-sharded checkpoint into a single file.""" + + class CustomModel(BoringModel): + def configure_optimizers(self): + # Use Adam instead of SGD for this test because it has state + # In PyTorch >= 2.4, saving an optimizer with empty state would result in a `KeyError: 'state'` + # when loading the optimizer state-dict back. + # TODO: To resolve this, switch to the new `torch.distributed.checkpoint` APIs in FSDPStrategy + return torch.optim.Adam(self.parameters(), lr=0.1) + + model = CustomModel() + trainer = Trainer( + default_root_dir=tmp_path, + accelerator="cuda", + devices=2, + strategy=FSDP2Strategy(), + max_steps=3, + ) + trainer.fit(model) + + checkpoint_path_sharded = trainer.strategy.broadcast(str(trainer.checkpoint_callback.best_model_path)) + assert set(os.listdir(checkpoint_path_sharded)) == {"meta.pt", ".metadata", "__0_0.distcp", "__1_0.distcp"} + + # consolidate the checkpoint to a single file + checkpoint_path_full = trainer.strategy.broadcast(str(tmp_path / "checkpoint_full.ckpt")) + if trainer.global_rank == 0: + checkpoint = _load_distributed_checkpoint(Path(checkpoint_path_sharded)) + checkpoint = _format_checkpoint(checkpoint) + torch.save(checkpoint, checkpoint_path_full) + trainer.strategy.barrier() + + model = CustomModel() + trainer = Trainer( + default_root_dir=tmp_path, + accelerator="cuda", + devices=2, + strategy="ddp", + max_steps=4, + ) + trainer.fit(model, ckpt_path=checkpoint_path_full) From 9389669e4cf6a533608017e9590784d9ac43dab4 Mon Sep 17 00:00:00 2001 From: Deependu Jha Date: Wed, 10 Sep 2025 17:12:03 +0530 Subject: [PATCH 09/17] fsdp2 tests --- tests/tests_pytorch/strategies/test_fsdp2.py | 417 +------------------ 1 file changed, 15 insertions(+), 402 deletions(-) diff --git a/tests/tests_pytorch/strategies/test_fsdp2.py b/tests/tests_pytorch/strategies/test_fsdp2.py index cac6b9f121002..01bdfb4ef8197 100644 --- a/tests/tests_pytorch/strategies/test_fsdp2.py +++ b/tests/tests_pytorch/strategies/test_fsdp2.py @@ -1,35 +1,35 @@ import os -from contextlib import nullcontext from copy import deepcopy -from datetime import timedelta -from functools import partial from pathlib import Path from re import escape from typing import Optional -from unittest import mock -from unittest.mock import ANY, MagicMock, Mock +from unittest.mock import Mock import pytest import torch import torch.nn as nn -from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload -from torch.distributed.fsdp.wrap import ModuleWrapPolicy, size_based_auto_wrap_policy from torchmetrics import Accuracy -from lightning.fabric.plugins.environments import LightningEnvironment -from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_3 from lightning.fabric.utilities.init import _has_all_dtensor_params_or_buffers from lightning.fabric.utilities.load import _load_distributed_checkpoint from lightning.pytorch import Trainer -from lightning.pytorch.callbacks import ModelCheckpoint from lightning.pytorch.demos.boring_classes import BoringModel -from lightning.pytorch.plugins import HalfPrecision -from lightning.pytorch.strategies import FSDP2Strategy, FSDPStrategy -from lightning.pytorch.trainer.states import TrainerFn +from lightning.pytorch.strategies import FSDP2Strategy from lightning.pytorch.utilities.consolidate_checkpoint import _format_checkpoint from tests_pytorch.helpers.runif import RunIf +# Minimal boring model for FSDP2 tests (used for DDP/FSDP2 checkpoint compatibility) +class TestBoringModel(BoringModel): + def __init__(self): + super().__init__() + self.save_hyperparameters() + self.layer = torch.nn.Sequential(torch.nn.Linear(32, 32), torch.nn.ReLU(), torch.nn.Linear(32, 2)) + + def configure_optimizers(self): + return torch.optim.AdamW(self.parameters(), lr=0.1) + + class TestFSDP2Model(BoringModel): def __init__(self): super().__init__() @@ -62,25 +62,11 @@ def on_predict_batch_end(self, _, batch, batch_idx): self._assert_layer_fsdp2_instance() def _assert_layer_fsdp2_instance(self): - # FSDP2 does not wrap modules in a distinct class (like FSDP1’s FullyShardedDataParallel). - # Instead, it injects an internal `_fsdp_state` attribute and replaces all parameters/buffers - # with DTensors in place. These two checks together confirm that `self.layer` is FSDP2-wrapped. + # FSDP2 injects an internal `_fsdp_state` attribute and replaces all parameters/buffers with DTensors. assert hasattr(self.layer, "_fsdp_state") assert _has_all_dtensor_params_or_buffers(self.layer) -class TestBoringModel(BoringModel): - def __init__(self): - super().__init__(self) - - self.save_hyperparameters() - self.layer = torch.nn.Sequential(torch.nn.Linear(32, 32), torch.nn.ReLU(), torch.nn.Linear(32, 2)) - - def configure_optimizers(self): - # SGD's FSDP optimizer, state is fixed in https://github.com/pytorch/pytorch/pull/99214 - return torch.optim.AdamW(self.parameters(), lr=0.1) - - class TestFSDP2ModelAutoWrapped(TestBoringModel): def on_train_batch_start(self, batch, batch_idx): assert batch.dtype == torch.float32 @@ -102,9 +88,6 @@ def on_predict_batch_end(self, _, batch, batch_idx): self._assert_layer_fsdp2_instance() def _assert_layer_fsdp2_instance(self): - # FSDP2 does not wrap modules in a distinct class (like FSDP1’s FullyShardedDataParallel). - # Instead, it injects an internal `_fsdp_state` attribute and replaces all parameters/buffers - # with DTensors in place. These two checks together confirm that `self.layer` is FSDP2-wrapped. assert hasattr(self.layer, "_fsdp_state") assert _has_all_dtensor_params_or_buffers(self.layer) @@ -164,7 +147,7 @@ def test_custom_mixed_precision(): # custom mp policy mp_policy = MixedPrecisionPolicy( - compute_dtype=torch.float16, param_dtype=torch.bfloat16, buffer_dtype=torch.float16, cast_forward_inputs=True + param_dtype=torch.bfloat16, reduce_dtype=torch.float16, output_dtype=torch.float16, cast_forward_inputs=True ) strategy = FSDP2Strategy(mp_policy=mp_policy) assert strategy.mp_policy == mp_policy @@ -254,365 +237,6 @@ def custom_auto_wrap_policy( return nonwrapped_numel >= 2 -@pytest.mark.filterwarnings("ignore::FutureWarning") -@RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True) -@pytest.mark.parametrize( - ("model", "strategy", "strategy_cfg"), - [ - pytest.param(TestFSDP2Model(), "fsdp", None, id="manually_wrapped"), - pytest.param( - TestFSDP2ModelAutoWrapped(), - FSDPStrategy, - {"auto_wrap_policy": custom_auto_wrap_policy}, - id="autowrap_2x", - ), - pytest.param( - TestFSDP2ModelAutoWrapped(), - FSDPStrategy, - { - "auto_wrap_policy": ModuleWrapPolicy({nn.Linear}), - "use_orig_params": True, - }, - id="autowrap_use_orig_params", - ), - ], -) -def test_checkpoint_multi_gpus(tmp_path, model, strategy, strategy_cfg): - """Test to ensure that checkpoint is saved correctly when using multiple GPUs, and all stages can be run.""" - ck = ModelCheckpoint(save_last=True) - - strategy_cfg = strategy_cfg or {} - if not isinstance(strategy, str): - strategy = strategy(**strategy_cfg) - - trainer = Trainer( - default_root_dir=tmp_path, - accelerator="gpu", - devices=2, - strategy=strategy, - precision="16-mixed", - max_epochs=1, - limit_train_batches=2, - limit_val_batches=2, - limit_test_batches=2, - limit_predict_batches=2, - callbacks=[ck], - ) - _run_multiple_stages(trainer, model) - - -@RunIf(min_cuda_gpus=1, skip_windows=True, standalone=True) -@pytest.mark.parametrize("use_orig_params", [None, False, True]) -def test_invalid_parameters_in_optimizer(use_orig_params): - fsdp_kwargs = {} - if use_orig_params is not None: - fsdp_kwargs = {"use_orig_params": use_orig_params} - - trainer = Trainer( - strategy=FSDPStrategy(**fsdp_kwargs), - accelerator="cuda", - devices=1, - fast_dev_run=1, - ) - - class EmptyParametersModel(BoringModel): - def configure_optimizers(self): - return torch.optim.Adam(self.parameters(), lr=1e-2) - - model = EmptyParametersModel() - trainer.fit(model) - - class NoFlatParametersModel(BoringModel): - def configure_optimizers(self): - layer = torch.nn.Linear(4, 5) - return torch.optim.Adam(layer.parameters(), lr=1e-2) - - error_context = ( - nullcontext() - if use_orig_params is not False - else pytest.raises(ValueError, match="The optimizer does not seem to reference any FSDP parameters") - ) - - model = NoFlatParametersModel() - with error_context: - trainer.fit(model) - - -def test_forbidden_precision_raises(): - with pytest.raises(TypeError, match="can only work with the `FSDPPrecision"): - FSDPStrategy(precision_plugin=HalfPrecision()) - - strategy = FSDPStrategy() - with pytest.raises(TypeError, match="can only work with the `FSDPPrecision"): - strategy.precision_plugin = HalfPrecision() - - -def test_activation_checkpointing(): - """Test that the FSDP strategy can apply activation checkpointing to the given layers.""" - - class Block1(nn.Linear): - pass - - class Block2(nn.Linear): - pass - - class Model(BoringModel): - def __init__(self): - super().__init__() - self.layer0 = nn.Sequential(Block1(4, 4), Block1(5, 5)) - self.layer1 = Block2(2, 2) - self.layer2 = nn.Linear(3, 3) - - strategy = FSDPStrategy(activation_checkpointing_policy={Block1}) - assert set(strategy._activation_checkpointing_kwargs) == {"auto_wrap_policy"} - assert isinstance(strategy._activation_checkpointing_kwargs["auto_wrap_policy"], ModuleWrapPolicy) - - strategy = FSDPStrategy(activation_checkpointing_policy=ModuleWrapPolicy({Block1, Block2})) - assert set(strategy._activation_checkpointing_kwargs) == {"auto_wrap_policy"} - assert isinstance(strategy._activation_checkpointing_kwargs["auto_wrap_policy"], ModuleWrapPolicy) - - model = Model() - strategy._parallel_devices = [torch.device("cuda", 0)] - strategy._lightning_module = model - strategy._process_group = Mock() - with ( - mock.patch("torch.distributed.fsdp.FullyShardedDataParallel", new=MagicMock), - mock.patch( - "torch.distributed.algorithms._checkpoint.checkpoint_wrapper.apply_activation_checkpointing" - ) as apply_mock, - ): - wrapped = strategy._setup_model(model) - apply_mock.assert_called_with(wrapped, checkpoint_wrapper_fn=ANY, **strategy._activation_checkpointing_kwargs) - - -def test_strategy_cpu_offload(): - """Test the different ways cpu offloading can be enabled.""" - # bool - strategy = FSDPStrategy(cpu_offload=True) - assert strategy.cpu_offload == CPUOffload(offload_params=True) - - # dataclass - config = CPUOffload() - strategy = FSDPStrategy(cpu_offload=config) - assert strategy.cpu_offload == config - - -def test_sharding_strategy(): - """Test the different ways the sharding strategy can be set.""" - from torch.distributed.fsdp import ShardingStrategy - - # default - strategy = FSDPStrategy() - assert strategy.sharding_strategy == ShardingStrategy.FULL_SHARD - - # enum - strategy = FSDPStrategy(sharding_strategy=ShardingStrategy.SHARD_GRAD_OP) - assert strategy.sharding_strategy == ShardingStrategy.SHARD_GRAD_OP - - # string - strategy = FSDPStrategy(sharding_strategy="NO_SHARD") - assert strategy.sharding_strategy == ShardingStrategy.NO_SHARD - strategy = FSDPStrategy(sharding_strategy="no_shard") - assert strategy.sharding_strategy == ShardingStrategy.NO_SHARD - - -@pytest.mark.parametrize("sharding_strategy", ["HYBRID_SHARD", "_HYBRID_SHARD_ZERO2"]) -def test_hybrid_shard_configuration(sharding_strategy, monkeypatch): - """Test that the hybrid sharding strategies can only be used with automatic wrapping or a manually specified pg.""" - with pytest.raises(RuntimeError, match="The hybrid sharding strategy requires you to pass at least one of"): - FSDPStrategy(sharding_strategy=sharding_strategy) - - strategy = FSDPStrategy(auto_wrap_policy={nn.Linear}, sharding_strategy=sharding_strategy) - assert strategy.sharding_strategy.name == sharding_strategy - - process_group = (Mock(), Mock()) - strategy = FSDPStrategy(sharding_strategy=sharding_strategy, process_group=process_group) - assert strategy.sharding_strategy.name == sharding_strategy - assert strategy.kwargs["process_group"] is process_group - - monkeypatch.setattr("lightning.pytorch.strategies.fsdp._TORCH_GREATER_EQUAL_2_2", False) - with pytest.raises(ValueError, match="`device_mesh` argument is only supported in torch >= 2.2."): - FSDPStrategy(device_mesh=Mock()) - - monkeypatch.setattr("lightning.pytorch.strategies.fsdp._TORCH_GREATER_EQUAL_2_2", True) - device_mesh = Mock() - strategy = FSDPStrategy(sharding_strategy=sharding_strategy, device_mesh=device_mesh) - assert strategy.sharding_strategy.name == sharding_strategy - assert strategy.kwargs["device_mesh"] is device_mesh - - with pytest.raises(ValueError, match="process_group.* device_mesh=.* are mutually exclusive"): - FSDPStrategy(sharding_strategy=sharding_strategy, process_group=process_group, device_mesh=device_mesh) - - -def test_use_orig_params(): - """Test that Lightning enables `use_orig_params` automatically.""" - strategy = FSDPStrategy() - assert strategy.kwargs["use_orig_params"] - strategy = FSDPStrategy(use_orig_params=False) - assert not strategy.kwargs["use_orig_params"] - - -@mock.patch("torch.distributed.init_process_group") -def test_set_timeout(init_process_group_mock): - """Test that the timeout gets passed to the ``torch.distributed.init_process_group`` function.""" - test_timedelta = timedelta(seconds=30) - strategy = FSDPStrategy(timeout=test_timedelta, parallel_devices=[torch.device("cpu")]) - strategy.cluster_environment = LightningEnvironment() - strategy.accelerator = Mock() - strategy.setup_environment() - process_group_backend = strategy._get_process_group_backend() - global_rank = strategy.cluster_environment.global_rank() - world_size = strategy.cluster_environment.world_size() - kwargs = {} - if _TORCH_GREATER_EQUAL_2_3: - kwargs["device_id"] = strategy.root_device if strategy.root_device.type != "cpu" else None - init_process_group_mock.assert_called_with( - process_group_backend, rank=global_rank, world_size=world_size, timeout=test_timedelta, **kwargs - ) - - -@mock.patch("lightning.pytorch.strategies.fsdp._load_raw_module_state") -def test_strategy_load_optimizer_states_multiple(_, tmp_path): - strategy = FSDPStrategy(parallel_devices=[torch.device("cpu")], state_dict_type="full") - trainer = Trainer() - trainer.state.fn = TrainerFn.FITTING - strategy._lightning_module = Mock(trainer=trainer) - spec = torch.optim.Optimizer - - # More states than optimizers configured - strategy.optimizers = [Mock(spec=spec)] - checkpoint = {"state_dict": {}, "optimizer_states": [{"state": {}}, {"state": {}}]} - torch.save(checkpoint, tmp_path / "two-states.ckpt") - with pytest.raises(RuntimeError, match="1 optimizers but the checkpoint contains 2 optimizers to load"): - strategy.load_checkpoint(tmp_path / "two-states.ckpt") - - # Fewer states than optimizers configured - strategy.optimizers = [Mock(spec=spec), Mock(spec=spec)] - checkpoint = {"state_dict": {}, "optimizer_states": [{"state": {}}]} - torch.save(checkpoint, tmp_path / "one-state.ckpt") - with pytest.raises(RuntimeError, match="2 optimizers but the checkpoint contains 1 optimizers to load"): - strategy.load_checkpoint(tmp_path / "one-state.ckpt") - - -@pytest.mark.filterwarnings("ignore::FutureWarning") -@RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True) -@pytest.mark.parametrize("wrap_min_params", [2, 1024, 100000000]) -def test_strategy_save_optimizer_states(tmp_path, wrap_min_params): - """Test to ensure that the full state dict and optimizer states is saved when using FSDP strategy. - - Based on `wrap_min_params`, the model will be fully wrapped, half wrapped, and not wrapped at all. If the model can - be restored to DDP, it means that the optimizer states were saved correctly. - - """ - model = TestFSDP2ModelAutoWrapped(wrap_min_params=wrap_min_params) - - strategy = FSDPStrategy(auto_wrap_policy=partial(size_based_auto_wrap_policy, min_num_params=wrap_min_params)) - trainer = Trainer( - default_root_dir=tmp_path, - accelerator="gpu", - devices=2, - strategy=strategy, - precision="16-mixed", - max_epochs=1, - barebones=True, - ) - - trainer.fit(model) - model_path = os.path.join(tmp_path, "last.ckpt") - model_path = trainer.strategy.broadcast(model_path) - trainer.save_checkpoint(model_path) - - model_state_dict = trainer.strategy.lightning_module_state_dict() - optimizer_state_dict = trainer.strategy.optimizer_state(model.optimizers()) - - if trainer.global_rank != 0: - assert len(model_state_dict) == 0 - - if trainer.global_rank != 0: - assert len(optimizer_state_dict) == 0 - - # restore model to ddp - model = TestBoringModel() - trainer = Trainer(default_root_dir=tmp_path, accelerator="gpu", devices=2, strategy="ddp", max_epochs=1) - - # This step will restore the model and optimizer states - trainer.fit(model, ckpt_path=model_path) - - # Get the model and optimizer states from the restored ddp model - restored_model_state_dict = trainer.strategy.lightning_module_state_dict() - restored_optimizer_state_dict = trainer.strategy.optimizer_state(model.optimizers()) - - if trainer.global_rank == 0: - # assert everything is the same - assert len(model_state_dict) == len(restored_model_state_dict) - assert len(optimizer_state_dict) == len(restored_optimizer_state_dict) - - torch.testing.assert_close(model_state_dict, restored_model_state_dict, atol=0, rtol=0) - torch.testing.assert_close(optimizer_state_dict, restored_optimizer_state_dict, atol=0, rtol=0) - - trainer.strategy.barrier() - - -@pytest.mark.filterwarnings("ignore::FutureWarning") -@RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True) -@pytest.mark.parametrize("wrap_min_params", [2, 1024, 100000000]) -def test_strategy_load_optimizer_states(wrap_min_params, tmp_path): - """Test to ensure that the full state dict and optimizer states can be load when using FSDP strategy. - - Based on `wrap_min_params`, the model will be fully wrapped, half wrapped, and not wrapped at all. If the DDP model - can be restored to FSDP, it means that the optimizer states were restored correctly. - - """ - - # restore model to ddp - model = TestBoringModel() - trainer = Trainer(default_root_dir=tmp_path, accelerator="gpu", devices=2, strategy="ddp", max_epochs=1) - - # This step will restore the model and optimizer states - trainer.fit(model) - model_path = os.path.join(tmp_path, "last.ckpt") - model_path = trainer.strategy.broadcast(model_path) - trainer.save_checkpoint(model_path) - - # Get the model and optimizer states from the restored ddp model - model_state_dict = trainer.strategy.lightning_module_state_dict() - optimizer_state_dict = trainer.strategy.optimizer_state(model.optimizers()) - - # Build a new FSDP model - model = TestFSDP2ModelAutoWrapped(wrap_min_params=wrap_min_params) - - strategy = FSDPStrategy(auto_wrap_policy=partial(size_based_auto_wrap_policy, min_num_params=wrap_min_params)) - trainer = Trainer( - default_root_dir=tmp_path, - accelerator="gpu", - devices=2, - strategy=strategy, - precision="16-mixed", - max_epochs=1, - barebones=True, - ) - - trainer.fit(model, ckpt_path=model_path) - - restored_model_state_dict = trainer.strategy.lightning_module_state_dict() - restored_optimizer_state_dict = trainer.strategy.optimizer_state(model.optimizers()) - - if trainer.global_rank != 0: - assert len(restored_model_state_dict) == 0 - - if trainer.global_rank != 0: - assert len(restored_optimizer_state_dict) == 0 - - if trainer.global_rank == 0: - # assert everything is the same - assert len(model_state_dict) == len(restored_model_state_dict) - assert len(optimizer_state_dict) == len(restored_optimizer_state_dict) - torch.testing.assert_close(model_state_dict, restored_model_state_dict, atol=0, rtol=0) - torch.testing.assert_close(optimizer_state_dict, restored_optimizer_state_dict, atol=0, rtol=0) - - trainer.strategy.barrier() - - @RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True) @pytest.mark.parametrize( ("precision", "expected_dtype"), @@ -662,17 +286,6 @@ def test_save_checkpoint_storage_options(tmp_path): strategy.save_checkpoint(filepath=tmp_path, checkpoint=Mock(), storage_options=Mock()) -def test_load_unknown_checkpoint_type(tmp_path): - """Test that the strategy validates the contents at the checkpoint path.""" - strategy = FSDP2Strategy() - strategy.model = Mock() - strategy._lightning_module = Mock() - path = tmp_path / "empty_dir" # neither a single file nor a directory with meta file - path.mkdir() - with pytest.raises(ValueError, match="does not point to a valid checkpoint"): - strategy.load_checkpoint(checkpoint_path=path) - - class TestFSDP2CheckpointModel(BoringModel): def __init__(self, params_to_compare=None): super().__init__() From b3ce371027fc2bcf673391c5bf5c43b2eb537273 Mon Sep 17 00:00:00 2001 From: Deependu Jha Date: Thu, 11 Sep 2025 10:28:27 +0530 Subject: [PATCH 10/17] could it be --- src/lightning/fabric/utilities/imports.py | 1 + src/lightning/pytorch/strategies/fsdp2.py | 18 +++++++++++---- tests/tests_pytorch/strategies/test_fsdp2.py | 23 +++++++++++++++----- 3 files changed, 33 insertions(+), 9 deletions(-) diff --git a/src/lightning/fabric/utilities/imports.py b/src/lightning/fabric/utilities/imports.py index 70239baac0e6d..5655f2674638e 100644 --- a/src/lightning/fabric/utilities/imports.py +++ b/src/lightning/fabric/utilities/imports.py @@ -35,6 +35,7 @@ _TORCH_GREATER_EQUAL_2_4 = compare_version("torch", operator.ge, "2.4.0") _TORCH_GREATER_EQUAL_2_4_1 = compare_version("torch", operator.ge, "2.4.1") _TORCH_GREATER_EQUAL_2_5 = compare_version("torch", operator.ge, "2.5.0") +_TORCH_GREATER_EQUAL_2_6 = compare_version("torch", operator.ge, "2.6.0") _TORCH_LESS_EQUAL_2_6 = compare_version("torch", operator.le, "2.6.0") _TORCHMETRICS_GREATER_EQUAL_1_0_0 = compare_version("torchmetrics", operator.ge, "1.0.0") _PYTHON_GREATER_EQUAL_3_10_0 = (sys.version_info.major, sys.version_info.minor) >= (3, 10) diff --git a/src/lightning/pytorch/strategies/fsdp2.py b/src/lightning/pytorch/strategies/fsdp2.py index db9dfb8d12f65..97bd6f975175f 100644 --- a/src/lightning/pytorch/strategies/fsdp2.py +++ b/src/lightning/pytorch/strategies/fsdp2.py @@ -48,7 +48,7 @@ _sync_ddp_if_available, ) from lightning.fabric.utilities.distributed import group as _group -from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_3 +from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_6 from lightning.fabric.utilities.init import _has_all_dtensor_params_or_buffers, _has_meta_device_parameters_or_buffers from lightning.fabric.utilities.optimizer import _optimizers_to_device from lightning.fabric.utilities.seed import reset_seed @@ -66,9 +66,9 @@ from torch.distributed.device_mesh import DeviceMesh from torch.distributed.fsdp import CPUOffloadPolicy, MixedPrecisionPolicy, OffloadPolicy -try: +if _TORCH_GREATER_EQUAL_2_6: from torch.distributed.checkpoint.stateful import Stateful as _TorchStateful -except ImportError: +else: class _TorchStateful: # type: ignore[no-redef] pass @@ -131,6 +131,11 @@ def __init__( mp_policy: Optional["MixedPrecisionPolicy"] = None, **kwargs: Any, ) -> None: + if not _TORCH_GREATER_EQUAL_2_6: + raise ModuleNotFoundError( + "FSDP2Strategy requires torch>=2.6.0. " + f"Found torch {torch.__version__}. Please upgrade torch to use FSDP2Strategy." + ) super().__init__( accelerator=accelerator, parallel_devices=parallel_devices, @@ -206,7 +211,7 @@ def setup_environment(self) -> None: self._process_group_backend = self._get_process_group_backend() assert self.cluster_environment is not None kwargs: dict[str, Any] = {"timeout": self._timeout} - if _TORCH_GREATER_EQUAL_2_3: + if _TORCH_GREATER_EQUAL_2_6: kwargs["device_id"] = self.root_device if self.root_device.type != "cpu" else None _init_dist_connection(self.cluster_environment, self._process_group_backend, **kwargs) @@ -551,6 +556,11 @@ class AppState(_TorchStateful): """ def __init__(self, model: Module, optimizers: list[Optimizer]) -> None: + if not _TORCH_GREATER_EQUAL_2_6: + raise ModuleNotFoundError( + "AppState requires torch>=2.6.0. " + f"Found torch {torch.__version__}. Please upgrade torch to use AppState." + ) self.model = model self.optimizers = optimizers diff --git a/tests/tests_pytorch/strategies/test_fsdp2.py b/tests/tests_pytorch/strategies/test_fsdp2.py index 01bdfb4ef8197..65ea8dcadc044 100644 --- a/tests/tests_pytorch/strategies/test_fsdp2.py +++ b/tests/tests_pytorch/strategies/test_fsdp2.py @@ -132,6 +132,7 @@ def _assert_save_equality(trainer, ckpt_path, cls=TestFSDP2Model): assert torch.equal(ddp_param, shard_param) +@RunIf(min_torch="2.6.0") @pytest.mark.parametrize("strategy", ["fsdp2", "fsdp2_cpu_offload"]) def test_invalid_on_cpu(tmp_path, cuda_count_0, strategy): """Test to ensure that we raise Misconfiguration for FSDP on CPU.""" @@ -141,6 +142,7 @@ def test_invalid_on_cpu(tmp_path, cuda_count_0, strategy): trainer.strategy.setup_environment() +@RunIf(min_torch="2.6.0") def test_custom_mixed_precision(): """Test to ensure that passing a custom mixed precision config works.""" from torch.distributed.fsdp import MixedPrecisionPolicy @@ -168,6 +170,7 @@ class InvalidMPPolicy: FSDP2Strategy(mp_policy=InvalidMPPolicy()) +@RunIf(min_torch="2.6.0") @pytest.mark.filterwarnings("ignore::FutureWarning") @RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True) def test_strategy_sync_batchnorm(tmp_path): @@ -185,6 +188,7 @@ def test_strategy_sync_batchnorm(tmp_path): _run_multiple_stages(trainer, model, os.path.join(tmp_path, "last.ckpt")) +@RunIf(min_torch="2.6.0") @pytest.mark.filterwarnings("ignore::FutureWarning") @RunIf(min_cuda_gpus=1, skip_windows=True) def test_modules_without_parameters(tmp_path): @@ -217,7 +221,7 @@ def training_step(self, batch, batch_idx): @pytest.mark.filterwarnings("ignore::FutureWarning") -@RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True) +@RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True, min_torch="2.6.0") @pytest.mark.parametrize("precision", ["16-mixed", pytest.param("bf16-mixed", marks=RunIf(bf16_cuda=True))]) def test_strategy_checkpoint(state_dict_type, precision, tmp_path): """Test to ensure that checkpoint is saved correctly when using a single GPU, and all stages can be run.""" @@ -237,7 +241,7 @@ def custom_auto_wrap_policy( return nonwrapped_numel >= 2 -@RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True) +@RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True, min_torch="2.6.0") @pytest.mark.parametrize( ("precision", "expected_dtype"), [ @@ -279,6 +283,7 @@ def on_fit_start(self): trainer.fit(model) +@RunIf(min_torch="2.6.0") def test_save_checkpoint_storage_options(tmp_path): """Test that the FSDP strategy does not accept storage options for saving checkpoints.""" strategy = FSDP2Strategy() @@ -304,7 +309,7 @@ def on_train_start(self): @pytest.mark.filterwarnings("ignore::FutureWarning") -@RunIf(min_cuda_gpus=2, standalone=True) +@RunIf(min_cuda_gpus=2, standalone=True, min_torch="2.6.0") def test_save_load_sharded_state_dict(tmp_path): """Test FSDP saving and loading with the sharded state dict format.""" strategy = FSDP2Strategy() @@ -341,7 +346,7 @@ def test_save_load_sharded_state_dict(tmp_path): trainer.fit(model, ckpt_path=checkpoint_path) -@RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True) +@RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True, min_torch="2.6.0") @pytest.mark.parametrize( ("precision", "expected_dtype"), [ @@ -391,7 +396,7 @@ def _run_setup_assertions(empty_init, expected_device): @pytest.mark.filterwarnings("ignore::FutureWarning") -@RunIf(min_cuda_gpus=2, standalone=True, min_torch="2.3.0") +@RunIf(min_cuda_gpus=2, standalone=True, min_torch="2.6.0") def test_save_sharded_and_consolidate_and_load(tmp_path): """Test the consolidation of a FSDP2-sharded checkpoint into a single file.""" @@ -433,3 +438,11 @@ def configure_optimizers(self): max_steps=4, ) trainer.fit(model, ckpt_path=checkpoint_path_full) + + +@RunIf(max_torch="2.5") +@pytest.mark.parametrize("strategy", ["fsdp2", "fsdp2_cpu_offload"]) +def test_fsdp2_requires_torch_2_6_or_newer(tmp_path, strategy): + """FSDP2 strategies should error on torch < 2.6.""" + with pytest.raises(ValueError, match="FSDP2Strategy requires torch>=2.6.0."): + Trainer(accelerator="cpu", default_root_dir=tmp_path, fast_dev_run=True, strategy=strategy) From 224a12573feccc265992694afbd7b9ca8482da82 Mon Sep 17 00:00:00 2001 From: Deependu Jha Date: Thu, 11 Sep 2025 10:35:58 +0530 Subject: [PATCH 11/17] update --- src/lightning/pytorch/strategies/fsdp2.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/lightning/pytorch/strategies/fsdp2.py b/src/lightning/pytorch/strategies/fsdp2.py index 97bd6f975175f..9302e6265e9b8 100644 --- a/src/lightning/pytorch/strategies/fsdp2.py +++ b/src/lightning/pytorch/strategies/fsdp2.py @@ -67,7 +67,13 @@ from torch.distributed.fsdp import CPUOffloadPolicy, MixedPrecisionPolicy, OffloadPolicy if _TORCH_GREATER_EQUAL_2_6: - from torch.distributed.checkpoint.stateful import Stateful as _TorchStateful + try: + from torch.distributed.checkpoint.stateful import Stateful as _TorchStateful + except ImportError: + + class _TorchStateful: # type: ignore[no-redef] + pass + else: class _TorchStateful: # type: ignore[no-redef] From 6b05701c8dd4ff98d3dec00f0e474e6a8e3ec9d5 Mon Sep 17 00:00:00 2001 From: Deependu Jha Date: Thu, 11 Sep 2025 10:52:49 +0530 Subject: [PATCH 12/17] meow --- tests/tests_pytorch/strategies/test_fsdp2.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/tests_pytorch/strategies/test_fsdp2.py b/tests/tests_pytorch/strategies/test_fsdp2.py index 65ea8dcadc044..7330c2f1a0a8e 100644 --- a/tests/tests_pytorch/strategies/test_fsdp2.py +++ b/tests/tests_pytorch/strategies/test_fsdp2.py @@ -440,9 +440,9 @@ def configure_optimizers(self): trainer.fit(model, ckpt_path=checkpoint_path_full) -@RunIf(max_torch="2.5") +@RunIf(max_torch="2.5", min_cuda_gpus=1) @pytest.mark.parametrize("strategy", ["fsdp2", "fsdp2_cpu_offload"]) def test_fsdp2_requires_torch_2_6_or_newer(tmp_path, strategy): """FSDP2 strategies should error on torch < 2.6.""" with pytest.raises(ValueError, match="FSDP2Strategy requires torch>=2.6.0."): - Trainer(accelerator="cpu", default_root_dir=tmp_path, fast_dev_run=True, strategy=strategy) + Trainer(default_root_dir=tmp_path, fast_dev_run=True, strategy=strategy) From 3e76d9e63f9426fb0ca096966042a96e9e142e14 Mon Sep 17 00:00:00 2001 From: Deependu Jha Date: Thu, 11 Sep 2025 13:06:41 +0530 Subject: [PATCH 13/17] update --- tests/tests_pytorch/strategies/test_fsdp2.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/tests/tests_pytorch/strategies/test_fsdp2.py b/tests/tests_pytorch/strategies/test_fsdp2.py index 7330c2f1a0a8e..4ae4d186e7717 100644 --- a/tests/tests_pytorch/strategies/test_fsdp2.py +++ b/tests/tests_pytorch/strategies/test_fsdp2.py @@ -170,9 +170,8 @@ class InvalidMPPolicy: FSDP2Strategy(mp_policy=InvalidMPPolicy()) -@RunIf(min_torch="2.6.0") @pytest.mark.filterwarnings("ignore::FutureWarning") -@RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True) +@RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True, min_torch="2.6.0") def test_strategy_sync_batchnorm(tmp_path): """Test to ensure that sync_batchnorm works when using FSDP and GPU, and all stages can be run.""" model = TestFSDP2Model() @@ -188,9 +187,8 @@ def test_strategy_sync_batchnorm(tmp_path): _run_multiple_stages(trainer, model, os.path.join(tmp_path, "last.ckpt")) -@RunIf(min_torch="2.6.0") @pytest.mark.filterwarnings("ignore::FutureWarning") -@RunIf(min_cuda_gpus=1, skip_windows=True) +@RunIf(min_cuda_gpus=1, skip_windows=True, min_torch="2.6.0") def test_modules_without_parameters(tmp_path): """Test that TorchMetrics get moved to the device despite not having any parameters.""" From e4002157c30d65db8c556b907856738f2a704700 Mon Sep 17 00:00:00 2001 From: Deependu Date: Thu, 11 Sep 2025 13:18:36 +0530 Subject: [PATCH 14/17] Update src/lightning/fabric/utilities/init.py Co-authored-by: Nicki Skafte Detlefsen --- src/lightning/fabric/utilities/init.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/lightning/fabric/utilities/init.py b/src/lightning/fabric/utilities/init.py index 6243c91a7a402..023ae43c4d7bb 100644 --- a/src/lightning/fabric/utilities/init.py +++ b/src/lightning/fabric/utilities/init.py @@ -116,6 +116,11 @@ def _has_meta_device_parameters_or_buffers(obj: Union[Module, Optimizer], recurs def _has_all_dtensor_params_or_buffers(obj: Union[Module, Optimizer], recurse: bool = True) -> bool: + """ + Check whether all parameters and buffers of a given + :class:`torch.nn.Module` or :class:`torch.optim.Optimizer` are instances of + :class:`torch.distributed.tensor.DTensor`. + """ from torch.distributed.tensor import DTensor if isinstance(obj, Optimizer): From a84b9b12548e3a7b0e43a85d3dd3fdefe4f2d491 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 11 Sep 2025 07:49:16 +0000 Subject: [PATCH 15/17] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/lightning/fabric/utilities/init.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/src/lightning/fabric/utilities/init.py b/src/lightning/fabric/utilities/init.py index 023ae43c4d7bb..ffd9b7646d62d 100644 --- a/src/lightning/fabric/utilities/init.py +++ b/src/lightning/fabric/utilities/init.py @@ -116,11 +116,8 @@ def _has_meta_device_parameters_or_buffers(obj: Union[Module, Optimizer], recurs def _has_all_dtensor_params_or_buffers(obj: Union[Module, Optimizer], recurse: bool = True) -> bool: - """ - Check whether all parameters and buffers of a given - :class:`torch.nn.Module` or :class:`torch.optim.Optimizer` are instances of - :class:`torch.distributed.tensor.DTensor`. - """ + """Check whether all parameters and buffers of a given :class:`torch.nn.Module` or :class:`torch.optim.Optimizer` + are instances of :class:`torch.distributed.tensor.DTensor`.""" from torch.distributed.tensor import DTensor if isinstance(obj, Optimizer): From cf6bbf15f70682f431f515a1b1ab42c9168b89b4 Mon Sep 17 00:00:00 2001 From: Deependu Jha Date: Thu, 11 Sep 2025 13:21:46 +0530 Subject: [PATCH 16/17] update --- src/lightning/pytorch/strategies/fsdp2.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/src/lightning/pytorch/strategies/fsdp2.py b/src/lightning/pytorch/strategies/fsdp2.py index 9302e6265e9b8..507d82db992a7 100644 --- a/src/lightning/pytorch/strategies/fsdp2.py +++ b/src/lightning/pytorch/strategies/fsdp2.py @@ -102,14 +102,12 @@ class FSDP2Strategy(ParallelStrategy): https://github.com/pytorch/pytorch/issues/114299 Arguments: + mp_policy: A ``MixedPrecisionPolicy`` object that specifies the precision policy for + model parameters and gradients when using mixed precision training with FSDP2. + cpu_offload: A ``CPUOffloadPolicy`` or boolean that specifies whether to offload + model parameters and gradients to CPU memory. If ``True``, offloading is enabled with default settings. device_mesh: A :class:`torch.distributed.device_mesh.DeviceMesh` object that specifies how devices are arranged and how tensors should be sharded/replicated. - parallelize_module: Optional policy function or mapping that specifies how to wrap or - distribute submodules of the model using ``DTensor``. - checkpoint_policy: Defines how checkpoint saving/loading is performed with DTensor-based - modules. See ``torch.distributed.checkpoint`` for available options. - mixed_precision: Optional policy for mixed precision training. Can be used to specify - precision for parameters, gradients, and buffers. \**kwargs: Additional keyword arguments passed to the underlying FSDP2 APIs. .. note:: @@ -125,7 +123,6 @@ class FSDP2Strategy(ParallelStrategy): def __init__( self, - device_mesh: Optional[Union[tuple[int], "DeviceMesh"]] = None, accelerator: Optional["pl.accelerators.Accelerator"] = None, parallel_devices: Optional[list[torch.device]] = None, cluster_environment: Optional[ClusterEnvironment] = None, @@ -135,6 +132,7 @@ def __init__( timeout: Optional[timedelta] = default_pg_timeout, cpu_offload: Union[bool, "CPUOffloadPolicy", None] = None, mp_policy: Optional["MixedPrecisionPolicy"] = None, + device_mesh: Optional[Union[tuple[int], "DeviceMesh"]] = None, **kwargs: Any, ) -> None: if not _TORCH_GREATER_EQUAL_2_6: From 029ebffcd24a228f1efc45dd2e576357c6464578 Mon Sep 17 00:00:00 2001 From: Deependu Jha Date: Thu, 11 Sep 2025 13:50:30 +0530 Subject: [PATCH 17/17] nitpick. and pause fsdp2 dev for now --- tests/tests_pytorch/strategies/test_fsdp2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/tests_pytorch/strategies/test_fsdp2.py b/tests/tests_pytorch/strategies/test_fsdp2.py index 4ae4d186e7717..38509feb3f258 100644 --- a/tests/tests_pytorch/strategies/test_fsdp2.py +++ b/tests/tests_pytorch/strategies/test_fsdp2.py @@ -442,5 +442,5 @@ def configure_optimizers(self): @pytest.mark.parametrize("strategy", ["fsdp2", "fsdp2_cpu_offload"]) def test_fsdp2_requires_torch_2_6_or_newer(tmp_path, strategy): """FSDP2 strategies should error on torch < 2.6.""" - with pytest.raises(ValueError, match="FSDP2Strategy requires torch>=2.6.0."): + with pytest.raises(ModuleNotFoundError, match="FSDP2Strategy requires torch>=2.6.0."): Trainer(default_root_dir=tmp_path, fast_dev_run=True, strategy=strategy)