Skip to content

Commit 58c2d96

Browse files
committed
Enable AG/RS overlap with explicit process group passing
1 parent 43db8c1 commit 58c2d96

File tree

10 files changed

+268
-84
lines changed

10 files changed

+268
-84
lines changed

megatron/core/distributed/fsdp/mcore_fsdp_adapter.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,12 @@ def _init_dist_index(self, pg_collection):
253253
single_rank_group = dist.new_group(ranks=[dist.get_rank()])
254254
expt_tp_group = single_rank_group
255255

256+
# Extract AG groups from pg_collection for explicit passing
257+
dp_cp_ag = getattr(pg_collection, 'dp_cp_ag', None) if pg_collection is not None else None
258+
expt_dp_ag = (
259+
getattr(pg_collection, 'expt_dp_ag', None) if pg_collection is not None else None
260+
)
261+
256262
if enable_hsdp:
257263
if expt_dp_group is not None:
258264
expt_mesh = _get_hsdp_tp_mesh(
@@ -281,6 +287,8 @@ def _init_dist_index(self, pg_collection):
281287
hybrid_fsdp_group=hybrid_fsdp_group,
282288
hybrid_fsdp_expt_group=hybrid_fsdp_expt_group,
283289
expt_device_mesh=expt_device_mesh,
290+
fsdp_group_ag=dp_cp_ag,
291+
expt_fsdp_group_ag=expt_dp_ag,
284292
)
285293
else:
286294
if ep_group is not None:
@@ -305,6 +313,8 @@ def _init_dist_index(self, pg_collection):
305313
dp_shard_dim="dp_cp",
306314
tp_dim="tp",
307315
expt_device_mesh=expt_device_mesh,
316+
fsdp_group_ag=dp_cp_ag,
317+
expt_fsdp_group_ag=expt_dp_ag,
308318
)
309319

310320
self.tp_group = tp_group

megatron/core/distributed/fsdp/src/megatron_fsdp/fully_shard.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,8 @@ def fully_shard_model(
7979
hybrid_fsdp_group: Optional[torch.distributed.ProcessGroup] = None,
8080
hybrid_fsdp_expt_group: Optional[torch.distributed.ProcessGroup] = None,
8181
expt_device_mesh: Optional[DeviceMesh] = None,
82+
fsdp_group_ag: Optional[torch.distributed.ProcessGroup] = None,
83+
expt_fsdp_group_ag: Optional[torch.distributed.ProcessGroup] = None,
8284
fsdp_unit_modules: Optional[Sequence[Type[torch.nn.Module]] | Sequence[str]] = None,
8385
zero_dp_strategy: str | int = 3,
8486
outer_dp_sharding_strategy: str | int = 0,
@@ -141,6 +143,17 @@ class that schedules the sharding lifecycle of the model parameters and gradient
141143
Expert parallel device mesh object defining the topology for MoE distributed training.
142144
Utilizes the mesh dimension names specified by the *_dim arguments.
143145
146+
fsdp_group_ag (Optional[torch.distributed.ProcessGroup]):
147+
Independent all-gather process group for overlapping all-gather and reduce-scatter
148+
operations. When provided, enables AG/RS overlap optimization for regular (non-expert)
149+
parameters. Users should create this group with the same ranks as the dp-cp group.
150+
Defaults to None.
151+
152+
expt_fsdp_group_ag (Optional[torch.distributed.ProcessGroup]):
153+
Independent all-gather process group for expert parameters in MoE models. When provided,
154+
enables AG/RS overlap optimization for expert parameters. Users should create this group
155+
with the same ranks as the expert data parallel group. Defaults to None.
156+
144157
fsdp_unit_modules (Optional[Sequence[Type[torch.nn.Module]] | Sequence[str]]):
145158
List of (sub-)module classes or (sub-)module class import paths that are "units",
146159
which are torch.nn.Module(s) that are sharded and scheduled by Megatron-FSDP.
@@ -365,6 +378,9 @@ class that schedules the sharding lifecycle of the model parameters and gradient
365378
hsdp_outer_dp_shard=_outer_fsdp_sharding,
366379
# Only required for Megatron-FSDP + EP.
367380
expt_device_mesh=expt_device_mesh,
381+
# AG groups for AG/RS overlap optimization.
382+
fsdp_group_ag=fsdp_group_ag,
383+
expt_fsdp_group_ag=expt_fsdp_group_ag,
368384
)
369385

