Skip to content

Commit 5633283

Browse files
tianyu-lpytorchmergebot
authored andcommitted
[reland][DTensor][FSDP2] necessary changes to FSDP and TP to unblock EP (pytorch#158204)
This PR is identical to pytorch#157216, which got reverted because of removing an outdated import of `torch._dynamo` https://www.internalfb.com/diff/D78021229?transaction_fbid=1713683499308113 The issue has been fixed by @weifengpy by D78199546, so this PR should be good to re-land. Pull Request resolved: pytorch#158204 Approved by: https://github.com/weifengpy
1 parent 5b10b0a commit 5633283

File tree

4 files changed

+10
-93
lines changed

4 files changed

+10
-93
lines changed

test/distributed/_composable/test_composability/test_2d_composability.py

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -554,21 +554,6 @@ def _compare_params(self, m1, m2):
554554
p2 = p2.redistribute(p2.device_mesh, [Replicate()]).to_local()
555555
self.assertTrue(torch.allclose(p1, p2), f"{p1} vs {p2}")
556556

557-
@with_comms
558-
@skip_if_lt_x_gpu(4)
559-
def test_raise_invalid_tp_composition(self):
560-
with self.assertRaisesRegex(
561-
RuntimeError, r"Found TP device_mesh on the \d dimension of its parent mesh"
562-
):
563-
mesh_2d = init_device_mesh(
564-
self.device_type, (2, self.world_size // 2), mesh_dim_names=("tp", "dp")
565-
)
566-
parallelize_plan = {
567-
"net1": ColwiseParallel(),
568-
"net2": RowwiseParallel(),
569-
}
570-
parallelize_module(SimpleModel().cuda(), mesh_2d["tp"], parallelize_plan)
571-
572557
@with_comms
573558
@skip_if_lt_x_gpu(4)
574559
def test_2d_fsdp_state_enable_extension(self):

torch/distributed/fsdp/_fully_shard/_fsdp_param.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -292,21 +292,22 @@ def _init_sharded_param(
292292
dp_global_mesh is None or tp_global_mesh is None
293293
):
294294
raise AssertionError(
295-
"FSDP requires the DP and TP mesh to have the same parent mesh but got: \n"
296-
f"DP's global mesh: {dp_global_mesh}\nTP's global mesh: {tp_global_mesh}"
295+
"FSDP requires the DP and model parallel TP/EP mesh to have the same parent mesh but got: \n"
296+
f"DP's global mesh: {dp_global_mesh}\nTP/EP's global mesh: {tp_global_mesh}"
297297
)
298298
name_dims_error = "FSDP requires named DeviceMesh dims for ND parallelism"
299299
assert dp_mesh.mesh_dim_names is not None, name_dims_error
300300
assert tp_mesh.mesh_dim_names is not None, name_dims_error
301301
submesh_names = dp_mesh.mesh_dim_names + tp_mesh.mesh_dim_names
302302
self._spmd_mesh = dp_global_mesh[submesh_names]
303-
if len(self._tp_spec.placements) != 1:
303+
if len(self._tp_spec.placements) > 2:
304304
raise NotImplementedError(
305-
f"FSDP only supports 1D TP, not {self._tp_spec.placements}"
305+
f"FSDP only supports 1D TP/EP or 2D EP+TP, not {self._tp_spec.placements}"
306306
)
307307
split_factor = self._tp_spec.num_shards_map[shard_dim]
308-
assert 2 <= self._spmd_mesh.ndim <= 3, (
309-
f"_spmd_mesh.ndim can only be 2 or 3 but got {self._spmd_mesh.ndim}."
308+
assert 2 <= self._spmd_mesh.ndim <= 4, (
309+
"_spmd_mesh.ndim can only be 2 (FSDP+TP/EP), 3 (FSDP+EP+TP, HSDP+TP/EP), "
310+
f"or 4 (HSDP+EP+TP) but got {self._spmd_mesh.ndim}."
310311
)
311312
self._spmd_placements: tuple[Placement, ...]
312313
dp_shard_tp_placement = (
@@ -315,11 +316,11 @@ def _init_sharded_param(
315316
if split_factor > 1
316317
else fsdp_placement
317318
),
318-
self._tp_spec.placements[0],
319+
*self._tp_spec.placements,
319320
)
320-
if self._spmd_mesh.ndim == 2:
321+
if dp_mesh.ndim == 1: # FSDP
321322
self._spmd_placements = dp_shard_tp_placement
322-
else:
323+
else: # HSDP
323324
assert self.mesh_info.replicate_mesh_dim == 0
324325
self._spmd_placements = (Replicate(),) + dp_shard_tp_placement
325326
self._sharding_spec = DTensorSpec(

torch/distributed/tensor/parallel/_utils.py

Lines changed: 0 additions & 67 deletions
This file was deleted.

torch/distributed/tensor/parallel/api.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
import torch
77
import torch.nn as nn
88
from torch.distributed.device_mesh import _mesh_resources, DeviceMesh
9-
from torch.distributed.tensor.parallel._utils import _validate_tp_mesh_dim
109
from torch.distributed.tensor.parallel.style import ParallelStyle
1110

1211

@@ -71,7 +70,6 @@ def parallelize_module( # type: ignore[return]
7170
torch._C._log_api_usage_once("torch.distributed.tensor.parallel.parallelize_module")
7271

7372
device_mesh = device_mesh or _mesh_resources.get_current_mesh()
74-
_validate_tp_mesh_dim(device_mesh)
7573

7674
if parallelize_plan is None:
7775
warnings.warn(

0 commit comments

Comments
 (0)