Skip to content

Commit cf6bbf1

Browse files
committed
update
1 parent a84b9b1 commit cf6bbf1

File tree

1 file changed

+5
-7
lines changed
  • src/lightning/pytorch/strategies

1 file changed

+5
-7
lines changed

src/lightning/pytorch/strategies/fsdp2.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -102,14 +102,12 @@ class FSDP2Strategy(ParallelStrategy):
102102
https://github.com/pytorch/pytorch/issues/114299
103103
104104
Arguments:
105+
mp_policy: A ``MixedPrecisionPolicy`` object that specifies the precision policy for
106+
model parameters and gradients when using mixed precision training with FSDP2.
107+
cpu_offload: A ``CPUOffloadPolicy`` or boolean that specifies whether to offload
108+
model parameters and gradients to CPU memory. If ``True``, offloading is enabled with default settings.
105109
device_mesh: A :class:`torch.distributed.device_mesh.DeviceMesh` object that specifies
106110
how devices are arranged and how tensors should be sharded/replicated.
107-
parallelize_module: Optional policy function or mapping that specifies how to wrap or
108-
distribute submodules of the model using ``DTensor``.
109-
checkpoint_policy: Defines how checkpoint saving/loading is performed with DTensor-based
110-
modules. See ``torch.distributed.checkpoint`` for available options.
111-
mixed_precision: Optional policy for mixed precision training. Can be used to specify
112-
precision for parameters, gradients, and buffers.
113111
\**kwargs: Additional keyword arguments passed to the underlying FSDP2 APIs.
114112
115113
.. note::
@@ -125,7 +123,6 @@ class FSDP2Strategy(ParallelStrategy):
125123

126124
def __init__(
127125
self,
128-
device_mesh: Optional[Union[tuple[int], "DeviceMesh"]] = None,
129126
accelerator: Optional["pl.accelerators.Accelerator"] = None,
130127
parallel_devices: Optional[list[torch.device]] = None,
131128
cluster_environment: Optional[ClusterEnvironment] = None,
@@ -135,6 +132,7 @@ def __init__(
135132
timeout: Optional[timedelta] = default_pg_timeout,
136133
cpu_offload: Union[bool, "CPUOffloadPolicy", None] = None,
137134
mp_policy: Optional["MixedPrecisionPolicy"] = None,
135+
device_mesh: Optional[Union[tuple[int], "DeviceMesh"]] = None,
138136
**kwargs: Any,
139137
) -> None:
140138
if not _TORCH_GREATER_EQUAL_2_6:

0 commit comments

Comments
 (0)