370386
# Wrap model in Megatron FSDP.
@@ -532,6 +548,8 @@ def fully_shard(
532548
hybrid_fsdp_group: Optional[torch.distributed.ProcessGroup] = None,
533549
hybrid_fsdp_expt_group: Optional[torch.distributed.ProcessGroup] = None,
534550
expt_device_mesh: Optional[DeviceMesh] = None,
551+
fsdp_group_ag: Optional[torch.distributed.ProcessGroup] = None,
552+
expt_fsdp_group_ag: Optional[torch.distributed.ProcessGroup] = None,
535553
fsdp_unit_modules: Optional[Sequence[Type[torch.nn.Module]] | Sequence[str]] = None,
536554
zero_dp_strategy: str | int = 3,
537555
outer_dp_sharding_strategy: str | int = 0,
@@ -581,6 +599,8 @@ def fully_shard(
581599
hybrid_fsdp_group=hybrid_fsdp_group,
582600
hybrid_fsdp_expt_group=hybrid_fsdp_expt_group,
583601
expt_device_mesh=expt_device_mesh,
602+
fsdp_group_ag=fsdp_group_ag,
603+
expt_fsdp_group_ag=expt_fsdp_group_ag,
584604
fsdp_unit_modules=fsdp_unit_modules,
585605
zero_dp_strategy=zero_dp_strategy,
586606
outer_dp_sharding_strategy=outer_dp_sharding_strategy,

megatron/core/distributed/fsdp/src/megatron_fsdp/param_and_grad_buffer.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1665,6 +1665,16 @@ def __init__(
16651665
is_expert_parallel=False, independent_all_gather=True
16661666
)
16671667
)
1668+
if (
1669+
self.dist_index.get_fsdp_group(is_expert_parallel=True, independent_all_gather=True)
1670+
is not None
1671+
):
1672+
# Expert all-gather group used when overlapping all-gather and gradient reduction.
1673+
self.ubr_groups.append(
1674+
self.dist_index.get_fsdp_group(
1675+
is_expert_parallel=True, independent_all_gather=True
1676+
)
1677+
)
16681678

16691679
log_single_rank(
16701680
logger,
@@ -1962,14 +1972,14 @@ def _init_each_parameter_group_buffers(self, meta_device_init_fp8_params):
19621972
is_expert_parallel=group.is_expert_param
19631973
)
19641974

1965-
# When --create-all-gather-group is enabled, use a separate process group for
1966-
# all-gather operations (model_weight_buffer) to enable overlap with gradient reduction
1967-
# operations (main_grad_buffer). This avoids head-of-line blocking between forward
1968-
# all-gather and backward reduce-scatter on the same communicator.
1975+
# Use separate process group for all-gather operations (model_weight_buffer)
1976+
# to enable overlap with gradient reduction operations (main_grad_buffer).
1977+
# This avoids head-of-line blocking between forward all-gather and backward
1978+
# reduce-scatter on the same communicator.
19691979
model_wbuf_dp_group = main_buf_dp_group
1970-
if not group.is_expert_param and not should_create_hfsdp_wbuf_and_gbuf:
1980+
if not should_create_hfsdp_wbuf_and_gbuf:
19711981
ag_group = self.dist_index.get_fsdp_group(
1972-
is_expert_parallel=False, independent_all_gather=True
1982+
is_expert_parallel=group.is_expert_param, independent_all_gather=True
19731983
)
19741984
if ag_group is not None:
19751985
model_wbuf_dp_group = ag_group

megatron/core/distributed/fsdp/src/megatron_fsdp/utils.py

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,6 @@
2121
from importlib.metadata import version
2222
from typing import Callable, Optional, Sequence, Union
2323

24-
try:
25-
import megatron.core.parallel_state as parallel_state
26-
27-
HAVE_MEGATRON_CORE = True
28-
except (ImportError, ModuleNotFoundError):
29-
HAVE_MEGATRON_CORE = False
30-
3124
try:
3225
import einops
3326

@@ -453,6 +446,8 @@ def __init__(
453446
hybrid_fsdp_expt_group: Optional[torch.distributed.ProcessGroup] = None,
454447
hsdp_outer_dp_shard: bool = False,
455448
expt_device_mesh: Optional[DeviceMesh] = None,
449+
fsdp_group_ag: Optional[torch.distributed.ProcessGroup] = None,
450+
expt_fsdp_group_ag: Optional[torch.distributed.ProcessGroup] = None,
456451
):
457452
"""
458453
Args:
@@ -474,6 +469,13 @@ def __init__(
474469
just sharding across dp_shard ranks and replicating across dp_outer ranks.
475470
expt_device_mesh (Optional[DeviceMesh]): The expert parallel device mesh
476471
to use for the DistributedIndex.
472+
fsdp_group_ag (Optional[torch.distributed.ProcessGroup]): Independent all-gather
473+
process group for overlapping all-gather and reduce-scatter operations.
474+
When provided, enables AG/RS overlap optimization for regular (non-expert)
475+
parameters.
476+
expt_fsdp_group_ag (Optional[torch.distributed.ProcessGroup]): Independent all-gather
477+
process group for expert parameters in MoE models. When provided, enables AG/RS
478+
overlap optimization for expert parameters.
477479
"""
478480
# Device mesh arguments.
479481
self.device_mesh = device_mesh
@@ -497,13 +499,10 @@ def __init__(
497499
if contains_submesh(self.device_mesh, self.dp_shard_dim)
498500
else None
499501
)
500-
# AG group comes from parallel_state, not the mesh
501-
# the purpose of this independent group is to overlap all-gather and gradient reduction.
502-
self.fsdp_group_ag = None
503-
if HAVE_MEGATRON_CORE and parallel_state.has_separate_all_gather_group():
504-
self.fsdp_group_ag = parallel_state.get_data_parallel_group(
505-
with_context_parallel=True, independent_all_gather=True
506-
)
502+
# AG groups passed as explicit arguments
503+
# The purpose of independent AG groups is to overlap all-gather and reduce-scatter.
504+
self.fsdp_group_ag = fsdp_group_ag
505+
self.expt_fsdp_group_ag = expt_fsdp_group_ag
507506
# Retrieve the outer-FSDP process group from the DeviceMesh.
508507
self.outer_fsdp_group = (
509508
self.device_mesh[self.dp_outer_dim].get_group()
@@ -655,6 +654,8 @@ def get_fsdp_group(
655654
) -> ProcessGroup:
656655
"""Get the FSDP process group."""
657656
if is_expert_parallel:
657+
if independent_all_gather:
658+
return self.expt_fsdp_group_ag
658659
return self.expt_fsdp_group
659660
if independent_all_gather:
660661
return self.fsdp_group_ag

0 commit comments

Comments
 (0)