Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions megatron/core/distributed/fsdp/mcore_fsdp_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,12 @@ def _init_dist_index(self, pg_collection):
single_rank_group = dist.new_group(ranks=[dist.get_rank()])
expt_tp_group = single_rank_group

# Extract AG groups from pg_collection for explicit passing
dp_cp_ag = getattr(pg_collection, 'dp_cp_ag', None) if pg_collection is not None else None
expt_dp_ag = (
getattr(pg_collection, 'expt_dp_ag', None) if pg_collection is not None else None
)

if enable_hsdp:
if expt_dp_group is not None:
expt_mesh = _get_hsdp_tp_mesh(
Expand Down Expand Up @@ -311,6 +317,8 @@ def _init_dist_index(self, pg_collection):
hybrid_fsdp_group=hybrid_fsdp_group,
hybrid_fsdp_expt_group=hybrid_fsdp_expt_group,
expt_device_mesh=expt_device_mesh,
fsdp_group_ag=dp_cp_ag,
expt_fsdp_group_ag=expt_dp_ag,
)
else:
if ep_group is not None:
Expand All @@ -335,6 +343,8 @@ def _init_dist_index(self, pg_collection):
dp_shard_dim="dp_cp",
tp_dim="tp",
expt_device_mesh=expt_device_mesh,
fsdp_group_ag=dp_cp_ag,
expt_fsdp_group_ag=expt_dp_ag,
)

self.tp_group = tp_group
Expand Down
20 changes: 20 additions & 0 deletions megatron/core/distributed/fsdp/src/megatron_fsdp/fully_shard.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,8 @@ def fully_shard_model(
hybrid_fsdp_group: Optional[torch.distributed.ProcessGroup] = None,
hybrid_fsdp_expt_group: Optional[torch.distributed.ProcessGroup] = None,
expt_device_mesh: Optional[DeviceMesh] = None,
fsdp_group_ag: Optional[torch.distributed.ProcessGroup] = None,
expt_fsdp_group_ag: Optional[torch.distributed.ProcessGroup] = None,
fsdp_unit_modules: Optional[Sequence[Type[torch.nn.Module]] | Sequence[str]] = None,
zero_dp_strategy: str | int = 3,
outer_dp_sharding_strategy: str | int = 0,
Expand Down Expand Up @@ -143,6 +145,17 @@ class that schedules the sharding lifecycle of the model parameters and gradient
Expert parallel device mesh object defining the topology for MoE distributed training.
Utilizes the mesh dimension names specified by the *_dim arguments.

fsdp_group_ag (Optional[torch.distributed.ProcessGroup]):
Independent all-gather process group for overlapping all-gather and reduce-scatter
operations. When provided, enables AG/RS overlap optimization for regular (non-expert)
parameters. Users should create this group with the same ranks as the dp-cp group.
Defaults to None.

expt_fsdp_group_ag (Optional[torch.distributed.ProcessGroup]):
Independent all-gather process group for expert parameters in MoE models. When provided,
enables AG/RS overlap optimization for expert parameters. Users should create this group
with the same ranks as the expert data parallel group. Defaults to None.

fsdp_unit_modules (Optional[Sequence[Type[torch.nn.Module]] | Sequence[str]]):
List of (sub-)module classes or (sub-)module class import paths that are "units",
which are torch.nn.Module(s) that are sharded and scheduled by Megatron-FSDP.
Expand Down Expand Up @@ -368,6 +381,9 @@ class that schedules the sharding lifecycle of the model parameters and gradient
hsdp_outer_dp_shard=_outer_fsdp_sharding,
# Only required for Megatron-FSDP + EP.
expt_device_mesh=expt_device_mesh,
# AG groups for AG/RS overlap optimization.
fsdp_group_ag=fsdp_group_ag,
expt_fsdp_group_ag=expt_fsdp_group_ag,
)

