diff --git a/src/lightning/fabric/strategies/fsdp.py b/src/lightning/fabric/strategies/fsdp.py index baaee74af0ec9..3809976548849 100644 --- a/src/lightning/fabric/strategies/fsdp.py +++ b/src/lightning/fabric/strategies/fsdp.py @@ -795,6 +795,12 @@ def _optimizer_has_flat_params(optimizer: Optimizer) -> bool: ) +def _optimizer_has_dtensor_params(optimizer: Optimizer) -> bool: + from torch.distributed.tensor import DTensor + + return any(isinstance(param, DTensor) for group in optimizer.param_groups for param in group["params"]) + + def _get_sharded_state_dict_context(module: Module) -> Generator[None, None, None]: from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp.api import ShardedOptimStateDictConfig, ShardedStateDictConfig, StateDictType diff --git a/src/lightning/fabric/utilities/imports.py b/src/lightning/fabric/utilities/imports.py index 70239baac0e6d..5655f2674638e 100644 --- a/src/lightning/fabric/utilities/imports.py +++ b/src/lightning/fabric/utilities/imports.py @@ -35,6 +35,7 @@ _TORCH_GREATER_EQUAL_2_4 = compare_version("torch", operator.ge, "2.4.0") _TORCH_GREATER_EQUAL_2_4_1 = compare_version("torch", operator.ge, "2.4.1") _TORCH_GREATER_EQUAL_2_5 = compare_version("torch", operator.ge, "2.5.0") +_TORCH_GREATER_EQUAL_2_6 = compare_version("torch", operator.ge, "2.6.0") _TORCH_LESS_EQUAL_2_6 = compare_version("torch", operator.le, "2.6.0") _TORCHMETRICS_GREATER_EQUAL_1_0_0 = compare_version("torchmetrics", operator.ge, "1.0.0") _PYTHON_GREATER_EQUAL_3_10_0 = (sys.version_info.major, sys.version_info.minor) >= (3, 10) diff --git a/src/lightning/fabric/utilities/init.py b/src/lightning/fabric/utilities/init.py index 4f8519eec9610..ffd9b7646d62d 100644 --- a/src/lightning/fabric/utilities/init.py +++ b/src/lightning/fabric/utilities/init.py @@ -113,3 +113,23 @@ def _has_meta_device_parameters_or_buffers(obj: Union[Module, Optimizer], recurs if isinstance(obj, Module): return any(t.is_meta for t in itertools.chain(obj.parameters(recurse=recurse), obj.buffers(recurse=recurse))) raise TypeError(f"Expected `torch.nn.Module` or `torch.optim.Optimizer`, got: {type(obj).__name__}") + + +def _has_all_dtensor_params_or_buffers(obj: Union[Module, Optimizer], recurse: bool = True) -> bool: + """Check whether all parameters and buffers of a given :class:`torch.nn.Module` or :class:`torch.optim.Optimizer` + are instances of :class:`torch.distributed.tensor.DTensor`.""" + from torch.distributed.tensor import DTensor + + if isinstance(obj, Optimizer): + return all( + isinstance(t, DTensor) + for param_group in obj.param_groups + for t in param_group["params"] + if isinstance(t, Parameter) + ) + if isinstance(obj, Module): + return all( + isinstance(t, DTensor) + for t in itertools.chain(obj.parameters(recurse=recurse), obj.buffers(recurse=recurse)) + ) + raise TypeError(f"Expected `torch.nn.Module` or `torch.optim.Optimizer`, got: {type(obj).__name__}") diff --git a/src/lightning/fabric/utilities/registry.py b/src/lightning/fabric/utilities/registry.py index 7d8f6ca17712e..03c30972c98b0 100644 --- a/src/lightning/fabric/utilities/registry.py +++ b/src/lightning/fabric/utilities/registry.py @@ -36,9 +36,7 @@ def _load_external_callbacks(group: str) -> list[Any]: A list of all callbacks collected from external factories. """ - factories = ( - entry_points(group=group) if _PYTHON_GREATER_EQUAL_3_10_0 else entry_points().get(group, {}) # type: ignore[arg-type] - ) + factories = entry_points(group=group) if _PYTHON_GREATER_EQUAL_3_10_0 else entry_points().get(group, {}) # type: ignore[arg-type] external_callbacks: list[Any] = [] for factory in factories: diff --git a/src/lightning/pytorch/plugins/__init__.py b/src/lightning/pytorch/plugins/__init__.py index d4fd63807c78d..8ec85a2b5fc67 100644 --- a/src/lightning/pytorch/plugins/__init__.py +++ b/src/lightning/pytorch/plugins/__init__.py @@ -8,6 +8,7 @@ from lightning.pytorch.plugins.precision.deepspeed import DeepSpeedPrecision from lightning.pytorch.plugins.precision.double import DoublePrecision from lightning.pytorch.plugins.precision.fsdp import FSDPPrecision +from lightning.pytorch.plugins.precision.fsdp2 import FSDP2Precision from lightning.pytorch.plugins.precision.half import HalfPrecision from lightning.pytorch.plugins.precision.precision import Precision from lightning.pytorch.plugins.precision.transformer_engine import TransformerEnginePrecision @@ -28,6 +29,7 @@ "Precision", "TransformerEnginePrecision", "FSDPPrecision", + "FSDP2Precision", "XLAPrecision", "LayerSync", "TorchSyncBatchNorm", diff --git a/src/lightning/pytorch/plugins/precision/fsdp2.py b/src/lightning/pytorch/plugins/precision/fsdp2.py new file mode 100644 index 0000000000000..2bda780a97282 --- /dev/null +++ b/src/lightning/pytorch/plugins/precision/fsdp2.py @@ -0,0 +1,110 @@ +# Copyright The Lightning AI team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from contextlib import AbstractContextManager +from typing import Any + +import torch +from lightning_utilities import apply_to_collection +from torch import Tensor +from torch.nn import Module +from typing_extensions import get_args, override + +from lightning.fabric.plugins.precision.fsdp import _PRECISION_INPUT +from lightning.fabric.plugins.precision.utils import _convert_fp_tensor, _DtypeContextManager +from lightning.pytorch.plugins.precision.precision import Precision +from lightning.pytorch.utilities.exceptions import MisconfigurationException + + +class FSDP2Precision(Precision): + """Precision plugin for training with FSDP2 (Fully Sharded Data Parallel v2). + + .. warning:: This is an :ref:`experimental ` feature. + + Args: + precision: Full precision (32-true), half precision (16-true, bf16-true) or + mixed precision (16-mixed, bf16-mixed). + scaler: An optional :class:`torch.distributed.fsdp.sharded_grad_scaler.ShardedGradScaler` to use. + + Raises: + ValueError: + If unsupported ``precision`` is provided. + + """ + + def __init__(self, precision: _PRECISION_INPUT, scaler: Any = None) -> None: + supported_precision = get_args(_PRECISION_INPUT) + if precision not in supported_precision: + raise ValueError( + f"`precision={precision!r})` is not supported in FSDP." + f" `precision` must be one of: {supported_precision}." + ) + + if scaler is not None: + raise ValueError( + f"`scaler` is not supported in `{self.__class__.__name__}`, found {scaler}." + "Use `mixed-precision policy` instead to configure the scaler." + ) + + if "mixed" in precision: + raise ValueError( + f"`precision={precision!r}` is not supported in `{self.__class__.__name__}`." + "Only `true` precision is supported." + "Use `mixed-precision policy (mp_policy)` instead to configure mixed precision." + ) + + self.precision = precision + + precision_to_type = { + "bf16-true": torch.bfloat16, + "16-true": torch.float16, + "32-true": torch.float32, + } + self._desired_input_dtype = precision_to_type[self.precision] + + @override + def convert_module(self, module: Module) -> Module: + if "true" in self.precision: + return module.to(dtype=self._desired_input_dtype) + return module + + @override + def clip_grad_by_norm(self, *_: Any, **__: Any) -> None: + # see https://pytorch.org/docs/stable/fsdp.html#torch.distributed.fsdp.FullyShardedDataParallel.clip_grad_norm_ + # section `Gradient Clipping`, using `torch.nn.utils.clip_grad_norm_` is incorrect with FSDP. + # To overcome this we need to call root_sharded_module.clip_grad_norm(clip_val), but we don't have a reference + # to the root module + raise MisconfigurationException( + f"`gradient_clip_algorithm='norm'` is currently not supported for `{self.__class__.__name__}`" + ) + + @override + def tensor_init_context(self) -> AbstractContextManager: + return _DtypeContextManager(self._desired_input_dtype) + + @override + def module_init_context(self) -> AbstractContextManager: + # Use float32 for module parameter initialization to ensure numerical stability + return _DtypeContextManager(self._desired_input_dtype) + + @override + def forward_context(self) -> AbstractContextManager: + return _DtypeContextManager(self._desired_input_dtype) + + @override + def convert_input(self, data: Any) -> Any: + return apply_to_collection(data, function=_convert_fp_tensor, dtype=Tensor, dst_type=self._desired_input_dtype) + + @override + def convert_output(self, data: Any) -> Any: + return apply_to_collection(data, function=_convert_fp_tensor, dtype=Tensor, dst_type=torch.get_default_dtype()) diff --git a/src/lightning/pytorch/strategies/__init__.py b/src/lightning/pytorch/strategies/__init__.py index 9c2b2a6a3a621..33b28b6ae59c2 100644 --- a/src/lightning/pytorch/strategies/__init__.py +++ b/src/lightning/pytorch/strategies/__init__.py @@ -18,6 +18,7 @@ from lightning.pytorch.strategies.ddp import DDPStrategy from lightning.pytorch.strategies.deepspeed import DeepSpeedStrategy from lightning.pytorch.strategies.fsdp import FSDPStrategy +from lightning.pytorch.strategies.fsdp2 import FSDP2Strategy from lightning.pytorch.strategies.model_parallel import ModelParallelStrategy from lightning.pytorch.strategies.parallel import ParallelStrategy from lightning.pytorch.strategies.single_device import SingleDeviceStrategy @@ -32,6 +33,7 @@ "DDPStrategy", "DeepSpeedStrategy", "FSDPStrategy", + "FSDP2Strategy", "ModelParallelStrategy", "ParallelStrategy", "SingleDeviceStrategy", diff --git a/src/lightning/pytorch/strategies/fsdp2.py b/src/lightning/pytorch/strategies/fsdp2.py new file mode 100644 index 0000000000000..507d82db992a7 --- /dev/null +++ b/src/lightning/pytorch/strategies/fsdp2.py @@ -0,0 +1,585 @@ +# Copyright The Lightning AI team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +from collections.abc import Generator, Mapping +from contextlib import contextmanager, nullcontext +from datetime import timedelta +from pathlib import Path +from typing import ( + TYPE_CHECKING, + Any, + Optional, + Union, +) + +import torch +from lightning_utilities.core.rank_zero import rank_zero_only as utils_rank_zero_only +from torch import Tensor +from torch.nn import Module +from torch.optim import Optimizer +from typing_extensions import override + +import lightning.pytorch as pl +from lightning.fabric.plugins import CheckpointIO, ClusterEnvironment +from lightning.fabric.plugins.collectives.torch_collective import default_pg_timeout +from lightning.fabric.strategies import _StrategyRegistry +from lightning.fabric.strategies.fsdp import ( + _METADATA_FILENAME, + _distributed_checkpoint_load, + _distributed_checkpoint_save, + _move_torchmetrics_to_device, + _optimizer_has_dtensor_params, +) +from lightning.fabric.utilities.distributed import ( + _distributed_is_initialized, + _get_default_process_group_backend_for_device, + _init_dist_connection, + _sync_ddp_if_available, +) +from lightning.fabric.utilities.distributed import group as _group +from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_6 +from lightning.fabric.utilities.init import _has_all_dtensor_params_or_buffers, _has_meta_device_parameters_or_buffers +from lightning.fabric.utilities.optimizer import _optimizers_to_device +from lightning.fabric.utilities.seed import reset_seed +from lightning.fabric.utilities.types import _PATH, ReduceOp +from lightning.pytorch.plugins.precision import Precision +from lightning.pytorch.plugins.precision.fsdp2 import FSDP2Precision +from lightning.pytorch.strategies.launchers.subprocess_script import _SubprocessScriptLauncher +from lightning.pytorch.strategies.parallel import ParallelStrategy +from lightning.pytorch.strategies.strategy import TBroadcast +from lightning.pytorch.trainer.states import TrainerFn +from lightning.pytorch.utilities.model_helpers import is_overridden +from lightning.pytorch.utilities.rank_zero import rank_zero_info, rank_zero_only, rank_zero_warn + +if TYPE_CHECKING: + from torch.distributed.device_mesh import DeviceMesh + from torch.distributed.fsdp import CPUOffloadPolicy, MixedPrecisionPolicy, OffloadPolicy + +if _TORCH_GREATER_EQUAL_2_6: + try: + from torch.distributed.checkpoint.stateful import Stateful as _TorchStateful + except ImportError: + + class _TorchStateful: # type: ignore[no-redef] + pass + +else: + + class _TorchStateful: # type: ignore[no-redef] + pass + + +log = logging.getLogger(__name__) + + +class FSDP2Strategy(ParallelStrategy): + r"""Strategy for Fully Sharded Data Parallel v2 (FSDP2) provided by ``torch.distributed``. + + FSDP2 is the next-generation implementation of Fully Sharded Data Parallel, built on top of the + ``DeviceMesh`` and ``DTensor`` abstractions. It provides a more robust and extensible way of + scaling models across devices, addressing many of the limitations and inconsistencies of the + original FSDP (referred to here as FSDP1). + + Compared to FSDP1, FSDP2 offers: + - Deterministic and composable sharding plans via ``DeviceMesh`` + - A unified tensor abstraction (``DTensor``) that enables interoperability between FSDP, + tensor parallelism, and pipeline parallelism + - Cleaner checkpointing semantics, reducing many of the loading/saving issues seen in FSDP1 + - Forward compatibility, as PyTorch maintainers are actively deprecating FSDP1 in favor of FSDP2 + + For background, see the RFC: + https://github.com/pytorch/pytorch/issues/114299 + + Arguments: + mp_policy: A ``MixedPrecisionPolicy`` object that specifies the precision policy for + model parameters and gradients when using mixed precision training with FSDP2. + cpu_offload: A ``CPUOffloadPolicy`` or boolean that specifies whether to offload + model parameters and gradients to CPU memory. If ``True``, offloading is enabled with default settings. + device_mesh: A :class:`torch.distributed.device_mesh.DeviceMesh` object that specifies + how devices are arranged and how tensors should be sharded/replicated. + \**kwargs: Additional keyword arguments passed to the underlying FSDP2 APIs. + + .. note:: + FSDP2 is still marked as "not fully stable" in PyTorch, but it is the recommended path + forward. FSDP1 will eventually be deprecated. Users are encouraged to migrate to FSDP2 + for new projects, but should test thoroughly before deploying in production-critical + environments. + + """ + + strategy_name = "fsdp2" + _registered_strategies: list[str] = [] + + def __init__( + self, + accelerator: Optional["pl.accelerators.Accelerator"] = None, + parallel_devices: Optional[list[torch.device]] = None, + cluster_environment: Optional[ClusterEnvironment] = None, + checkpoint_io: Optional[CheckpointIO] = None, + precision_plugin: Optional[Precision] = None, + process_group_backend: Optional[str] = None, + timeout: Optional[timedelta] = default_pg_timeout, + cpu_offload: Union[bool, "CPUOffloadPolicy", None] = None, + mp_policy: Optional["MixedPrecisionPolicy"] = None, + device_mesh: Optional[Union[tuple[int], "DeviceMesh"]] = None, + **kwargs: Any, + ) -> None: + if not _TORCH_GREATER_EQUAL_2_6: + raise ModuleNotFoundError( + "FSDP2Strategy requires torch>=2.6.0. " + f"Found torch {torch.__version__}. Please upgrade torch to use FSDP2Strategy." + ) + super().__init__( + accelerator=accelerator, + parallel_devices=parallel_devices, + cluster_environment=cluster_environment, + checkpoint_io=checkpoint_io, + precision_plugin=precision_plugin, + ) + self.num_nodes = 1 + self._process_group_backend = process_group_backend + self._timeout: Optional[timedelta] = timeout + self.cpu_offload = _init_fsdp2_cpu_offload(cpu_offload) + self.mp_policy = _init_fsdp2_mp_policy(mp_policy) + + self.device_mesh = device_mesh + self.kwargs = kwargs + + @property + @override + def root_device(self) -> torch.device: + assert self.parallel_devices is not None + return self.parallel_devices[self.local_rank] + + @property + def num_processes(self) -> int: + return len(self.parallel_devices) if self.parallel_devices is not None else 0 + + @property + def process_group_backend(self) -> Optional[str]: + return self._process_group_backend + + @property + @override + def precision_plugin(self) -> FSDP2Precision: + plugin = self._precision_plugin + if plugin is not None: + assert isinstance(plugin, FSDP2Precision) + return plugin + return FSDP2Precision("32-true") + + @precision_plugin.setter + @override + def precision_plugin(self, precision_plugin: Optional[Precision]) -> None: + if precision_plugin is not None and not isinstance(precision_plugin, FSDP2Precision): + raise TypeError( + f"The FSDP2 strategy can only work with the `FSDP2Precision` plugin, found {precision_plugin}" + ) + self._precision_plugin = precision_plugin + + @property + @override + def distributed_sampler_kwargs(self) -> dict: + return {"num_replicas": (self.num_nodes * self.num_processes), "rank": self.global_rank} + + @property + @override + def restore_checkpoint_after_setup(self) -> bool: + return True + + @property + @override + def lightning_restore_optimizer(self) -> bool: + return False + + @override + def setup_environment(self) -> None: + super().setup_environment() + log.debug(f"{self.__class__.__name__}: setting up distributed...") + reset_seed() + + # determine which process we are and world size + self.set_world_ranks() + + self._process_group_backend = self._get_process_group_backend() + assert self.cluster_environment is not None + kwargs: dict[str, Any] = {"timeout": self._timeout} + if _TORCH_GREATER_EQUAL_2_6: + kwargs["device_id"] = self.root_device if self.root_device.type != "cpu" else None + _init_dist_connection(self.cluster_environment, self._process_group_backend, **kwargs) + + # if device_mesh is None, get the world_size and create a 1D device mesh + if self.device_mesh is None: + world_size = self.cluster_environment.world_size() + self.device_mesh = (world_size,) # a 1-D tuple + # if 'device_mesh' is provided as a tuple, update it into the `DeviceMesh` object here + if isinstance(self.device_mesh, tuple): + from torch.distributed.device_mesh import init_device_mesh + + self.device_mesh = init_device_mesh("cuda", self.device_mesh) + + def _get_process_group_backend(self) -> str: + return self._process_group_backend or _get_default_process_group_backend_for_device(self.root_device) + + def set_world_ranks(self) -> None: + if self.cluster_environment is not None: + self.cluster_environment.set_global_rank(self.node_rank * self.num_processes + self.local_rank) + self.cluster_environment.set_world_size(self.num_nodes * self.num_processes) + # `LightningEnvironment.set_global_rank` will do this too, but we cannot rely on that implementation detail + # additionally, for some implementations, the setter is a no-op, so it's safer to access the getter + rank_zero_only.rank = utils_rank_zero_only.rank = self.global_rank + + @override + def _configure_launcher(self) -> None: + assert self.cluster_environment is not None + if not self.cluster_environment.creates_processes_externally: + self._launcher = _SubprocessScriptLauncher(self.cluster_environment, self.num_processes, self.num_nodes) + + @override + def _setup_model(self, model: Module) -> Module: + """Wraps the model into a :class:`~torch.distributed.fsdp.fully_sharded_data_parallel.FullyShardedDataParallel` + module.""" + from torch.distributed.fsdp import fully_shard + + if _has_all_dtensor_params_or_buffers(model): + rank_zero_warn("Model is already in DTensor format. FSDP2Strategy will not re-wrap the model.") + + else: + is_on_meta_device = True + if not _has_meta_device_parameters_or_buffers(model): + is_on_meta_device = False + rank_zero_warn( + "To make the best use of FSDP2 strategy, and to avoid unnecessary memory copies, we recommend to" + " initialize your model's parameters and buffers on the `meta` device." + ) + + log.debug(f"setting up FSDP model with device id: {self.root_device.index}, kwargs: {self.kwargs}") + if isinstance(self.device_mesh, tuple): + from torch.distributed.device_mesh import DeviceMesh + + self.device_mesh = DeviceMesh("cuda", self.device_mesh) + + if self.mp_policy is None: + raise ValueError("`mp_policy` cannot be None when calling `fully_shard`.") + + fully_shard( + module=model, + mesh=self.device_mesh, + mp_policy=self.mp_policy, + offload_policy=self.cpu_offload, + ) + + if is_on_meta_device: + # Allocate buffers and sharded parameters on device + model.to_empty(device=self.root_device) + + # Run your custom initialization + def init_weights(m: Module) -> None: + if isinstance(m, torch.nn.Linear): + torch.nn.init.kaiming_uniform_(m.weight) + if m.bias is not None: + torch.nn.init.zeros_(m.bias) + + model.apply(init_weights) + + _move_torchmetrics_to_device(model, self.root_device) + + return model + + @override + def setup(self, trainer: "pl.Trainer") -> None: + assert self.accelerator is not None + self.accelerator.setup(trainer) + + assert self.model is not None + if trainer.state.fn == TrainerFn.FITTING and self._layer_sync: + self.model = self._layer_sync.apply(self.model) + + self.model = self.precision_plugin.convert_module(self.model) + + if is_overridden("configure_sharded_model", self.lightning_module): + # legacy: we don't skip setup with the `configure_model` alternative + rank_zero_info( + "You have overridden `LightningModule.configure_sharded_model` hook. It will assume that all the layers" + " are already wrapped for sharding and won't wrap the entire model using `fully_shard`." + ) + else: + self.model = self._setup_model(self.model) + self.barrier() + + if trainer.state.fn == TrainerFn.FITTING: + self.setup_optimizers(trainer) + self.setup_precision_plugin() + if trainer.state.fn == TrainerFn.FITTING: + _optimizers_to_device(self.optimizers, self.root_device) + + @override + def setup_optimizers(self, trainer: "pl.Trainer") -> None: + # If we're setting up for evaluation after fitting, we need to discard the optimizers + # since we're rewrapping the model, otherwise optimizer param references are no longer valid + # and subsequent checkpoint saving can fail + self._reset_optimizers_and_schedulers() + + if self.kwargs.get("use_orig_params"): + return super().setup_optimizers(trainer) + + invalid_params_error = False + try: + # If `use_orig_params=False` the user needs to do access `self.trainer.model.parameters()` in + # `configure_optimizers()` + super().setup_optimizers(trainer) + except ValueError as ex: + if "optimizer got an empty parameter list" not in str(ex): + raise + invalid_params_error = True + + if invalid_params_error or any(not _optimizer_has_dtensor_params(optimizer) for optimizer in self.optimizers): + # We avoid this limitation by setting `use_orig_params=True` + raise ValueError( + "The optimizer does not seem to reference any FSDP parameters. HINT: Make sure to create the" + " optimizer after setting up the model by referencing `self.trainer.model.parameters()` in the" + " `configure_optimizers()` hook." + ) + return None + + @override + def model_to_device(self) -> None: + # FSDP takes care of moving the model to device + pass + + @contextmanager + @override + def tensor_init_context(self, empty_init: Optional[bool] = None) -> Generator[None, None, None]: + # Materialization happens in `setup`. When modules get wrapped by FSDP, the sequence of operations is: + # 1) materialize module 2) call `reset_parameters()` 3) shard the module. + # These operations are applied to each submodule 'bottom up' in the module hierarchy. + empty_init_context = torch.device("meta") if empty_init in (True, None) else nullcontext() + with empty_init_context, self.precision_plugin.tensor_init_context(): + yield + + @override + def barrier(self, name: Optional[str] = None) -> None: + if not _distributed_is_initialized(): + return + if torch.distributed.get_backend() == "nccl": + torch.distributed.barrier(device_ids=self._determine_device_ids()) + else: + torch.distributed.barrier() + + @override + def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast: + if not _distributed_is_initialized(): + return obj + + obj = [obj] + torch.distributed.broadcast_object_list(obj, src, group=_group.WORLD) + return obj[0] + + @override + def reduce( + self, + tensor: Union[Tensor, Any], + group: Optional[Any] = None, + reduce_op: Optional[Union[ReduceOp, str]] = "mean", + ) -> Tensor: + """Reduces a tensor from several distributed processes to one aggregated tensor. + + Args: + tensor: the tensor to sync and reduce + group: the process group to gather results from. Defaults to all processes (world) + reduce_op: the reduction operation. Defaults to 'mean'/'avg'. + Can also be a string 'sum' to calculate the sum during reduction. + + Return: + reduced value, except when the input was not a tensor the output remains is unchanged + + """ + if isinstance(tensor, Tensor): + return _sync_ddp_if_available(tensor, group, reduce_op=reduce_op) + return tensor + + def _determine_device_ids(self) -> list[int]: + return [self.root_device.index] + + @override + def teardown(self) -> None: + log.debug(f"{self.__class__.__name__}: tearing down strategy...") + + pl_module = self.lightning_module + if ( + pl_module is not None + # `self.lightning_module._trainer` can be None if teardown gets called on an exception before + # the trainer gets set on the LightningModule + and pl_module._trainer is not None + and pl_module._trainer.state.fn == TrainerFn.FITTING + and self._layer_sync + ): + assert self.model is not None + self.model = self._layer_sync.revert(self.model) + + assert self.cluster_environment is not None + assert self.accelerator is not None + self.cluster_environment.teardown() + self.precision_plugin.teardown() + self.accelerator.teardown() + + @classmethod + def get_registered_strategies(cls) -> list[str]: + return cls._registered_strategies + + @classmethod + @override + def register_strategies(cls, strategy_registry: _StrategyRegistry) -> None: + if not torch.distributed.is_available(): + return + strategy_registry.register( + "fsdp2", + cls, + description="FSDP2 training", + ) + cls._registered_strategies.append("fsdp2") + + strategy_registry.register( + "fsdp2_cpu_offload", + cls, + description="FSDP2 training with Full Sharding and CPU Offloading", + cpu_offload=True, + ) + cls._registered_strategies.append("fsdp2_cpu_offload") + + @override + def lightning_module_state_dict(self) -> dict[str, Any]: + assert self.model is not None + # Override to do nothing, `save_checkpoint()` method will take care of saving the model state + return {} + + @override + def load_model_state_dict(self, checkpoint: Mapping[str, Any], strict: bool = True) -> None: + # Override to do nothing, FSDP already loaded the states in `load_checkpoint()` + pass + + @override + def optimizer_state(self, optimizer: Optimizer) -> dict[str, Tensor]: + # Override to do nothing, `save_checkpoint()` method will take care of saving the optimizer state + return {} + + @override + def load_optimizer_state_dict(self, checkpoint: Mapping[str, Any]) -> None: + # Override to do nothing, the FSDP already loaded the states in `load_checkpoint()` + pass + + @override + def save_checkpoint( + self, checkpoint: dict[str, Any], filepath: _PATH, storage_options: Optional[Any] = None + ) -> None: + if storage_options is not None: + raise TypeError( + "`FSDP2Strategy.save_checkpoint(..., storage_options=...)` is not supported because" + " `FSDP2Strategy` does not use the `CheckpointIO`." + ) + + path = Path(self.broadcast(filepath)) + + if path.is_file(): + path.unlink() + path.mkdir(parents=True, exist_ok=True) + + if self.model is None: + raise RuntimeError( + "Cannot save checkpoint: FSDP2Strategy model is not initialized." + " Please ensure the strategy is set up before saving." + ) + state_dict = {"fsdp2_checkpoint_state_dict": AppState(self.model, self.optimizers)} + _distributed_checkpoint_save(state_dict, path) + + if self.global_rank == 0: + torch.save(checkpoint, path / _METADATA_FILENAME) + + @override + def load_checkpoint(self, checkpoint_path: _PATH) -> dict[str, Any]: + # broadcast the path from rank 0 to ensure all the states are loaded from a common path + path = Path(self.broadcast(checkpoint_path)) + + assert self.model is not None + assert self.lightning_module is not None + + state_dict = {"fsdp2_checkpoint_state_dict": AppState(self.model, self.optimizers)} + _distributed_checkpoint_load(state_dict, path) + + # Load metadata (anything not a module or optimizer) + metadata = torch.load(path / _METADATA_FILENAME) + return metadata + + +def _init_fsdp2_cpu_offload(cpu_offload: Optional[Union[bool, "CPUOffloadPolicy"]]) -> "OffloadPolicy": + from torch.distributed.fsdp import CPUOffloadPolicy, OffloadPolicy + + if cpu_offload is None or cpu_offload is False: + return OffloadPolicy() + + if cpu_offload is True: + return CPUOffloadPolicy(pin_memory=True) + + if isinstance(cpu_offload, CPUOffloadPolicy): + return cpu_offload + + raise TypeError(f"`cpu_offload` should be of type `bool` or `CPUOffloadPolicy`, got {type(cpu_offload)}") + + +def _init_fsdp2_mp_policy(mp_policy: Optional["MixedPrecisionPolicy"]) -> Optional["MixedPrecisionPolicy"]: + from torch.distributed.fsdp import MixedPrecisionPolicy + + if mp_policy is None: + return MixedPrecisionPolicy(param_dtype=None, reduce_dtype=None, output_dtype=None, cast_forward_inputs=True) + + if isinstance(mp_policy, MixedPrecisionPolicy): + return mp_policy + + raise TypeError(f"`mp_policy` should be of type `MixedPrecisionPolicy`, got {type(mp_policy)}") + + +# Code taken from: https://docs.pytorch.org/tutorials/recipes/distributed_checkpoint_recipe.html#saving +class AppState(_TorchStateful): + """This is a useful wrapper for checkpointing the Application State. Since this object is compliant with the + Stateful protocol, DCP will automatically call state_dict/load_stat_dict as needed in the dcp.save/load APIs. + + Note: We take advantage of this wrapper to handle calling distributed state dict methods on the model + and optimizer. + + """ + + def __init__(self, model: Module, optimizers: list[Optimizer]) -> None: + if not _TORCH_GREATER_EQUAL_2_6: + raise ModuleNotFoundError( + "AppState requires torch>=2.6.0. " + f"Found torch {torch.__version__}. Please upgrade torch to use AppState." + ) + self.model = model + self.optimizers = optimizers + + def state_dict(self) -> dict[str, Any]: + from torch.distributed.checkpoint.state_dict import get_state_dict + + # this line automatically manages FSDP FQN's, + # as well as sets the default state dict type to FSDP.SHARDED_STATE_DICT + model_state_dict, optimizer_state_dict = get_state_dict(self.model, self.optimizers) + return {"model": model_state_dict, "optim": optimizer_state_dict} + + def load_state_dict(self, state_dict: dict[str, Any]) -> None: + from torch.distributed.checkpoint.state_dict import set_state_dict + + # sets our state dicts on the model and optimizer, now that we've loaded + set_state_dict( + self.model, self.optimizers, model_state_dict=state_dict["model"], optim_state_dict=state_dict["optim"] + ) diff --git a/src/lightning/pytorch/trainer/connectors/accelerator_connector.py b/src/lightning/pytorch/trainer/connectors/accelerator_connector.py index 7f44de0589938..9028e0ca02470 100644 --- a/src/lightning/pytorch/trainer/connectors/accelerator_connector.py +++ b/src/lightning/pytorch/trainer/connectors/accelerator_connector.py @@ -42,6 +42,7 @@ CheckpointIO, DeepSpeedPrecision, DoublePrecision, + FSDP2Precision, FSDPPrecision, HalfPrecision, MixedPrecision, @@ -53,6 +54,7 @@ from lightning.pytorch.strategies import ( DDPStrategy, DeepSpeedStrategy, + FSDP2Strategy, FSDPStrategy, ModelParallelStrategy, ParallelStrategy, @@ -451,11 +453,19 @@ def _check_strategy_and_fallback(self) -> None: # TODO this logic should apply to both str and object config strategy_flag = "" if isinstance(self._strategy_flag, Strategy) else self._strategy_flag - if ( + is_fsdp1_str = ( strategy_flag in FSDPStrategy.get_registered_strategies() or type(self._strategy_flag) is FSDPStrategy - ) and not (self._accelerator_flag in ("cuda", "gpu") or isinstance(self._accelerator_flag, CUDAAccelerator)): + ) + is_fsdp2_str = ( + strategy_flag in FSDP2Strategy.get_registered_strategies() or type(self._strategy_flag) is FSDP2Strategy + ) + + if (is_fsdp1_str or is_fsdp2_str) and not ( + self._accelerator_flag in ("cuda", "gpu") or isinstance(self._accelerator_flag, CUDAAccelerator) + ): + strategy_name = FSDP2Strategy.strategy_name if is_fsdp2_str else FSDPStrategy.strategy_name raise ValueError( - f"The strategy `{FSDPStrategy.strategy_name}` requires a GPU accelerator, but received " + f"The strategy `{strategy_name}` requires a GPU accelerator, but received " f"`accelerator={self._accelerator_flag!r}`. Please set `accelerator='cuda'`, `accelerator='gpu'`," " or pass a `CUDAAccelerator()` instance to use FSDP." ) @@ -493,6 +503,8 @@ def _check_and_init_precision(self) -> Precision: return DeepSpeedPrecision(self._precision_flag) # type: ignore[arg-type] if isinstance(self.strategy, FSDPStrategy): return FSDPPrecision(self._precision_flag) # type: ignore[arg-type] + if isinstance(self.strategy, FSDP2Strategy): + return FSDP2Precision(self._precision_flag) # type: ignore[arg-type] if self._precision_flag in ("16-true", "bf16-true"): return HalfPrecision(self._precision_flag) # type: ignore if self._precision_flag == "32-true": diff --git a/tests/tests_pytorch/strategies/test_fsdp2.py b/tests/tests_pytorch/strategies/test_fsdp2.py new file mode 100644 index 0000000000000..38509feb3f258 --- /dev/null +++ b/tests/tests_pytorch/strategies/test_fsdp2.py @@ -0,0 +1,446 @@ +import os +from copy import deepcopy +from pathlib import Path +from re import escape +from typing import Optional +from unittest.mock import Mock + +import pytest +import torch +import torch.nn as nn +from torchmetrics import Accuracy + +from lightning.fabric.utilities.init import _has_all_dtensor_params_or_buffers +from lightning.fabric.utilities.load import _load_distributed_checkpoint +from lightning.pytorch import Trainer +from lightning.pytorch.demos.boring_classes import BoringModel +from lightning.pytorch.strategies import FSDP2Strategy +from lightning.pytorch.utilities.consolidate_checkpoint import _format_checkpoint +from tests_pytorch.helpers.runif import RunIf + + +# Minimal boring model for FSDP2 tests (used for DDP/FSDP2 checkpoint compatibility) +class TestBoringModel(BoringModel): + def __init__(self): + super().__init__() + self.save_hyperparameters() + self.layer = torch.nn.Sequential(torch.nn.Linear(32, 32), torch.nn.ReLU(), torch.nn.Linear(32, 2)) + + def configure_optimizers(self): + return torch.optim.AdamW(self.parameters(), lr=0.1) + + +class TestFSDP2Model(BoringModel): + def __init__(self): + super().__init__() + self.layer: Optional[nn.Module] = None + + def _init_model(self) -> None: + self.layer = torch.nn.Sequential(torch.nn.Linear(32, 32), torch.nn.ReLU(), torch.nn.Linear(32, 2)) + + def configure_optimizers(self): + # There is some issue with SGD optimizer state in FSDP + return torch.optim.AdamW(self.layer.parameters(), lr=0.1) + + def on_train_batch_start(self, batch, batch_idx): + assert batch.dtype == torch.float32 + + def on_train_batch_end(self, _, batch, batch_idx): + assert batch.dtype == torch.float32 + self._assert_layer_fsdp2_instance() + + def on_test_batch_end(self, _, batch, batch_idx): + assert batch.dtype == torch.float32 + self._assert_layer_fsdp2_instance() + + def on_validation_batch_end(self, _, batch, batch_idx): + assert batch.dtype == torch.float32 + self._assert_layer_fsdp2_instance() + + def on_predict_batch_end(self, _, batch, batch_idx): + assert batch.dtype == torch.float32 + self._assert_layer_fsdp2_instance() + + def _assert_layer_fsdp2_instance(self): + # FSDP2 injects an internal `_fsdp_state` attribute and replaces all parameters/buffers with DTensors. + assert hasattr(self.layer, "_fsdp_state") + assert _has_all_dtensor_params_or_buffers(self.layer) + + +class TestFSDP2ModelAutoWrapped(TestBoringModel): + def on_train_batch_start(self, batch, batch_idx): + assert batch.dtype == torch.float32 + + def on_train_batch_end(self, _, batch, batch_idx): + assert batch.dtype == torch.float32 + self._assert_layer_fsdp2_instance() + + def on_test_batch_end(self, _, batch, batch_idx): + assert batch.dtype == torch.float32 + self._assert_layer_fsdp2_instance() + + def on_validation_batch_end(self, _, batch, batch_idx): + assert batch.dtype == torch.float32 + self._assert_layer_fsdp2_instance() + + def on_predict_batch_end(self, _, batch, batch_idx): + assert batch.dtype == torch.float32 + self._assert_layer_fsdp2_instance() + + def _assert_layer_fsdp2_instance(self): + assert hasattr(self.layer, "_fsdp_state") + assert _has_all_dtensor_params_or_buffers(self.layer) + + +def _run_multiple_stages(trainer, model, model_path: Optional[str] = None): + trainer.fit(model) + trainer.test(model) + + model_path = trainer.strategy.broadcast(model_path) + model_path = Path(model_path if model_path else trainer.checkpoint_callback.last_model_path) + + # Save another checkpoint after testing, without optimizer states + trainer.save_checkpoint(model_path.with_name("after-test")) + trainer.save_checkpoint(model_path, weights_only=True) + + if not model_path.is_dir(): # TODO (@awaelchli): Add support for asserting equality of sharded checkpoints + _assert_save_equality(trainer, model_path, cls=model.__class__) + + with torch.inference_mode(): + # Test entry point + trainer.test(model) # model is wrapped, will not call `configure_model` + + # provide model path, will create a new unwrapped model and load and then call `configure_shared_model` to wrap + trainer.test(ckpt_path=model_path) + + # Predict entry point + trainer.predict(model) # model is wrapped, will not call `configure_model` + + # provide model path, will create a new unwrapped model and load and then call `configure_shared_model` to wrap + trainer.predict(ckpt_path=model_path) + + +def _assert_save_equality(trainer, ckpt_path, cls=TestFSDP2Model): + # Use FullySharded to get the state dict for the sake of comparison + model_state_dict = trainer.strategy.lightning_module_state_dict() + + if trainer.is_global_zero: + saved_model = cls.load_from_checkpoint(ckpt_path) + + # Assert model parameters are identical after loading + for ddp_param, shard_param in zip(model_state_dict.values(), saved_model.state_dict().values()): + assert torch.equal(ddp_param, shard_param) + + +@RunIf(min_torch="2.6.0") +@pytest.mark.parametrize("strategy", ["fsdp2", "fsdp2_cpu_offload"]) +def test_invalid_on_cpu(tmp_path, cuda_count_0, strategy): + """Test to ensure that we raise Misconfiguration for FSDP on CPU.""" + with pytest.raises(ValueError, match="The strategy `fsdp2` requires a GPU accelerator"): + trainer = Trainer(accelerator="cpu", default_root_dir=tmp_path, fast_dev_run=True, strategy=strategy) + assert isinstance(trainer.strategy, FSDP2Strategy) + trainer.strategy.setup_environment() + + +@RunIf(min_torch="2.6.0") +def test_custom_mixed_precision(): + """Test to ensure that passing a custom mixed precision config works.""" + from torch.distributed.fsdp import MixedPrecisionPolicy + + # custom mp policy + mp_policy = MixedPrecisionPolicy( + param_dtype=torch.bfloat16, reduce_dtype=torch.float16, output_dtype=torch.float16, cast_forward_inputs=True + ) + strategy = FSDP2Strategy(mp_policy=mp_policy) + assert strategy.mp_policy == mp_policy + + # default mp policy + strategy = FSDP2Strategy(mp_policy=None) + assert isinstance(strategy.mp_policy, MixedPrecisionPolicy) + assert strategy.mp_policy.param_dtype is None + assert strategy.mp_policy.reduce_dtype is None + assert strategy.mp_policy.output_dtype is None + assert strategy.mp_policy.cast_forward_inputs is True + + # invalid mp policy + class InvalidMPPolicy: + pass + + with pytest.raises(TypeError, match="`mp_policy` should be of type `MixedPrecisionPolicy`"): + FSDP2Strategy(mp_policy=InvalidMPPolicy()) + + +@pytest.mark.filterwarnings("ignore::FutureWarning") +@RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True, min_torch="2.6.0") +def test_strategy_sync_batchnorm(tmp_path): + """Test to ensure that sync_batchnorm works when using FSDP and GPU, and all stages can be run.""" + model = TestFSDP2Model() + trainer = Trainer( + default_root_dir=tmp_path, + accelerator="gpu", + devices=2, + strategy="fsdp2", + precision="16-true", + max_epochs=1, + sync_batchnorm=True, + ) + _run_multiple_stages(trainer, model, os.path.join(tmp_path, "last.ckpt")) + + +@pytest.mark.filterwarnings("ignore::FutureWarning") +@RunIf(min_cuda_gpus=1, skip_windows=True, min_torch="2.6.0") +def test_modules_without_parameters(tmp_path): + """Test that TorchMetrics get moved to the device despite not having any parameters.""" + + class MetricsModel(BoringModel): + def __init__(self): + super().__init__() + self.metric = Accuracy("multiclass", num_classes=10) + assert self.metric.device == self.metric.tp.device == torch.device("cpu") + + def setup(self, stage) -> None: + assert self.metric.device == self.metric.tp.device == torch.device("cpu") + + def training_step(self, batch, batch_idx): + loss = super().training_step(batch, batch_idx) + assert self.metric.device == self.metric.tp.device == torch.device("cuda", 0) + self.metric(torch.rand(2, 10, device=self.device), torch.randint(0, 10, size=(2,), device=self.device)) + return loss + + model = MetricsModel() + trainer = Trainer( + default_root_dir=tmp_path, + accelerator="cuda", + devices=1, + strategy="fsdp2", + max_steps=1, + ) + trainer.fit(model) + + +@pytest.mark.filterwarnings("ignore::FutureWarning") +@RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True, min_torch="2.6.0") +@pytest.mark.parametrize("precision", ["16-mixed", pytest.param("bf16-mixed", marks=RunIf(bf16_cuda=True))]) +def test_strategy_checkpoint(state_dict_type, precision, tmp_path): + """Test to ensure that checkpoint is saved correctly when using a single GPU, and all stages can be run.""" + model = TestFSDP2Model() + strategy = FSDP2Strategy() + trainer = Trainer( + default_root_dir=tmp_path, accelerator="gpu", devices=2, strategy=strategy, precision=precision, max_epochs=1 + ) + _run_multiple_stages(trainer, model, os.path.join(tmp_path, "last.ckpt")) + + +def custom_auto_wrap_policy( + module, + recurse, + nonwrapped_numel: int, +) -> bool: + return nonwrapped_numel >= 2 + + +@RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True, min_torch="2.6.0") +@pytest.mark.parametrize( + ("precision", "expected_dtype"), + [ + ("32-true", torch.float32), + ], +) +def test_configure_model(precision, expected_dtype, tmp_path): + """Test that the module under configure_model gets moved to the right device and dtype.""" + trainer = Trainer( + default_root_dir=tmp_path, + accelerator="cuda", + devices=2, + strategy=FSDP2Strategy(), + precision=precision, + max_epochs=1, + enable_checkpointing=False, + logger=False, + ) + + class MyModel(BoringModel): + def configure_model(self): + self.layer = torch.nn.Linear(32, 2) + # The model is on the CPU until after `.setup()`` + # TODO: Support initialization on meta device + expected_device = torch.device("cpu") + assert self.layer.weight.device == expected_device + assert self.layer.weight.dtype == expected_dtype + + def configure_optimizers(self): + # There is some issue with SGD optimizer state in FSDP + return torch.optim.AdamW(self.layer.parameters(), lr=0.1) + + def on_fit_start(self): + # Parameters get sharded in `.setup()` and moved to the target device + assert self.layer.weight.device == torch.device("cuda", self.local_rank) + assert self.layer.weight.dtype == expected_dtype + + model = MyModel() + trainer.fit(model) + + +@RunIf(min_torch="2.6.0") +def test_save_checkpoint_storage_options(tmp_path): + """Test that the FSDP strategy does not accept storage options for saving checkpoints.""" + strategy = FSDP2Strategy() + with pytest.raises(TypeError, match=escape("FSDP2Strategy.save_checkpoint(..., storage_options=...)` is not")): + strategy.save_checkpoint(filepath=tmp_path, checkpoint=Mock(), storage_options=Mock()) + + +class TestFSDP2CheckpointModel(BoringModel): + def __init__(self, params_to_compare=None): + super().__init__() + self.layer = torch.nn.Sequential(torch.nn.Linear(32, 32), torch.nn.ReLU(), torch.nn.Linear(32, 2)) + self.params_to_compare = params_to_compare + + def configure_optimizers(self): + # SGD's FSDP optimizer, state is fixed in https://github.com/pytorch/pytorch/pull/99214 + return torch.optim.AdamW(self.parameters(), lr=0.1) + + def on_train_start(self): + if self.params_to_compare is None: + return + for p0, p1 in zip(self.params_to_compare, self.trainer.model.parameters()): + torch.testing.assert_close(p0, p1, atol=0, rtol=0, equal_nan=True) + + +@pytest.mark.filterwarnings("ignore::FutureWarning") +@RunIf(min_cuda_gpus=2, standalone=True, min_torch="2.6.0") +def test_save_load_sharded_state_dict(tmp_path): + """Test FSDP saving and loading with the sharded state dict format.""" + strategy = FSDP2Strategy() + trainer_kwargs = { + "default_root_dir": tmp_path, + "accelerator": "cuda", + "devices": 2, + "max_epochs": 1, + "enable_progress_bar": False, + "enable_model_summary": False, + "logger": False, + } + + # Initial training + model = TestFSDP2CheckpointModel() + trainer = Trainer(**trainer_kwargs, strategy=strategy) + trainer.fit(model) + params_before = deepcopy(list(trainer.model.parameters())) + + checkpoint_path = Path(trainer.strategy.broadcast(trainer.checkpoint_callback.best_model_path)) + assert set(os.listdir(checkpoint_path)) == {"meta.pt", ".metadata", "__0_0.distcp", "__1_0.distcp"} + + metadata = torch.load(checkpoint_path / "meta.pt", weights_only=True) + assert "pytorch-lightning_version" in metadata + assert len(metadata["callbacks"]) == 1 # model checkpoint callback + assert "state_dict" not in metadata + assert "optimizer_states" not in metadata + + # Load checkpoint and continue training + trainer_kwargs.update(max_epochs=2) + model = TestFSDP2CheckpointModel(params_to_compare=params_before) + strategy = FSDP2Strategy() + trainer = Trainer(**trainer_kwargs, strategy=strategy) + trainer.fit(model, ckpt_path=checkpoint_path) + + +@RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True, min_torch="2.6.0") +@pytest.mark.parametrize( + ("precision", "expected_dtype"), + [ + ("32-true", torch.float32), + ("16-true", torch.float16), + ], +) +def test_module_init_context(precision, expected_dtype, tmp_path): + """Test that the module under the init-context gets moved to the right device and dtype.""" + + class Model(BoringModel): + def configure_optimizers(self): + return torch.optim.Adam(self.parameters(), lr=1e-2) + + def on_train_start(self): + # Parameters get sharded in `FSDPStrategy.setup()` and moved to the target device + assert self.layer.weight.device == torch.device("cuda", self.local_rank) + assert self.layer.weight.dtype == expected_dtype + optimizer = self.optimizers(use_pl_optimizer=False) + assert optimizer.param_groups[0]["params"][0].device.type == "cuda" + + def _run_setup_assertions(empty_init, expected_device): + trainer = Trainer( + default_root_dir=tmp_path, + accelerator="cuda", + devices=2, + strategy=FSDP2Strategy(), + precision=precision, + max_steps=1, + barebones=True, + enable_checkpointing=False, + logger=False, + ) + with trainer.init_module(empty_init=empty_init): + model = Model() + + # The model is on the CPU/meta-device until after `FSDPStrategy.setup()` + assert model.layer.weight.device == expected_device + assert model.layer.weight.dtype == expected_dtype + trainer.fit(model) + + # Case 1: No empty init + _run_setup_assertions(empty_init=False, expected_device=torch.device("cpu")) + + # Case 2: Empty-init with meta device + _run_setup_assertions(empty_init=True, expected_device=torch.device("meta")) + + +@pytest.mark.filterwarnings("ignore::FutureWarning") +@RunIf(min_cuda_gpus=2, standalone=True, min_torch="2.6.0") +def test_save_sharded_and_consolidate_and_load(tmp_path): + """Test the consolidation of a FSDP2-sharded checkpoint into a single file.""" + + class CustomModel(BoringModel): + def configure_optimizers(self): + # Use Adam instead of SGD for this test because it has state + # In PyTorch >= 2.4, saving an optimizer with empty state would result in a `KeyError: 'state'` + # when loading the optimizer state-dict back. + # TODO: To resolve this, switch to the new `torch.distributed.checkpoint` APIs in FSDPStrategy + return torch.optim.Adam(self.parameters(), lr=0.1) + + model = CustomModel() + trainer = Trainer( + default_root_dir=tmp_path, + accelerator="cuda", + devices=2, + strategy=FSDP2Strategy(), + max_steps=3, + ) + trainer.fit(model) + + checkpoint_path_sharded = trainer.strategy.broadcast(str(trainer.checkpoint_callback.best_model_path)) + assert set(os.listdir(checkpoint_path_sharded)) == {"meta.pt", ".metadata", "__0_0.distcp", "__1_0.distcp"} + + # consolidate the checkpoint to a single file + checkpoint_path_full = trainer.strategy.broadcast(str(tmp_path / "checkpoint_full.ckpt")) + if trainer.global_rank == 0: + checkpoint = _load_distributed_checkpoint(Path(checkpoint_path_sharded)) + checkpoint = _format_checkpoint(checkpoint) + torch.save(checkpoint, checkpoint_path_full) + trainer.strategy.barrier() + + model = CustomModel() + trainer = Trainer( + default_root_dir=tmp_path, + accelerator="cuda", + devices=2, + strategy="ddp", + max_steps=4, + ) + trainer.fit(model, ckpt_path=checkpoint_path_full) + + +@RunIf(max_torch="2.5", min_cuda_gpus=1) +@pytest.mark.parametrize("strategy", ["fsdp2", "fsdp2_cpu_offload"]) +def test_fsdp2_requires_torch_2_6_or_newer(tmp_path, strategy): + """FSDP2 strategies should error on torch < 2.6.""" + with pytest.raises(ModuleNotFoundError, match="FSDP2Strategy requires torch>=2.6.0."): + Trainer(default_root_dir=tmp_path, fast_dev_run=True, strategy=strategy)