|
41 | 41 | get_optimizer_state_dict, |
42 | 42 | set_optimizer_state_dict, |
43 | 43 | ) |
44 | | -from torch.distributed.device_mesh import init_device_mesh |
| 44 | +from torch.distributed.device_mesh import DeviceMesh, init_device_mesh |
45 | 45 | from torch.distributed.fsdp.fully_sharded_data_parallel import FullStateDictConfig |
| 46 | +from torch.distributed.tensor import Shard |
46 | 47 | from torch.distributed.tensor.parallel import parallelize_module |
47 | 48 | from torch.distributed.tensor.parallel.style import ParallelStyle |
48 | 49 | from torchtnt.utils.device_mesh import GlobalMeshCoordinator |
@@ -199,6 +200,8 @@ class FSDP2Strategy(Strategy): |
199 | 200 | reshard_after_forward: If True, reshards parameters post-forward pass to save memory. |
200 | 201 | mp_policy: Controls mixed precision policy. If only dtype is provided, it will be used to cast all relevant parts of model. If None, no mixed precision is used |
201 | 202 | cpu_offload: If True, enables CPU offloading of model parameters to reduce GPU memory usage. |
| 203 | + shard_on_largest_dim: If True, shards on the largest dimension of the parameter. By default FSDP shards on the first dimension, and if it is small will end up replicated on all ranks, which ends up increasing |
| 204 | + memory usage as world size increases. |
202 | 205 |
|
203 | 206 | Note: |
204 | 207 | It is recommended to specify specific modules to shard to avoid unnecessary sharding of all submodules, which has |
@@ -240,6 +243,9 @@ class FSDP2Strategy(Strategy): |
240 | 243 | mp_policy: Optional[Union[str, torch.dtype, MixedPrecisionPolicy]] = None |
241 | 244 | cpu_offload: bool = False |
242 | 245 |
|
| 246 | + # experimental flag |
| 247 | + shard_on_largest_dim: bool = False |
| 248 | + |
243 | 249 |
|
244 | 250 | @dataclass |
245 | 251 | class TPStrategy(Strategy): |
@@ -409,6 +415,7 @@ def prepare_fsdp2( |
409 | 415 | strategy = strategy or FSDP2Strategy() |
410 | 416 |
|
411 | 417 | # prepare kwargs for fully_shard api |
| 418 | + mesh: DeviceMesh |
412 | 419 | if global_mesh is None: |
413 | 420 | pg = dist.distributed_c10d._get_default_group() |
414 | 421 | mesh = init_device_mesh(device.type, mesh_shape=(pg.size(),)) |
@@ -438,6 +445,22 @@ def prepare_fsdp2( |
438 | 445 | reduce_dtype=mp_policy, |
439 | 446 | output_dtype=mp_policy, |
440 | 447 | ) |
| 448 | + if strategy.shard_on_largest_dim: |
| 449 | + |
| 450 | + # From the docs: https://docs.pytorch.org/docs/stable/distributed.fsdp.fully_shard.html |
| 451 | + # "If sharding on a nonzero dim, we currently require even sharding, i.e. the tensor dim size on that dim must be divisible by the FSDP shard mesh size." |
| 452 | + |
| 453 | + # So we shard on a candidate nonzero dim only when it's divisible by the fsdp world size |
| 454 | + |
| 455 | + def _shard_placement_fn(param: torch.nn.Parameter) -> Optional[Shard]: |
| 456 | + largest_dim_size = max(param.shape) |
| 457 | + idx = param.shape.index(largest_dim_size) |
| 458 | + if idx != 0 and largest_dim_size % mesh.size() != 0: |
| 459 | + # not divisible, so we return None to shard on default dim 0 |
| 460 | + return None |
| 461 | + return Shard(idx) |
| 462 | + |
| 463 | + fsdp_kwargs["shard_placement_fn"] = _shard_placement_fn |
441 | 464 |
|
442 | 465 | # parse out the modules_to_shard argument |
443 | 466 | modules_to_shard = strategy.modules_to_shard |
|
0 commit comments