# Wrap model in Megatron FSDP.
Expand Down Expand Up @@ -627,6 +643,8 @@ def fully_shard(
hybrid_fsdp_group: Optional[torch.distributed.ProcessGroup] = None,
hybrid_fsdp_expt_group: Optional[torch.distributed.ProcessGroup] = None,
expt_device_mesh: Optional[DeviceMesh] = None,
fsdp_group_ag: Optional[torch.distributed.ProcessGroup] = None,
expt_fsdp_group_ag: Optional[torch.distributed.ProcessGroup] = None,
fsdp_unit_modules: Optional[Sequence[Type[torch.nn.Module]] | Sequence[str]] = None,
zero_dp_strategy: str | int = 3,
outer_dp_sharding_strategy: str | int = 0,
Expand Down Expand Up @@ -676,6 +694,8 @@ def fully_shard(
hybrid_fsdp_group=hybrid_fsdp_group,
hybrid_fsdp_expt_group=hybrid_fsdp_expt_group,
expt_device_mesh=expt_device_mesh,
fsdp_group_ag=fsdp_group_ag,
expt_fsdp_group_ag=expt_fsdp_group_ag,
fsdp_unit_modules=fsdp_unit_modules,
zero_dp_strategy=zero_dp_strategy,
outer_dp_sharding_strategy=outer_dp_sharding_strategy,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1692,9 +1692,6 @@ def __init__(
if self.dist_index.get_fsdp_group(is_expert_parallel=True) is not None:
# Expert-DP group when using EP
self.ubr_groups.append(self.dist_index.get_fsdp_group(is_expert_parallel=True))
if self.dist_index.get_outer_fsdp_group() is not None:
# Outer/Inter-FSDP group when using hybrid FSDP
self.ubr_groups.append(self.dist_index.get_outer_fsdp_group())
if (
self.dist_index.get_fsdp_group(
is_expert_parallel=False, independent_all_gather=True
Expand All @@ -1707,6 +1704,19 @@ def __init__(
is_expert_parallel=False, independent_all_gather=True
)
)
if (
self.dist_index.get_fsdp_group(is_expert_parallel=True, independent_all_gather=True)
is not None
):
# Expert all-gather group used when overlapping all-gather and gradient reduction.
self.ubr_groups.append(
self.dist_index.get_fsdp_group(
is_expert_parallel=True, independent_all_gather=True
)
)
if self.dist_index.get_outer_fsdp_group() is not None:
# Outer/Inter-FSDP group when using hybrid FSDP (IB domain, registered last).
self.ubr_groups.append(self.dist_index.get_outer_fsdp_group())

log_single_rank(
logger,
Expand Down Expand Up @@ -2182,14 +2192,14 @@ def _init_each_parameter_group_buffers(self, meta_device_init_fp8_params):
is_expert_parallel=group.is_expert_param
)

# When --create-all-gather-group is enabled, use a separate process group for
# all-gather operations (model_weight_buffer) to enable overlap with gradient reduction
# operations (main_grad_buffer). This avoids head-of-line blocking between forward
# all-gather and backward reduce-scatter on the same communicator.
# Use separate process group for all-gather operations (model_weight_buffer)
# to enable overlap with gradient reduction operations (main_grad_buffer).
# This avoids head-of-line blocking between forward all-gather and backward
# reduce-scatter on the same communicator.
model_wbuf_dp_group = main_buf_dp_group
if not group.is_expert_param and not should_create_hfsdp_helper_buffers:
if not should_create_hfsdp_helper_buffers:
ag_group = self.dist_index.get_fsdp_group(
is_expert_parallel=False, independent_all_gather=True
is_expert_parallel=group.is_expert_param, independent_all_gather=True
)
if ag_group is not None:
model_wbuf_dp_group = ag_group
Expand Down
28 changes: 14 additions & 14 deletions megatron/core/distributed/fsdp/src/megatron_fsdp/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,6 @@
from importlib.metadata import version
from typing import Callable, Optional, Sequence, Union

try:
import megatron.core.parallel_state as parallel_state

HAVE_MEGATRON_CORE = True
except (ImportError, ModuleNotFoundError):
HAVE_MEGATRON_CORE = False

try:
import einops

Expand Down Expand Up @@ -481,6 +474,8 @@ def __init__(
hybrid_fsdp_expt_group: Optional[torch.distributed.ProcessGroup] = None,
hsdp_outer_dp_shard: bool = False,
expt_device_mesh: Optional[DeviceMesh] = None,
fsdp_group_ag: Optional[torch.distributed.ProcessGroup] = None,
expt_fsdp_group_ag: Optional[torch.distributed.ProcessGroup] = None,
):
"""
Args:
Expand All @@ -502,6 +497,13 @@ def __init__(
just sharding across dp_shard ranks and replicating across dp_outer ranks.
expt_device_mesh (Optional[DeviceMesh]): The expert parallel device mesh
to use for the DistributedIndex.
fsdp_group_ag (Optional[torch.distributed.ProcessGroup]): Independent all-gather
process group for overlapping all-gather and reduce-scatter operations.
When provided, enables AG/RS overlap optimization for regular (non-expert)
parameters.
expt_fsdp_group_ag (Optional[torch.distributed.ProcessGroup]): Independent all-gather
process group for expert parameters in MoE models. When provided, enables AG/RS
overlap optimization for expert parameters.
"""
# Device mesh arguments.
self.device_mesh = device_mesh
Expand All @@ -525,13 +527,9 @@ def __init__(
if contains_submesh(self.device_mesh, self.dp_shard_dim)
else None
)
# AG group comes from parallel_state, not the mesh
# the purpose of this independent group is to overlap all-gather and gradient reduction.
self.fsdp_group_ag = None
if HAVE_MEGATRON_CORE and parallel_state.has_separate_all_gather_group():
self.fsdp_group_ag = parallel_state.get_data_parallel_group(
with_context_parallel=True, independent_all_gather=True
)
# AG groups: supplied via ProcessGroupCollection (Megatron-FSDP entrypoint).
self.fsdp_group_ag = fsdp_group_ag
self.expt_fsdp_group_ag = expt_fsdp_group_ag
# Retrieve the outer-FSDP process group from the DeviceMesh.
self.outer_fsdp_group = (
self.device_mesh[self.dp_outer_dim].get_group()
Expand Down Expand Up @@ -676,6 +674,8 @@ def get_fsdp_group(
) -> ProcessGroup:
"""Get the FSDP process group."""
if is_expert_parallel:
if independent_all_gather:
return self.expt_fsdp_group_ag
return self.expt_fsdp_group
if independent_all_gather:
return self.fsdp_group_ag
Expand Down
Loading
Loading