Skip to content

Commit 1014220

Browse files
committed
Cherry-pick bug-fixes into 0.15.X.
Signed-off-by: Cory Ye <cye@nvidia.com>
1 parent 1221b91 commit 1014220

File tree

7 files changed

+476
-284
lines changed

7 files changed

+476
-284
lines changed

megatron/core/distributed/fsdp/src/README.md

Lines changed: 99 additions & 32 deletions
Large diffs are not rendered by default.

megatron/core/distributed/fsdp/src/megatron_fsdp/fully_shard.py

Lines changed: 220 additions & 166 deletions
Large diffs are not rendered by default.

megatron/core/distributed/fsdp/src/megatron_fsdp/megatron_fsdp.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -283,8 +283,14 @@ def __init__(
283283
self._register_fsdp_hooks(self.module)
284284
self.microbatch_count = 0
285285

286+
# Add a reference from the distributed parameters to self for API
287+
# accessibility, e.g. when attaching MegatronFSDP scheduled ops
288+
# to the distributed optimizer.step() and optimizer.zero_grad().
286289
self.is_param_fsdp_distributed = False
287290
self._replace_param_with_distributed_if_needed()
291+
for param in self.module.parameters():
292+
# Attach MegatronFSDP reference to the parameter.
293+
setattr(param, "_megatron_fsdp_model", self)
288294

289295
def _check_module_parameter_types(self):
290296
"""

megatron/core/distributed/fsdp/src/megatron_fsdp/param_and_grad_buffer.py

Lines changed: 3 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@
3131
import torch
3232
from torch.distributed import _coalescing_manager
3333
from torch.distributed.tensor import DTensor, Replicate, Shard
34-
from torch.distributed.tensor.device_mesh import _mesh_resources
3534

3635
from .uneven_dtensor import update_uneven_dtensor_chunk_metadata, validate_uneven_dtensor
3736
from .utils import _MODEL_PARALLEL_RNG_TRACKER_NAME, FSDPDistributedIndex, get_global_memory_buffer
@@ -94,7 +93,7 @@ def _p_assert(cond: Any, s: str, raise_assertion_error: bool = True) -> None:
9493
message ``s`` since otherwise, it is swallowed.
9594
"""
9695
if not cond:
97-
print(s)
96+
logger.warning(s)
9897
traceback.print_stack()
9998
if raise_assertion_error:
10099
raise AssertionError(s)
@@ -205,7 +204,7 @@ def __exit__(self, *args):
205204
for group in self.groups[1:]:
206205
backend = group._get_backend(torch.device("cuda", torch.cuda.current_device()))
207206
if torch.distributed.get_rank() == 0:
208-
print(
207+
logger.info(
209208
f"[MultiGroupUBRAllocator] Registering mem pool to group {group}, "
210209
f"group.group_desc:{group.group_desc}"
211210
)
@@ -3525,20 +3524,6 @@ def _get_fsdp_tensor_spec(param, dist_index: FSDPDistributedIndex, is_sharded_pa
35253524
if isinstance(param, DTensor) and cast(DTensor, param)._spec.num_shards > 1:
35263525
# Retrieve original DTensorSpec (for TP).
35273526
dtensor_spec = cast(DTensor, param)._spec
3528-
dtensor_mesh = getattr(dtensor_spec, "mesh", None)
3529-
3530-
# Validate that the DTensor root mesh is identical to the Megatron-FSDP device mesh.
3531-
megatron_fsdp_global_mesh = dist_index.get_root_mesh()
3532-
dtensor_global_mesh = _mesh_resources.get_root_mesh(dtensor_mesh)
3533-
# FIXME(boxiangw): add or megatron_fsdp_global_mesh != dtensor_global_mesh:
3534-
# _mesh_resources.get_root_mesh(dtensor_mesh) is not getting the correct root mesh
3535-
if dtensor_global_mesh is None:
3536-
raise ValueError(
3537-
f"When utilizing DTensor-based modules with Megatron-FSDP, the DTensor root "
3538-
f"device mesh must be identical to the Megatron-FSDP root device mesh.\n"
3539-
f"DTensor Root Mesh: {dtensor_global_mesh} / Megatron-FSDP "
3540-
f"Root Mesh: {megatron_fsdp_global_mesh}"
3541-
)
35423527

35433528
# Get the placements for the parameter.
35443529
assert len(dtensor_spec.placements) == 1, (
@@ -3724,7 +3709,7 @@ def make_fsdp_dtensor(
37243709
device_mesh=tp_mesh,
37253710
placements=[Shard(tp_dim)],
37263711
run_check=run_check,
3727-
shape=global_shape,
3712+
shape=tuple(global_shape),
37283713
stride=torch.empty(global_shape).stride(),
37293714
)
37303715

megatron/core/distributed/fsdp/src/megatron_fsdp/uneven_dtensor.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@
2525
from torch.distributed.checkpoint.planner import TensorWriteData, WriteItem, WriteItemType
2626
from torch.distributed.tensor.placement_types import Replicate, Shard, _StridedShard
2727

28+
from .utils import get_mesh_names
29+
2830

2931
def gather_and_compute_chunk_metadata(dtensor: DTensor) -> ChunkStorageMetadata:
3032
"""
@@ -272,7 +274,25 @@ def gather_uneven_dtensor_to_full_tensor(
272274
if not device_mesh.mesh_dim_names:
273275
process_group = device_mesh.get_group()
274276
else:
275-
process_group = device_mesh._flatten().get_group()
277+
# Check if the fully-flattened mesh exists first.
278+
full_flattened_mesh_dim_name = "_".join(device_mesh.mesh_dim_names)
279+
if full_flattened_mesh_dim_name in get_mesh_names(device_mesh):
280+
# Retrieve the existing flattened DeviceMesh ProcessGroup.
281+
try:
282+
# Two Cases: Name is a root dimension, or using the old DeviceMesh
283+
# API which allows us to get flattened dimensions.
284+
process_group = device_mesh[full_flattened_mesh_dim_name].get_group()
285+
except:
286+
# Name is a flattened dimension that cannot be retrieved from the
287+
# DeviceMesh.__getitem__, so fall-back to new DeviceMesh API.
288+
process_group = (
289+
device_mesh._get_root_mesh()
290+
._flatten_mapping[full_flattened_mesh_dim_name]
291+
.get_group()
292+
)
293+
else:
294+
# Create the _-separated flattened DeviceMesh ProcessGroup.
295+
process_group = device_mesh._flatten().get_group()
276296

277297
# Collect chunk metadata for uneven shards (update if missing)
278298
if not hasattr(dtensor._local_tensor, "__create_chunk_list__"):

megatron/core/distributed/fsdp/src/megatron_fsdp/utils.py

Lines changed: 48 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@
3434
from torch.cuda import _lazy_call, _lazy_init
3535
from torch.cuda import device as device_ctx_manager
3636
from torch.distributed import DeviceMesh, ProcessGroup
37-
from torch.distributed.device_mesh import _mesh_resources
3837

3938
logger = logging.getLogger(__name__)
4039

@@ -150,30 +149,50 @@ def is_float8tensor(tensor: torch.Tensor) -> bool:
150149
return HAVE_TE_FP8_TENSOR_CLASS and isinstance(tensor, FP8_TENSOR_CLASS)
151150

152151

153-
def get_mesh_names(device_mesh: Optional[DeviceMesh] = None) -> list[str]:
152+
def get_mesh_names(
153+
device_mesh: Optional[DeviceMesh] = None, only_submesh_dims: bool = False
154+
) -> list[str]:
154155
"""
155-
Get all the sub-mesh names in the DeviceMesh.
156+
Get all the sub-mesh ("dp", "cp", etc.) and flattened-mesh ("dp_cp", etc.) names
157+
in the DeviceMesh. When only_submesh_dims=True, only checks for sub-mesh dimensions.
156158
"""
157159
if device_mesh is None:
158160
# Device mesh does not exist.
159161
return []
160-
# Order of the returned list of mesh dimension names must match the order / index
161-
# of the root mesh dimension names followed by children / flattened sub-meshes:
162-
# [<root mesh dimension names>, <child mesh dimension names>]
163-
mesh_dim_names = (
162+
163+
# Sub-mesh dimension names.
164+
submesh_dim_names = (
164165
list(device_mesh.mesh_dim_names) if device_mesh.mesh_dim_names is not None else []
165166
)
166-
submesh_dim_names = [
167-
submesh_dim_name
168-
for child_mesh, root_mesh in _mesh_resources.child_to_root_mapping.items()
169-
for submesh_dim_name in (child_mesh.mesh_dim_names or [])
170-
if root_mesh == device_mesh
171-
]
172-
# Combine without duplicate dimensions.
173-
for dim_name in submesh_dim_names:
174-
if dim_name not in mesh_dim_names:
175-
mesh_dim_names.append(dim_name)
176-
return mesh_dim_names
167+
168+
# Flattened mesh dimension names.
169+
try:
170+
# Retrieve all flattened meshes associated with DeviceMesh.
171+
# The flattened DeviceMesh are all located in the _flatten_mapping
172+
# dictionary of the root DeviceMesh.
173+
flatten_mesh_names = [
174+
flat_dim
175+
for flat_dim, flat_mesh in device_mesh._get_root_mesh()._flatten_mapping.items()
176+
]
177+
except AttributeError:
178+
# Fallback to the DeviceMesh global state to retrieve flattened
179+
# meshes associated with the DeviceMesh.
180+
from torch.distributed.device_mesh import _mesh_resources
181+
182+
flatten_mesh_names = [
183+
child_mesh_dim_name
184+
for child_mesh, root_mesh in _mesh_resources.child_to_root_mapping.items()
185+
for child_mesh_dim_name in (child_mesh.mesh_dim_names or [])
186+
if root_mesh == device_mesh and child_mesh_dim_name not in submesh_dim_names
187+
]
188+
189+
# Order of the returned list of mesh dimension names must match the index
190+
# of the root mesh dimension names followed by flattened sub-meshes:
191+
# [<root mesh dimension names>, <flattened mesh dimension names>]
192+
if only_submesh_dims:
193+
return submesh_dim_names
194+
else:
195+
return submesh_dim_names + flatten_mesh_names
177196

178197

179198
def contains_submesh(
@@ -720,16 +739,14 @@ def __init__(
720739
self.hybrid_fsdp_group = hybrid_fsdp_group
721740

722741
"""
723-
Store a persistent reference to the core device meshes that back Megatron-FSDP.
724-
This is necessary because _MeshEnv (_mesh_resources) may not persist:
725-
- _mesh_resources.child_to_root_mapping
726-
- _mesh_resources.root_to_flatten_mapping
727-
- _mesh_resources.flatten_name_to_root_dims
728-
- ...
729-
during Torch Autograd, so child and flattened sub-meshes may be cleared.
730-
For example, this breaks Megatron-FSDP when self.dp_shard_dim is the flattened
731-
sub-mesh of the DP and CP root mesh dimensions.
732-
FIXME(@cspades): Identify the root cause of this behavior.
742+
Megatron-FSDP is responsible for storing all required DeviceMesh
743+
as per best practices recommended by the DeviceMesh API.
744+
745+
NOTE(@cspades): In PyTorch 2.11, retrieving flattened mesh dimensions
746+
will be impossible via the device_mesh[...] API. We will require all
747+
users to correctly _unflatten() their DeviceMesh such that all
748+
dimensions used by Megatron-FSDP are sub-meshes of the DeviceMesh.
749+
contains_submesh(...) -> get_mesh_names(only_submesh_dims=True).
733750
"""
734751
self.mesh_library = {}
735752
# TP Mesh
@@ -825,6 +842,9 @@ def get_outer_fsdp_group(self) -> ProcessGroup:
825842

826843
def get_root_mesh(self, is_expert_parallel: bool = False) -> DeviceMesh:
827844
"""Get the device mesh."""
845+
# NOTE(@cspades): This is FSDPDistributedIndex's root mesh, NOT the actual
846+
# root mesh that the DeviceMesh or expert DeviceMesh was un-flattened from.
847+
# To get the root mesh, use: DeviceMesh._get_root_mesh().
828848
if is_expert_parallel:
829849
raise NotImplementedError("Expert parallel is not supported in Megatron-FSDP.")
830850
return self.device_mesh

0 commit comments

Comments
 (0)