Skip to content

Commit 89298ad

Browse files
fduwjjpytorchmergebot
authored andcommitted
[device_mesh] Implement _unflatten on top of CuTe layout bookkeeping (pytorch#161224)
Pull Request resolved: pytorch#161224 Approved by: https://github.com/lw, https://github.com/fegin ghstack dependencies: pytorch#164510
1 parent c467e59 commit 89298ad

File tree

3 files changed

+267
-0
lines changed

3 files changed

+267
-0
lines changed

test/distributed/test_device_mesh.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
# Owner(s): ["oncall: distributed"]
33
import os
44
import unittest
5+
from datetime import timedelta
56

67
import torch
78
import torch.distributed as dist
@@ -40,6 +41,13 @@
4041
device_type = acc.type if (acc := torch.accelerator.current_accelerator()) else "cpu"
4142
device_count = torch.accelerator.device_count()
4243

44+
try:
45+
import torch._C._distributed_c10d.ProcessGroupNCCL
46+
47+
_NCCL_AVAILABLE = True
48+
except ImportError:
49+
_NCCL_AVAILABLE = False
50+
4351

4452
def _set_env_var(addr="localhost", port="25364", world_size=1, rank=0, local_rank=-1):
4553
os.environ["MASTER_ADDR"] = addr
@@ -962,6 +970,85 @@ def test_flatten_mesh_4d(self):
962970
# check flattened mesh dependency
963971
self.assertEqual(dp_cp_mesh._get_root_mesh(), mesh_4d)
964972

973+
@with_comms
974+
def test_unflatten_mesh_2d(self):
975+
mesh_shape = (4, 2)
976+
mesh_dim_names = ("dp", "tp")
977+
mesh_2d = init_device_mesh(
978+
self.device_type, mesh_shape, mesh_dim_names=mesh_dim_names
979+
)
980+
unflatten_mesh = mesh_2d._unflatten(0, (2, 2), ("dp_shard", "dp_replicate"))
981+
self.assertEqual(
982+
unflatten_mesh.mesh_dim_names, ["dp_shard", "dp_replicate", "tp"]
983+
)
984+
self.assertEqual(mesh_2d["tp"].mesh, unflatten_mesh["tp"].mesh)
985+
self.assertEqual(mesh_2d["tp"].get_group(), unflatten_mesh["tp"].get_group())
986+
987+
# Not supporting slicing out unflatten dim name from root mesh.
988+
with self.assertRaises(KeyError):
989+
self.assertEqual(mesh_2d["dp_shard"].mesh, unflatten_mesh["dp_shard"].mesh)
990+
991+
@with_comms
992+
def test_unflatten_mesh_3d(self):
993+
# Test unflatten from a dummy world mesh, which is the case we need for Expert Parallelism(EP).
994+
global_mesh = init_device_mesh(
995+
self.device_type,
996+
(8,),
997+
mesh_dim_names=("world",),
998+
)
999+
non_ep_mesh = global_mesh._unflatten(0, (2, 2, 2), ("dp", "cp", "tp"))
1000+
ep_mesh = global_mesh._unflatten(0, (2, 2, 2), ("dp", "ep", "ep_tp"))
1001+
self.assertEqual(non_ep_mesh["cp"].mesh, ep_mesh["ep"].mesh)
1002+
self.assertEqual(non_ep_mesh["tp"].mesh, ep_mesh["ep_tp"].mesh)
1003+
mesh_3d = global_mesh._unflatten(0, (4, 2, 1), ("dp", "cp", "tp"))
1004+
unflatten_mesh = mesh_3d._unflatten(0, (2, 2), ("dp_shard", "dp_replicate"))
1005+
self.assertEqual(
1006+
unflatten_mesh.mesh_dim_names, ["dp_shard", "dp_replicate", "cp", "tp"]
1007+
)
1008+
self.assertEqual(mesh_3d["tp"].mesh, unflatten_mesh["tp"].mesh)
1009+
self.assertEqual(mesh_3d["tp"].get_group(), unflatten_mesh["tp"].get_group())
1010+
self.assertEqual(mesh_3d["cp"].mesh, unflatten_mesh["cp"].mesh)
1011+
self.assertEqual(mesh_3d["cp"].get_group(), unflatten_mesh["cp"].get_group())
1012+
1013+
# Test unflatten with backend override set.
1014+
if not _NCCL_AVAILABLE:
1015+
return
1016+
opts = dist.ProcessGroupNCCL.Options()
1017+
opts._timeout = timedelta(seconds=30)
1018+
mesh_2d = global_mesh._unflatten(
1019+
0,
1020+
(1, 8),
1021+
("pp", "spmd"),
1022+
backend_override={"pp": "fake", "spmd": ("nccl", opts)},
1023+
)
1024+
opts = dist.ProcessGroupNCCL.Options()
1025+
opts._timeout = timedelta(seconds=60)
1026+
mesh_4d = mesh_2d._unflatten(
1027+
1,
1028+
(2, 2, 2),
1029+
("dp", "cp", "tp"),
1030+
backend_override={"dp": "nccl", "cp": "nccl", "tp": ("nccl", opts)},
1031+
)
1032+
self.assertEqual(mesh_4d["pp"].get_group()._get_backend_name(), "custom")
1033+
spmd_pg = mesh_2d["spmd"].get_group()
1034+
self.assertEqual(spmd_pg._get_backend_name(), "nccl")
1035+
w = spmd_pg.allreduce(torch.rand(10).cuda(self.rank))
1036+
self.assertTrue(
1037+
spmd_pg._get_backend(
1038+
torch.device(f"cuda:{self.rank}")
1039+
)._verify_work_timeout(w, timedelta(seconds=30))
1040+
)
1041+
w.wait()
1042+
tp_pg = mesh_4d["tp"].get_group()
1043+
self.assertEqual(tp_pg._get_backend_name(), "nccl")
1044+
w = tp_pg.allreduce(torch.rand(10).cuda(self.rank))
1045+
self.assertTrue(
1046+
tp_pg._get_backend(torch.device(f"cuda:{self.rank}"))._verify_work_timeout(
1047+
w, timedelta(seconds=60)
1048+
)
1049+
)
1050+
w.wait()
1051+
9651052
@with_comms
9661053
def test_reconstruct_mesh_with_flatten_dim(self):
9671054
mesh_3d = init_device_mesh(

torch/distributed/_mesh_layout.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
is_int,
1818
is_tuple,
1919
Layout,
20+
suffix_product,
2021
)
2122

2223

@@ -148,6 +149,52 @@ def complement(self, world_size: int) -> "_MeshLayout":
148149
layout = complement(self, world_size)
149150
return _MeshLayout(layout.shape, layout.stride)
150151

152+
def unflatten(self, dim: int, unflatten_sizes: tuple[int, ...]) -> "_MeshLayout":
153+
"""
154+
Unflatten a single dimension in the layout by splitting it into multiple dimensions.
155+
It takes a dimension at position `dim` and splits it into multiple new dimensions
156+
with the specified sizes.
157+
158+
Args:
159+
dim (int): The index of the dimension to unflatten. Must be a valid dimension index.
160+
unflatten_sizes (tuple[int, ...]): The new sizes for the dimensions that will replace
161+
the original dimension at `dim`. The product of these sizes must equal the size
162+
of the original dimension at `dim`.
163+
164+
Returns:
165+
_MeshLayout: A new layout with the specified dimension unflattened.
166+
167+
Example:
168+
Original: sizes=(8,), strides=(1,) # 8 ranks in 1D
169+
Call: unflatten(0, (2, 2, 2)) # Create 3D topology
170+
Result: sizes=(2, 2, 2), strides=(4, 2, 1) # 2*2*2 unflattened topology
171+
"""
172+
# Check that dim is within valid range
173+
if dim < 0 or dim >= len(self):
174+
raise ValueError(
175+
f"dim {dim} is out of range for layout with {len(self)} dimensions. "
176+
f"Expected dim to be in range [0, {len(self) - 1}]."
177+
)
178+
179+
# Check that the product of unflatten_sizes equals the original dimension size
180+
original_size = self[dim].numel()
181+
unflatten_product = math.prod(unflatten_sizes)
182+
if unflatten_product != original_size:
183+
raise ValueError(
184+
f"The product of unflatten_sizes {unflatten_sizes} is {unflatten_product}, "
185+
f"but the original dimension at dim={dim} has size {original_size}. "
186+
f"These must be equal for unflatten to work correctly."
187+
)
188+
189+
sizes = list(self.sizes) # type: ignore[arg-type]
190+
strides = list(self.strides) # type: ignore[arg-type]
191+
unflatten_layout = self[dim].composition(
192+
_MeshLayout(tuple(unflatten_sizes), suffix_product(unflatten_sizes))
193+
)
194+
sizes[dim : dim + 1] = list(unflatten_layout.sizes) # type: ignore[arg-type]
195+
strides[dim : dim + 1] = list(unflatten_layout.strides) # type: ignore[arg-type]
196+
return _MeshLayout(tuple(sizes), tuple(strides))
197+
151198
def all_ranks_from_zero(self) -> list[int]:
152199
"""
153200
This function computes the all ranks specified by the layout staring from zero.

torch/distributed/device_mesh.py

Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -353,6 +353,10 @@ def _init_process_groups(
353353
-1, self.mesh.size(dim)
354354
)
355355
backend, pg_options = backend_override[dim]
356+
# We need to explicitly pass in timeout when specified in option, otherwise
357+
# the default timeout will be used to override the timeout set in option.
358+
# TODO: remove this once we have fixed inside c10d level.
359+
timeout = pg_options._timeout if pg_options else None
356360

357361
# If we have a 2D mesh with mesh_dim_names ("dp", "tp"), the group description
358362
# of the subgroups would be `mesh_dim_dp` and `mesh_name_tp`.
@@ -390,6 +394,7 @@ def _init_process_groups(
390394
):
391395
dim_group = split_group(
392396
parent_pg=default_group,
397+
timeout=timeout,
393398
pg_options=pg_options,
394399
split_ranks=pg_ranks_by_dim.tolist(),
395400
group_desc=group_desc,
@@ -410,6 +415,7 @@ def _init_process_groups(
410415
if bound_device_id is None or not has_split_group:
411416
dim_group = new_group(
412417
ranks=subgroup_ranks,
418+
timeout=timeout,
413419
backend=backend,
414420
pg_options=pg_options,
415421
group_desc=group_desc,
@@ -1093,6 +1099,133 @@ def _flatten(
10931099

10941100
return self._create_flatten_mesh(mesh_dim_name, backend_override_tuple)
10951101

1102+
def _create_unflatten_mesh(
1103+
self,
1104+
dim: int,
1105+
mesh_sizes: tuple[int, ...],
1106+
mesh_dim_names: tuple[str, ...],
1107+
backend_override: tuple[
1108+
tuple[Optional[str], Optional[C10dBackend.Options]], ...
1109+
] = ((None, None),),
1110+
) -> "DeviceMesh":
1111+
root_mesh = self._get_root_mesh()
1112+
cur_rank = self.get_rank()
1113+
unflattened_layout = self._layout.unflatten(dim, mesh_sizes)
1114+
pg_ranks_by_dim = unflattened_layout.remap_to_tensor(
1115+
root_mesh.mesh,
1116+
)
1117+
unflattened_mesh_dim_names = list(not_none(self.mesh_dim_names))
1118+
unflattened_mesh_dim_names[dim : dim + 1] = list(mesh_dim_names)
1119+
res_mesh = DeviceMesh._create_mesh_from_ranks(
1120+
self.device_type,
1121+
pg_ranks_by_dim,
1122+
cur_rank,
1123+
tuple(unflattened_mesh_dim_names),
1124+
_init_backend=False,
1125+
_layout=unflattened_layout,
1126+
_root_mesh=root_mesh,
1127+
)
1128+
1129+
# If original mesh has initiated its backend, we need to initialize the backend
1130+
# of unflatten dims as well.
1131+
# TODO: To make backend init more efficient with cute layout representation and support
1132+
# per dim backend init.
1133+
if hasattr(self, "_dim_group_names"):
1134+
unflatten_length = len(mesh_sizes)
1135+
unflatten_layout = _MeshLayout(
1136+
tuple(unflattened_layout.sizes[dim : dim + unflatten_length]), # type: ignore[index]
1137+
tuple(unflattened_layout.strides[dim : dim + unflatten_length]), # type: ignore[index]
1138+
)
1139+
unflatten_pg_ranks_by_dim = unflatten_layout.remap_to_tensor(
1140+
root_mesh.mesh,
1141+
)
1142+
unflatten_submesh = DeviceMesh._create_mesh_from_ranks(
1143+
self.device_type,
1144+
unflatten_pg_ranks_by_dim,
1145+
cur_rank,
1146+
mesh_dim_names,
1147+
backend_override=backend_override,
1148+
)
1149+
dim_group_names = []
1150+
for idx in range(0, res_mesh.ndim):
1151+
if idx < dim:
1152+
dim_group_names.append(self._dim_group_names[idx])
1153+
elif idx >= dim + unflatten_length:
1154+
dim_group_names.append(
1155+
self._dim_group_names[idx - unflatten_length + 1]
1156+
)
1157+
else:
1158+
dim_group_names.append(
1159+
unflatten_submesh._dim_group_names[idx - dim]
1160+
)
1161+
res_mesh._dim_group_names = dim_group_names
1162+
1163+
return res_mesh
1164+
1165+
def _unflatten(
1166+
self,
1167+
dim: Union[int, str],
1168+
mesh_sizes: tuple[int, ...],
1169+
mesh_dim_names: tuple[str, ...],
1170+
backend_override: Optional[
1171+
dict[
1172+
str,
1173+
Union[str, C10dBackend.Options, tuple[str, C10dBackend.Options]],
1174+
]
1175+
] = None,
1176+
) -> "DeviceMesh":
1177+
"""
1178+
Returns a DeviceMesh by unflatten the current DeviceMesh.
1179+
1180+
This api can be used to unflatten a N-D DeviceMesh into N-1+len(mesh_sizes)-D meshes or submeshes.
1181+
The dim is the dimension to be unflattened which can be either a string or an integer.
1182+
1183+
The mesh_sizes is a tuple which specifies the shape of the mesh unflatten into for the given dim.
1184+
The mesh_dim_names is a list of strings which specifies the names of the dimensions of the mesh unflatten into.
1185+
Its length must match the length of mesh_sizes.
1186+
1187+
For example, if we have a 1D mesh DeviceMesh([0, 1, 2, 3, 4, 5, 6, 7], mesh_dim_names=("world")),
1188+
calling mesh_1d._unflatten(0, (2, 2, 4), ["dp", "pp", "tp"]) will create a 3D mesh
1189+
DeviceMesh([[[0, 1], [2, 3]], [[4, 5], [6, 7]]], mesh_dim_names=("dp", "cp", "tp")).
1190+
1191+
Note that after calling the unflatten, there is no access to the unflattened dimension in mesh_1d, one can only
1192+
use the newly unflattened mesh to slice out the unflattened mesh dims.
1193+
"""
1194+
if isinstance(dim, int) and dim >= self.ndim:
1195+
raise ValueError(
1196+
f"dim {dim} specified in `_unflatten` is out of range {self.ndim}"
1197+
)
1198+
elif isinstance(dim, str) and dim in not_none(self.mesh_dim_names):
1199+
raise ValueError(
1200+
f"dim {dim} specified in `_unflatten` is not in {self.mesh_dim_names}"
1201+
)
1202+
1203+
if len(mesh_sizes) != len(mesh_dim_names):
1204+
raise RuntimeError(
1205+
"mesh_dim_names must have same length as mesh_sizes in _unflatten!"
1206+
)
1207+
1208+
if isinstance(dim, str):
1209+
dim = not_none(self.mesh_dim_names).index(dim)
1210+
1211+
if backend_override is not None:
1212+
backend_override_tuple = tuple(
1213+
_normalize_backend_override(
1214+
backend_override, # type: ignore[arg-type]
1215+
len(mesh_sizes),
1216+
mesh_dim_names,
1217+
)
1218+
)
1219+
else:
1220+
backend_override_tuple = ((None, None),) * len(mesh_dim_names)
1221+
1222+
return self._create_unflatten_mesh(
1223+
dim,
1224+
mesh_sizes,
1225+
mesh_dim_names,
1226+
backend_override_tuple,
1227+
)
1228+
10961229
def _normalize_backend_override(
10971230
backend_override: dict[
10981231
Union[int, str],

0 commit comments

Comments
 (0)