Skip to content

Commit cc6de82

Browse files
committed
works. i'm still worthy
1 parent 35bf1a2 commit cc6de82

File tree

5 files changed

+25
-4
lines changed

5 files changed

+25
-4
lines changed

src/lightning/fabric/strategies/fsdp.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from lightning_utilities.core.imports import RequirementCache
3232
from lightning_utilities.core.rank_zero import rank_zero_only as utils_rank_zero_only
3333
from torch import Tensor
34+
from torch.distributed.tensor import DTensor
3435
from torch.nn import Module
3536
from torch.optim import Optimizer
3637
from typing_extensions import TypeGuard, override
@@ -795,6 +796,10 @@ def _optimizer_has_flat_params(optimizer: Optimizer) -> bool:
795796
)
796797

797798

799+
def _optimizer_has_dtensor_params(optimizer: Optimizer) -> bool:
800+
return any(isinstance(param, DTensor) for group in optimizer.param_groups for param in group["params"])
801+
802+
798803
def _get_sharded_state_dict_context(module: Module) -> Generator[None, None, None]:
799804
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
800805
from torch.distributed.fsdp.api import ShardedOptimStateDictConfig, ShardedStateDictConfig, StateDictType

src/lightning/pytorch/plugins/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from lightning.pytorch.plugins.precision.deepspeed import DeepSpeedPrecision
99
from lightning.pytorch.plugins.precision.double import DoublePrecision
1010
from lightning.pytorch.plugins.precision.fsdp import FSDPPrecision
11+
from lightning.pytorch.plugins.precision.fsdp2 import FSDP2Precision
1112
from lightning.pytorch.plugins.precision.half import HalfPrecision
1213
from lightning.pytorch.plugins.precision.precision import Precision
1314
from lightning.pytorch.plugins.precision.transformer_engine import TransformerEnginePrecision
@@ -28,6 +29,7 @@
2829
"Precision",
2930
"TransformerEnginePrecision",
3031
"FSDPPrecision",
32+
"FSDP2Precision",
3133
"XLAPrecision",
3234
"LayerSync",
3335
"TorchSyncBatchNorm",

src/lightning/pytorch/strategies/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from lightning.pytorch.strategies.ddp import DDPStrategy
1919
from lightning.pytorch.strategies.deepspeed import DeepSpeedStrategy
2020
from lightning.pytorch.strategies.fsdp import FSDPStrategy
21+
from lightning.pytorch.strategies.fsdp2 import FSDP2Strategy
2122
from lightning.pytorch.strategies.model_parallel import ModelParallelStrategy
2223
from lightning.pytorch.strategies.parallel import ParallelStrategy
2324
from lightning.pytorch.strategies.single_device import SingleDeviceStrategy
@@ -32,6 +33,7 @@
3233
"DDPStrategy",
3334
"DeepSpeedStrategy",
3435
"FSDPStrategy",
36+
"FSDP2Strategy",
3537
"ModelParallelStrategy",
3638
"ParallelStrategy",
3739
"SingleDeviceStrategy",

src/lightning/pytorch/strategies/fsdp2.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
_distributed_checkpoint_load,
4242
_distributed_checkpoint_save,
4343
_move_torchmetrics_to_device,
44-
_optimizer_has_flat_params,
44+
_optimizer_has_dtensor_params,
4545
)
4646
from lightning.fabric.utilities.distributed import (
4747
_distributed_is_initialized,
@@ -139,6 +139,7 @@ def __init__(
139139
self.mp_policy = _init_fsdp2_mp_policy(mp_policy)
140140

141141
self.device_mesh = device_mesh
142+
self.kwargs = kwargs
142143

143144
@property
144145
@override
@@ -249,12 +250,19 @@ def _setup_model(self, model: Module) -> Module:
249250
)
250251

251252
log.debug(f"setting up FSDP model with device id: {self.root_device.index}, kwargs: {self.kwargs}")
253+
if isinstance(self.device_mesh, tuple):
254+
from torch.distributed.device_mesh import DeviceMesh
255+
256+
self.device_mesh = DeviceMesh("cuda", self.device_mesh)
257+
258+
if self.mp_policy is None:
259+
raise ValueError("`mp_policy` cannot be None when calling `fully_shard`.")
260+
252261
fully_shard(
253262
module=model,
254263
mesh=self.device_mesh,
255264
mp_policy=self.mp_policy,
256265
offload_policy=self.cpu_offload,
257-
cpu_offload=self.cpu_offload,
258266
)
259267

260268
if is_on_meta_device:
@@ -321,7 +329,7 @@ def setup_optimizers(self, trainer: "pl.Trainer") -> None:
321329
raise
322330
invalid_params_error = True
323331

324-
if invalid_params_error or any(not _optimizer_has_flat_params(optimizer) for optimizer in self.optimizers):
332+
if invalid_params_error or any(not _optimizer_has_dtensor_params(optimizer) for optimizer in self.optimizers):
325333
# We avoid this limitation by setting `use_orig_params=True`
326334
raise ValueError(
327335
"The optimizer does not seem to reference any FSDP parameters. HINT: Make sure to create the"
@@ -428,7 +436,7 @@ def register_strategies(cls, strategy_registry: _StrategyRegistry) -> None:
428436
cls._registered_strategies.append("fsdp2")
429437

430438
strategy_registry.register(
431-
"fsdp_cpu_offload",
439+
"fsdp2_cpu_offload",
432440
cls,
433441
description="FSDP2 training with Full Sharding and CPU Offloading",
434442
cpu_offload=True,

src/lightning/pytorch/trainer/connectors/accelerator_connector.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
CheckpointIO,
4343
DeepSpeedPrecision,
4444
DoublePrecision,
45+
FSDP2Precision,
4546
FSDPPrecision,
4647
HalfPrecision,
4748
MixedPrecision,
@@ -53,6 +54,7 @@
5354
from lightning.pytorch.strategies import (
5455
DDPStrategy,
5556
DeepSpeedStrategy,
57+
FSDP2Strategy,
5658
FSDPStrategy,
5759
ModelParallelStrategy,
5860
ParallelStrategy,
@@ -493,6 +495,8 @@ def _check_and_init_precision(self) -> Precision:
493495
return DeepSpeedPrecision(self._precision_flag) # type: ignore[arg-type]
494496
if isinstance(self.strategy, FSDPStrategy):
495497
return FSDPPrecision(self._precision_flag) # type: ignore[arg-type]
498+
if isinstance(self.strategy, FSDP2Strategy):
499+
return FSDP2Precision(self._precision_flag) # type: ignore[arg-type]
496500
if self._precision_flag in ("16-true", "bf16-true"):
497501
return HalfPrecision(self._precision_flag) # type: ignore
498502
if self._precision_flag == "32-true":

0 commit comments

Comments
 (0)