|
34 | 34 | from torch.cuda import _lazy_call, _lazy_init |
35 | 35 | from torch.cuda import device as device_ctx_manager |
36 | 36 | from torch.distributed import DeviceMesh, ProcessGroup |
37 | | -from torch.distributed.device_mesh import _mesh_resources |
38 | 37 |
|
39 | 38 | logger = logging.getLogger(__name__) |
40 | 39 |
|
@@ -150,30 +149,50 @@ def is_float8tensor(tensor: torch.Tensor) -> bool: |
150 | 149 | return HAVE_TE_FP8_TENSOR_CLASS and isinstance(tensor, FP8_TENSOR_CLASS) |
151 | 150 |
|
152 | 151 |
|
153 | | -def get_mesh_names(device_mesh: Optional[DeviceMesh] = None) -> list[str]: |
| 152 | +def get_mesh_names( |
| 153 | + device_mesh: Optional[DeviceMesh] = None, only_submesh_dims: bool = False |
| 154 | +) -> list[str]: |
154 | 155 | """ |
155 | | - Get all the sub-mesh names in the DeviceMesh. |
| 156 | + Get all the sub-mesh ("dp", "cp", etc.) and flattened-mesh ("dp_cp", etc.) names |
| 157 | + in the DeviceMesh. When only_submesh_dims=True, only checks for sub-mesh dimensions. |
156 | 158 | """ |
157 | 159 | if device_mesh is None: |
158 | 160 | # Device mesh does not exist. |
159 | 161 | return [] |
160 | | - # Order of the returned list of mesh dimension names must match the order / index |
161 | | - # of the root mesh dimension names followed by children / flattened sub-meshes: |
162 | | - # [<root mesh dimension names>, <child mesh dimension names>] |
163 | | - mesh_dim_names = ( |
| 162 | + |
| 163 | + # Sub-mesh dimension names. |
| 164 | + submesh_dim_names = ( |
164 | 165 | list(device_mesh.mesh_dim_names) if device_mesh.mesh_dim_names is not None else [] |
165 | 166 | ) |
166 | | - submesh_dim_names = [ |
167 | | - submesh_dim_name |
168 | | - for child_mesh, root_mesh in _mesh_resources.child_to_root_mapping.items() |
169 | | - for submesh_dim_name in (child_mesh.mesh_dim_names or []) |
170 | | - if root_mesh == device_mesh |
171 | | - ] |
172 | | - # Combine without duplicate dimensions. |
173 | | - for dim_name in submesh_dim_names: |
174 | | - if dim_name not in mesh_dim_names: |
175 | | - mesh_dim_names.append(dim_name) |
176 | | - return mesh_dim_names |
| 167 | + |
| 168 | + # Flattened mesh dimension names. |
| 169 | + try: |
| 170 | + # Retrieve all flattened meshes associated with DeviceMesh. |
| 171 | + # The flattened DeviceMesh are all located in the _flatten_mapping |
| 172 | + # dictionary of the root DeviceMesh. |
| 173 | + flatten_mesh_names = [ |
| 174 | + flat_dim |
| 175 | + for flat_dim, flat_mesh in device_mesh._get_root_mesh()._flatten_mapping.items() |
| 176 | + ] |
| 177 | + except AttributeError: |
| 178 | + # Fallback to the DeviceMesh global state to retrieve flattened |
| 179 | + # meshes associated with the DeviceMesh. |
| 180 | + from torch.distributed.device_mesh import _mesh_resources |
| 181 | + |
| 182 | + flatten_mesh_names = [ |
| 183 | + child_mesh_dim_name |
| 184 | + for child_mesh, root_mesh in _mesh_resources.child_to_root_mapping.items() |
| 185 | + for child_mesh_dim_name in (child_mesh.mesh_dim_names or []) |
| 186 | + if root_mesh == device_mesh and child_mesh_dim_name not in submesh_dim_names |
| 187 | + ] |
| 188 | + |
| 189 | + # Order of the returned list of mesh dimension names must match the index |
| 190 | + # of the root mesh dimension names followed by flattened sub-meshes: |
| 191 | + # [<root mesh dimension names>, <flattened mesh dimension names>] |
| 192 | + if only_submesh_dims: |
| 193 | + return submesh_dim_names |
| 194 | + else: |
| 195 | + return submesh_dim_names + flatten_mesh_names |
177 | 196 |
|
178 | 197 |
|
179 | 198 | def contains_submesh( |
@@ -720,16 +739,14 @@ def __init__( |
720 | 739 | self.hybrid_fsdp_group = hybrid_fsdp_group |
721 | 740 |
|
722 | 741 | """ |
723 | | - Store a persistent reference to the core device meshes that back Megatron-FSDP. |
724 | | - This is necessary because _MeshEnv (_mesh_resources) may not persist: |
725 | | - - _mesh_resources.child_to_root_mapping |
726 | | - - _mesh_resources.root_to_flatten_mapping |
727 | | - - _mesh_resources.flatten_name_to_root_dims |
728 | | - - ... |
729 | | - during Torch Autograd, so child and flattened sub-meshes may be cleared. |
730 | | - For example, this breaks Megatron-FSDP when self.dp_shard_dim is the flattened |
731 | | - sub-mesh of the DP and CP root mesh dimensions. |
732 | | - FIXME(@cspades): Identify the root cause of this behavior. |
| 742 | + Megatron-FSDP is responsible for storing all required DeviceMesh |
| 743 | + as per best practices recommended by the DeviceMesh API. |
| 744 | +
|
| 745 | + NOTE(@cspades): In PyTorch 2.11, retrieving flattened mesh dimensions |
| 746 | + will be impossible via the device_mesh[...] API. We will require all |
| 747 | + users to correctly _unflatten() their DeviceMesh such that all |
| 748 | + dimensions used by Megatron-FSDP are sub-meshes of the DeviceMesh. |
| 749 | + contains_submesh(...) -> get_mesh_names(only_submesh_dims=True). |
733 | 750 | """ |
734 | 751 | self.mesh_library = {} |
735 | 752 | # TP Mesh |
@@ -825,6 +842,9 @@ def get_outer_fsdp_group(self) -> ProcessGroup: |
825 | 842 |
|
826 | 843 | def get_root_mesh(self, is_expert_parallel: bool = False) -> DeviceMesh: |
827 | 844 | """Get the device mesh.""" |
| 845 | + # NOTE(@cspades): This is FSDPDistributedIndex's root mesh, NOT the actual |
| 846 | + # root mesh that the DeviceMesh or expert DeviceMesh was un-flattened from. |
| 847 | + # To get the root mesh, use: DeviceMesh._get_root_mesh(). |
828 | 848 | if is_expert_parallel: |
829 | 849 | raise NotImplementedError("Expert parallel is not supported in Megatron-FSDP.") |
830 | 850 | return self.device_mesh |
|
0 commit comments