@@ -102,14 +102,12 @@ class FSDP2Strategy(ParallelStrategy):
102
102
https://github.com/pytorch/pytorch/issues/114299
103
103
104
104
Arguments:
105
+ mp_policy: A ``MixedPrecisionPolicy`` object that specifies the precision policy for
106
+ model parameters and gradients when using mixed precision training with FSDP2.
107
+ cpu_offload: A ``CPUOffloadPolicy`` or boolean that specifies whether to offload
108
+ model parameters and gradients to CPU memory. If ``True``, offloading is enabled with default settings.
105
109
device_mesh: A :class:`torch.distributed.device_mesh.DeviceMesh` object that specifies
106
110
how devices are arranged and how tensors should be sharded/replicated.
107
- parallelize_module: Optional policy function or mapping that specifies how to wrap or
108
- distribute submodules of the model using ``DTensor``.
109
- checkpoint_policy: Defines how checkpoint saving/loading is performed with DTensor-based
110
- modules. See ``torch.distributed.checkpoint`` for available options.
111
- mixed_precision: Optional policy for mixed precision training. Can be used to specify
112
- precision for parameters, gradients, and buffers.
113
111
\**kwargs: Additional keyword arguments passed to the underlying FSDP2 APIs.
114
112
115
113
.. note::
@@ -125,7 +123,6 @@ class FSDP2Strategy(ParallelStrategy):
125
123
126
124
def __init__ (
127
125
self ,
128
- device_mesh : Optional [Union [tuple [int ], "DeviceMesh" ]] = None ,
129
126
accelerator : Optional ["pl.accelerators.Accelerator" ] = None ,
130
127
parallel_devices : Optional [list [torch .device ]] = None ,
131
128
cluster_environment : Optional [ClusterEnvironment ] = None ,
@@ -135,6 +132,7 @@ def __init__(
135
132
timeout : Optional [timedelta ] = default_pg_timeout ,
136
133
cpu_offload : Union [bool , "CPUOffloadPolicy" , None ] = None ,
137
134
mp_policy : Optional ["MixedPrecisionPolicy" ] = None ,
135
+ device_mesh : Optional [Union [tuple [int ], "DeviceMesh" ]] = None ,
138
136
** kwargs : Any ,
139
137
) -> None :
140
138
if not _TORCH_GREATER_EQUAL_2_6 :
0 commit comments