Skip to content

Commit b83990d

Browse files
JKSenthilmeta-codesync[bot]
authored andcommitted
support fsdp2 sharding on largest dim (#1037)
Summary: Pull Request resolved: #1037 Reviewed By: galrotem Differential Revision: D83673518 fbshipit-source-id: 13a5c758bfe46ee4e2dd07719db600e2b23a7a21
1 parent e4357a1 commit b83990d

File tree

2 files changed

+47
-1
lines changed

2 files changed

+47
-1
lines changed

tests/utils/test_prepare_module.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -322,6 +322,29 @@ def test_fsdp2_mesh(self, mock_fully_shard: Mock) -> None:
322322
module, mesh=mock_mesh, reshard_after_forward=False
323323
)
324324

325+
@patch("torchtnt.utils.prepare_module.fully_shard")
326+
def test_fsdp2_shard_on_largest_dim(self, mock_fully_shard: Mock) -> None:
327+
"""
328+
Test that shard on largest dim function is used
329+
"""
330+
331+
module = torch.nn.Linear(2, 2, device="cpu")
332+
mock_mesh = MagicMock(spec=DeviceMesh)
333+
mock_global_mesh = MagicMock(spec=GlobalMeshCoordinator)
334+
mock_global_mesh.dp_mesh = mock_mesh
335+
336+
strategy = FSDP2Strategy(
337+
modules_to_shard=[torch.nn.Linear], shard_on_largest_dim=True
338+
)
339+
module = prepare_fsdp2(
340+
module,
341+
torch.device("cpu"),
342+
strategy,
343+
global_mesh=mock_global_mesh,
344+
)
345+
# Check that "shard_placement_fn" is in the kwargs passed to fully_shard
346+
self.assertIn("shard_placement_fn", mock_fully_shard.call_args.kwargs)
347+
325348
@patch("torchtnt.utils.prepare_module._prepare_module_2d")
326349
@patch("torchtnt.utils.prepare_module._prepare_module_1d")
327350
def test_prepare_module_dispatching(

torchtnt/utils/prepare_module.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,9 @@
4141
get_optimizer_state_dict,
4242
set_optimizer_state_dict,
4343
)
44-
from torch.distributed.device_mesh import init_device_mesh
44+
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
4545
from torch.distributed.fsdp.fully_sharded_data_parallel import FullStateDictConfig
46+
from torch.distributed.tensor import Shard
4647
from torch.distributed.tensor.parallel import parallelize_module
4748
from torch.distributed.tensor.parallel.style import ParallelStyle
4849
from torchtnt.utils.device_mesh import GlobalMeshCoordinator
@@ -199,6 +200,8 @@ class FSDP2Strategy(Strategy):
199200
reshard_after_forward: If True, reshards parameters post-forward pass to save memory.
200201
mp_policy: Controls mixed precision policy. If only dtype is provided, it will be used to cast all relevant parts of model. If None, no mixed precision is used
201202
cpu_offload: If True, enables CPU offloading of model parameters to reduce GPU memory usage.
203+
shard_on_largest_dim: If True, shards on the largest dimension of the parameter. By default FSDP shards on the first dimension, and if it is small will end up replicated on all ranks, which ends up increasing
204+
memory usage as world size increases.
202205
203206
Note:
204207
It is recommended to specify specific modules to shard to avoid unnecessary sharding of all submodules, which has
@@ -240,6 +243,9 @@ class FSDP2Strategy(Strategy):
240243
mp_policy: Optional[Union[str, torch.dtype, MixedPrecisionPolicy]] = None
241244
cpu_offload: bool = False
242245

246+
# experimental flag
247+
shard_on_largest_dim: bool = False
248+
243249

244250
@dataclass
245251
class TPStrategy(Strategy):
@@ -409,6 +415,7 @@ def prepare_fsdp2(
409415
strategy = strategy or FSDP2Strategy()
410416

411417
# prepare kwargs for fully_shard api
418+
mesh: DeviceMesh
412419
if global_mesh is None:
413420
pg = dist.distributed_c10d._get_default_group()
414421
mesh = init_device_mesh(device.type, mesh_shape=(pg.size(),))
@@ -438,6 +445,22 @@ def prepare_fsdp2(
438445
reduce_dtype=mp_policy,
439446
output_dtype=mp_policy,
440447
)
448+
if strategy.shard_on_largest_dim:
449+
450+
# From the docs: https://docs.pytorch.org/docs/stable/distributed.fsdp.fully_shard.html
451+
# "If sharding on a nonzero dim, we currently require even sharding, i.e. the tensor dim size on that dim must be divisible by the FSDP shard mesh size."
452+
453+
# So we shard on a candidate nonzero dim only when it's divisible by the fsdp world size
454+
455+
def _shard_placement_fn(param: torch.nn.Parameter) -> Optional[Shard]:
456+
largest_dim_size = max(param.shape)
457+
idx = param.shape.index(largest_dim_size)
458+
if idx != 0 and largest_dim_size % mesh.size() != 0:
459+
# not divisible, so we return None to shard on default dim 0
460+
return None
461+
return Shard(idx)
462+
463+
fsdp_kwargs["shard_placement_fn"] = _shard_placement_fn
441464

442465
# parse out the modules_to_shard argument
443466
modules_to_shard = strategy.modules_to_shard

0 commit comments

Comments
 (0)