@@ -353,6 +353,10 @@ def _init_process_groups(
353353 - 1 , self .mesh .size (dim )
354354 )
355355 backend , pg_options = backend_override [dim ]
356+ # We need to explicitly pass in timeout when specified in option, otherwise
357+ # the default timeout will be used to override the timeout set in option.
358+ # TODO: remove this once we have fixed inside c10d level.
359+ timeout = pg_options ._timeout if pg_options else None
356360
357361 # If we have a 2D mesh with mesh_dim_names ("dp", "tp"), the group description
358362 # of the subgroups would be `mesh_dim_dp` and `mesh_name_tp`.
@@ -390,6 +394,7 @@ def _init_process_groups(
390394 ):
391395 dim_group = split_group (
392396 parent_pg = default_group ,
397+ timeout = timeout ,
393398 pg_options = pg_options ,
394399 split_ranks = pg_ranks_by_dim .tolist (),
395400 group_desc = group_desc ,
@@ -410,6 +415,7 @@ def _init_process_groups(
410415 if bound_device_id is None or not has_split_group :
411416 dim_group = new_group (
412417 ranks = subgroup_ranks ,
418+ timeout = timeout ,
413419 backend = backend ,
414420 pg_options = pg_options ,
415421 group_desc = group_desc ,
@@ -1093,6 +1099,133 @@ def _flatten(
10931099
10941100 return self ._create_flatten_mesh (mesh_dim_name , backend_override_tuple )
10951101
1102+ def _create_unflatten_mesh (
1103+ self ,
1104+ dim : int ,
1105+ mesh_sizes : tuple [int , ...],
1106+ mesh_dim_names : tuple [str , ...],
1107+ backend_override : tuple [
1108+ tuple [Optional [str ], Optional [C10dBackend .Options ]], ...
1109+ ] = ((None , None ),),
1110+ ) -> "DeviceMesh" :
1111+ root_mesh = self ._get_root_mesh ()
1112+ cur_rank = self .get_rank ()
1113+ unflattened_layout = self ._layout .unflatten (dim , mesh_sizes )
1114+ pg_ranks_by_dim = unflattened_layout .remap_to_tensor (
1115+ root_mesh .mesh ,
1116+ )
1117+ unflattened_mesh_dim_names = list (not_none (self .mesh_dim_names ))
1118+ unflattened_mesh_dim_names [dim : dim + 1 ] = list (mesh_dim_names )
1119+ res_mesh = DeviceMesh ._create_mesh_from_ranks (
1120+ self .device_type ,
1121+ pg_ranks_by_dim ,
1122+ cur_rank ,
1123+ tuple (unflattened_mesh_dim_names ),
1124+ _init_backend = False ,
1125+ _layout = unflattened_layout ,
1126+ _root_mesh = root_mesh ,
1127+ )
1128+
1129+ # If original mesh has initiated its backend, we need to initialize the backend
1130+ # of unflatten dims as well.
1131+ # TODO: To make backend init more efficient with cute layout representation and support
1132+ # per dim backend init.
1133+ if hasattr (self , "_dim_group_names" ):
1134+ unflatten_length = len (mesh_sizes )
1135+ unflatten_layout = _MeshLayout (
1136+ tuple (unflattened_layout .sizes [dim : dim + unflatten_length ]), # type: ignore[index]
1137+ tuple (unflattened_layout .strides [dim : dim + unflatten_length ]), # type: ignore[index]
1138+ )
1139+ unflatten_pg_ranks_by_dim = unflatten_layout .remap_to_tensor (
1140+ root_mesh .mesh ,
1141+ )
1142+ unflatten_submesh = DeviceMesh ._create_mesh_from_ranks (
1143+ self .device_type ,
1144+ unflatten_pg_ranks_by_dim ,
1145+ cur_rank ,
1146+ mesh_dim_names ,
1147+ backend_override = backend_override ,
1148+ )
1149+ dim_group_names = []
1150+ for idx in range (0 , res_mesh .ndim ):
1151+ if idx < dim :
1152+ dim_group_names .append (self ._dim_group_names [idx ])
1153+ elif idx >= dim + unflatten_length :
1154+ dim_group_names .append (
1155+ self ._dim_group_names [idx - unflatten_length + 1 ]
1156+ )
1157+ else :
1158+ dim_group_names .append (
1159+ unflatten_submesh ._dim_group_names [idx - dim ]
1160+ )
1161+ res_mesh ._dim_group_names = dim_group_names
1162+
1163+ return res_mesh
1164+
1165+ def _unflatten (
1166+ self ,
1167+ dim : Union [int , str ],
1168+ mesh_sizes : tuple [int , ...],
1169+ mesh_dim_names : tuple [str , ...],
1170+ backend_override : Optional [
1171+ dict [
1172+ str ,
1173+ Union [str , C10dBackend .Options , tuple [str , C10dBackend .Options ]],
1174+ ]
1175+ ] = None ,
1176+ ) -> "DeviceMesh" :
1177+ """
1178+ Returns a DeviceMesh by unflatten the current DeviceMesh.
1179+
1180+ This api can be used to unflatten a N-D DeviceMesh into N-1+len(mesh_sizes)-D meshes or submeshes.
1181+ The dim is the dimension to be unflattened which can be either a string or an integer.
1182+
1183+ The mesh_sizes is a tuple which specifies the shape of the mesh unflatten into for the given dim.
1184+ The mesh_dim_names is a list of strings which specifies the names of the dimensions of the mesh unflatten into.
1185+ Its length must match the length of mesh_sizes.
1186+
1187+ For example, if we have a 1D mesh DeviceMesh([0, 1, 2, 3, 4, 5, 6, 7], mesh_dim_names=("world")),
1188+ calling mesh_1d._unflatten(0, (2, 2, 4), ["dp", "pp", "tp"]) will create a 3D mesh
1189+ DeviceMesh([[[0, 1], [2, 3]], [[4, 5], [6, 7]]], mesh_dim_names=("dp", "cp", "tp")).
1190+
1191+ Note that after calling the unflatten, there is no access to the unflattened dimension in mesh_1d, one can only
1192+ use the newly unflattened mesh to slice out the unflattened mesh dims.
1193+ """
1194+ if isinstance (dim , int ) and dim >= self .ndim :
1195+ raise ValueError (
1196+ f"dim { dim } specified in `_unflatten` is out of range { self .ndim } "
1197+ )
1198+ elif isinstance (dim , str ) and dim in not_none (self .mesh_dim_names ):
1199+ raise ValueError (
1200+ f"dim { dim } specified in `_unflatten` is not in { self .mesh_dim_names } "
1201+ )
1202+
1203+ if len (mesh_sizes ) != len (mesh_dim_names ):
1204+ raise RuntimeError (
1205+ "mesh_dim_names must have same length as mesh_sizes in _unflatten!"
1206+ )
1207+
1208+ if isinstance (dim , str ):
1209+ dim = not_none (self .mesh_dim_names ).index (dim )
1210+
1211+ if backend_override is not None :
1212+ backend_override_tuple = tuple (
1213+ _normalize_backend_override (
1214+ backend_override , # type: ignore[arg-type]
1215+ len (mesh_sizes ),
1216+ mesh_dim_names ,
1217+ )
1218+ )
1219+ else :
1220+ backend_override_tuple = ((None , None ),) * len (mesh_dim_names )
1221+
1222+ return self ._create_unflatten_mesh (
1223+ dim ,
1224+ mesh_sizes ,
1225+ mesh_dim_names ,
1226+ backend_override_tuple ,
1227+ )
1228+
10961229 def _normalize_backend_override (
10971230 backend_override : dict [
10981231 Union [int , str ],
0 commit comments