Skip to content

Commit cfdfe8b

Browse files
committed
Enable AG/RS overlap with explicit process group passing
1 parent a0cc8ca commit cfdfe8b

File tree

9 files changed

+158
-83
lines changed

9 files changed

+158
-83
lines changed

megatron/core/distributed/fsdp/mcore_fsdp_adapter.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,12 @@ def _init_dist_index(self, pg_collection):
242242
single_rank_group = dist.new_group(ranks=[dist.get_rank()])
243243
expt_tp_group = single_rank_group
244244

245+
# Extract AG groups from pg_collection for explicit passing
246+
dp_cp_ag = getattr(pg_collection, 'dp_cp_ag', None) if pg_collection is not None else None
247+
expt_dp_ag = (
248+
getattr(pg_collection, 'expt_dp_ag', None) if pg_collection is not None else None
249+
)
250+
245251
if enable_hsdp:
246252
mesh = _get_hsdp_tp_mesh(outer_fsdp_group, dp_cp_group, tp_group)
247253
dist_index = FSDPDistributedIndex(
@@ -256,6 +262,8 @@ def _init_dist_index(self, pg_collection):
256262
dp_shard_dim="dp_cp",
257263
tp_dim="tp",
258264
hybrid_fsdp_group=hybrid_fsdp_group,
265+
fsdp_group_ag=dp_cp_ag,
266+
expt_fsdp_group_ag=expt_dp_ag,
259267
)
260268
else:
261269
if ep_group is not None:
@@ -280,6 +288,8 @@ def _init_dist_index(self, pg_collection):
280288
dp_shard_dim="dp_cp",
281289
tp_dim="tp",
282290
expt_device_mesh=expt_device_mesh,
291+
fsdp_group_ag=dp_cp_ag,
292+
expt_fsdp_group_ag=expt_dp_ag,
283293
)
284294

285295
self.tp_group = tp_group

megatron/core/distributed/fsdp/src/megatron_fsdp/fully_shard.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,8 @@ def fully_shard_model(
7878
tp_dim: Optional[str] = None,
7979
hybrid_fsdp_group: Optional[torch.distributed.ProcessGroup] = None,
8080
expt_device_mesh: Optional[DeviceMesh] = None,
81+
fsdp_group_ag: Optional[torch.distributed.ProcessGroup] = None,
82+
expt_fsdp_group_ag: Optional[torch.distributed.ProcessGroup] = None,
8183
fsdp_unit_modules: Optional[Sequence[Type[torch.nn.Module]] | Sequence[str]] = None,
8284
zero_dp_strategy: str | int = 3,
8385
outer_dp_sharding_strategy: str | int = 0,
@@ -139,6 +141,17 @@ class that schedules the sharding lifecycle of the model parameters and gradient
139141
Expert parallel device mesh object defining the topology for MoE distributed training.
140142
Utilizes the mesh dimension names specified by the *_dim arguments.
141143
144+
fsdp_group_ag (Optional[torch.distributed.ProcessGroup]):
145+
Independent all-gather process group for overlapping all-gather and reduce-scatter
146+
operations. When provided, enables AG/RS overlap optimization for regular (non-expert)
147+
parameters. Users should create this group with the same ranks as the dp-cp group.
148+
Defaults to None.
149+
150+
expt_fsdp_group_ag (Optional[torch.distributed.ProcessGroup]):
151+
Independent all-gather process group for expert parameters in MoE models. When provided,
152+
enables AG/RS overlap optimization for expert parameters. Users should create this group
153+
with the same ranks as the expert data parallel group. Defaults to None.
154+
142155
fsdp_unit_modules (Optional[Sequence[Type[torch.nn.Module]] | Sequence[str]]):
143156
List of (sub-)module classes or (sub-)module class import paths that are "units",
144157
which are torch.nn.Module(s) that are sharded and scheduled by Megatron-FSDP.
@@ -356,6 +369,9 @@ class that schedules the sharding lifecycle of the model parameters and gradient
356369
hsdp_outer_dp_shard=_outer_fsdp_sharding,
357370
# Only required for Megatron-FSDP + EP.
358371
expt_device_mesh=expt_device_mesh,
372+
# AG groups for AG/RS overlap optimization.
373+
fsdp_group_ag=fsdp_group_ag,
374+
expt_fsdp_group_ag=expt_fsdp_group_ag,
359375
)
360376

361377
# Wrap model in Megatron FSDP.
@@ -522,6 +538,8 @@ def fully_shard(
522538
tp_dim: Optional[str] = None,
523539
hybrid_fsdp_group: Optional[torch.distributed.ProcessGroup] = None,
524540
expt_device_mesh: Optional[DeviceMesh] = None,
541+
fsdp_group_ag: Optional[torch.distributed.ProcessGroup] = None,
542+
expt_fsdp_group_ag: Optional[torch.distributed.ProcessGroup] = None,
525543
fsdp_unit_modules: Optional[Sequence[Type[torch.nn.Module]] | Sequence[str]] = None,
526544
zero_dp_strategy: str | int = 3,
527545
outer_dp_sharding_strategy: str | int = 0,
@@ -569,6 +587,8 @@ def fully_shard(
569587
tp_dim=tp_dim,
570588
hybrid_fsdp_group=hybrid_fsdp_group,
571589
expt_device_mesh=expt_device_mesh,
590+
fsdp_group_ag=fsdp_group_ag,
591+
expt_fsdp_group_ag=expt_fsdp_group_ag,
572592
fsdp_unit_modules=fsdp_unit_modules,
573593
zero_dp_strategy=zero_dp_strategy,
574594
outer_dp_sharding_strategy=outer_dp_sharding_strategy,

megatron/core/distributed/fsdp/src/megatron_fsdp/param_and_grad_buffer.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1631,6 +1631,16 @@ def __init__(
16311631
is_expert_parallel=False, independent_all_gather=True
16321632
)
16331633
)
1634+
if (
1635+
self.dist_index.get_fsdp_group(is_expert_parallel=True, independent_all_gather=True)
1636+
is not None
1637+
):
1638+
# Expert all-gather group used when overlapping all-gather and gradient reduction.
1639+
self.ubr_groups.append(
1640+
self.dist_index.get_fsdp_group(
1641+
is_expert_parallel=True, independent_all_gather=True
1642+
)
1643+
)
16341644

16351645
log_single_rank(
16361646
logger,
@@ -1920,14 +1930,14 @@ def _init_each_parameter_group_buffers(self, meta_device_init_fp8_params):
19201930
is_expert_parallel=group.is_expert_param
19211931
)
19221932

1923-
# When --create-all-gather-group is enabled, use a separate process group for
1924-
# all-gather operations (model_weight_buffer) to enable overlap with gradient reduction
1925-
# operations (main_grad_buffer). This avoids head-of-line blocking between forward
1926-
# all-gather and backward reduce-scatter on the same communicator.
1933+
# Use separate process group for all-gather operations (model_weight_buffer)
1934+
# to enable overlap with gradient reduction operations (main_grad_buffer).
1935+
# This avoids head-of-line blocking between forward all-gather and backward
1936+
# reduce-scatter on the same communicator.
19271937
model_wbuf_dp_group = main_buf_dp_group
1928-
if not group.is_expert_param and not should_create_hfsdp_wbuf_and_gbuf:
1938+
if not should_create_hfsdp_wbuf_and_gbuf:
19291939
ag_group = self.dist_index.get_fsdp_group(
1930-
is_expert_parallel=False, independent_all_gather=True
1940+
is_expert_parallel=group.is_expert_param, independent_all_gather=True
19311941
)
19321942
if ag_group is not None:
19331943
model_wbuf_dp_group = ag_group

megatron/core/distributed/fsdp/src/megatron_fsdp/utils.py

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,6 @@
2121
from importlib.metadata import version
2222
from typing import Callable, Optional, Sequence, Union
2323

24-
try:
25-
import megatron.core.parallel_state as parallel_state
26-
27-
HAVE_MEGATRON_CORE = True
28-
except (ImportError, ModuleNotFoundError):
29-
HAVE_MEGATRON_CORE = False
30-
3124
try:
3225
import einops
3326

@@ -452,6 +445,8 @@ def __init__(
452445
hybrid_fsdp_group: Optional[torch.distributed.ProcessGroup] = None,
453446
hsdp_outer_dp_shard: bool = False,
454447
expt_device_mesh: Optional[DeviceMesh] = None,
448+
fsdp_group_ag: Optional[torch.distributed.ProcessGroup] = None,
449+
expt_fsdp_group_ag: Optional[torch.distributed.ProcessGroup] = None,
455450
):
456451
"""
457452
Args:
@@ -470,6 +465,13 @@ def __init__(
470465
just sharding across dp_shard ranks and replicating across dp_outer ranks.
471466
expt_device_mesh (Optional[DeviceMesh]): The expert parallel device mesh
472467
to use for the DistributedIndex.
468+
fsdp_group_ag (Optional[torch.distributed.ProcessGroup]): Independent all-gather
469+
process group for overlapping all-gather and reduce-scatter operations.
470+
When provided, enables AG/RS overlap optimization for regular (non-expert)
471+
parameters.
472+
expt_fsdp_group_ag (Optional[torch.distributed.ProcessGroup]): Independent all-gather
473+
process group for expert parameters in MoE models. When provided, enables AG/RS
474+
overlap optimization for expert parameters.
473475
"""
474476
# Device mesh arguments.
475477
self.device_mesh = device_mesh
@@ -493,13 +495,10 @@ def __init__(
493495
if contains_submesh(self.device_mesh, self.dp_shard_dim)
494496
else None
495497
)
496-
# AG group comes from parallel_state, not the mesh
497-
# the purpose of this independent group is to overlap all-gather and gradient reduction.
498-
self.fsdp_group_ag = None
499-
if HAVE_MEGATRON_CORE and parallel_state.has_separate_all_gather_group():
500-
self.fsdp_group_ag = parallel_state.get_data_parallel_group(
501-
with_context_parallel=True, independent_all_gather=True
502-
)
498+
# AG groups passed as explicit arguments
499+
# The purpose of independent AG groups is to overlap all-gather and reduce-scatter.
500+
self.fsdp_group_ag = fsdp_group_ag
501+
self.expt_fsdp_group_ag = expt_fsdp_group_ag
503502
# Retrieve the outer-FSDP process group from the DeviceMesh.
504503
self.outer_fsdp_group = (
505504
self.device_mesh[self.dp_outer_dim].get_group()
@@ -639,6 +638,8 @@ def get_fsdp_group(
639638
) -> ProcessGroup:
640639
"""Get the FSDP process group."""
641640
if is_expert_parallel:
641+
if independent_all_gather:
642+
return self.expt_fsdp_group_ag
642643
return self.expt_fsdp_group
643644
if independent_all_gather:
644645
return self.fsdp_group_ag

megatron/core/parallel_state.py

Lines changed: 1 addition & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,6 @@
120120

121121
# Data parallel group information with context parallel combined.
122122
_DATA_PARALLEL_GROUP_WITH_CP = None
123-
_DATA_PARALLEL_GROUP_WITH_CP_AG = None
124123
_DATA_PARALLEL_GROUP_WITH_CP_GLOO = None
125124
_DATA_PARALLEL_GLOBAL_RANKS_WITH_CP = None
126125

@@ -567,7 +566,6 @@ def initialize_model_parallel(
567566
create_gloo_process_groups: bool = True,
568567
high_priority_stream_groups: Optional[List[str]] = None,
569568
sharp_enabled_group: Optional[str] = None,
570-
create_all_gather_group: Optional[bool] = False,
571569
) -> None:
572570
"""Initialize model data parallel groups.
573571
@@ -682,13 +680,6 @@ def initialize_model_parallel(
682680
By default (None), it is enabled from dp group.
683681
Available options (choose one): [dp, dp_replica]
684682
685-
create_all_gather_group (bool, default = False):
686-
Create a separate process group for all-gather operations to avoid
687-
head-of-line blocking with reduce-scatter operations. When enabled,
688-
creates an additional NCCL communicator with identical ranks as the
689-
dp-cp group but with independent progress engines for better communication
690-
overlap.
691-
692683
Let's say we have a total of 16 GPUs denoted by g0 ... g15 and we
693684
use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize
694685
the model pipeline. The present function will
@@ -825,7 +816,6 @@ def initialize_model_parallel(
825816
global _DATA_PARALLEL_GROUP_GLOO
826817
global _DATA_PARALLEL_GLOBAL_RANKS
827818
global _DATA_PARALLEL_GROUP_WITH_CP
828-
global _DATA_PARALLEL_GROUP_WITH_CP_AG
829819
global _DATA_PARALLEL_GROUP_WITH_CP_GLOO
830820
global _DATA_PARALLEL_GLOBAL_RANKS_WITH_CP
831821
global _INTRA_PARTIAL_DATA_PARALLEL_GROUP_WITH_CP
@@ -857,15 +847,6 @@ def initialize_model_parallel(
857847
pg_options=get_nccl_options("dp_cp", nccl_comm_cfgs),
858848
group_desc="DATA_PARALLEL_GROUP_WITH_CP",
859849
)
860-
if create_all_gather_group:
861-
group_with_cp_ag = create_group(
862-
ranks_with_cp,
863-
timeout=timeout,
864-
pg_options=get_nccl_options("dp_cp", nccl_comm_cfgs),
865-
group_desc="DATA_PARALLEL_GROUP_WITH_CP_AG",
866-
)
867-
else:
868-
group_with_cp_ag = None
869850
if create_gloo_process_groups:
870851
group_with_cp_gloo = create_group(
871852
ranks_with_cp,
@@ -877,7 +858,6 @@ def initialize_model_parallel(
877858
group_with_cp_gloo = None
878859
if rank in ranks_with_cp:
879860
_DATA_PARALLEL_GROUP_WITH_CP = group_with_cp
880-
_DATA_PARALLEL_GROUP_WITH_CP_AG = group_with_cp_ag
881861
_DATA_PARALLEL_GROUP_WITH_CP_GLOO = group_with_cp_gloo
882862
_DATA_PARALLEL_GLOBAL_RANKS_WITH_CP = ranks_with_cp
883863

@@ -1407,21 +1387,14 @@ def get_pipeline_model_parallel_group(check_initialized=True):
14071387
return _PIPELINE_MODEL_PARALLEL_GROUP
14081388

14091389

1410-
def get_data_parallel_group(
1411-
with_context_parallel=False, partial_data_parallel=False, independent_all_gather=False
1412-
):
1390+
def get_data_parallel_group(with_context_parallel=False, partial_data_parallel=False):
14131391
"""Get the data-parallel group the caller rank belongs to."""
14141392
if with_context_parallel:
14151393
if partial_data_parallel:
14161394
assert (
14171395
_INTRA_PARTIAL_DATA_PARALLEL_GROUP_WITH_CP is not None
14181396
), "Intra partial data parallel group is not initialized"
14191397
return _INTRA_PARTIAL_DATA_PARALLEL_GROUP_WITH_CP
1420-
if independent_all_gather:
1421-
assert (
1422-
_DATA_PARALLEL_GROUP_WITH_CP_AG is not None
1423-
), "data parallel group with context parallel AG is not initialized"
1424-
return _DATA_PARALLEL_GROUP_WITH_CP_AG
14251398
assert (
14261399
_DATA_PARALLEL_GROUP_WITH_CP is not None
14271400
), "data parallel group with context parallel combined is not initialized"
@@ -1432,15 +1405,6 @@ def get_data_parallel_group(
14321405
return _DATA_PARALLEL_GROUP
14331406

14341407

1435-
def has_separate_all_gather_group() -> bool:
1436-
"""Check if a separate all-gather process group has been created.
1437-
1438-
Returns True if a dedicated all-gather process group exists for improved
1439-
communication overlap, False otherwise.
1440-
"""
1441-
return _DATA_PARALLEL_GROUP_WITH_CP_AG is not None
1442-
1443-
14441408
def get_data_parallel_group_gloo(with_context_parallel=False, partial_data_parallel=False):
14451409
"""Get the Gloo data-parallel group the caller rank belongs to."""
14461410
if with_context_parallel:
@@ -2101,9 +2065,6 @@ def destroy_model_parallel():
21012065
global _DATA_PARALLEL_GROUP_WITH_CP
21022066
_DATA_PARALLEL_GROUP_WITH_CP = None
21032067

2104-
global _DATA_PARALLEL_GROUP_WITH_CP_AG
2105-
_DATA_PARALLEL_GROUP_WITH_CP_AG = None
2106-
21072068
global _CONTEXT_PARALLEL_GROUP
21082069
_CONTEXT_PARALLEL_GROUP = None
21092070

megatron/core/process_groups_config.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,13 +114,19 @@ class ProcessGroupCollection:
114114
# _DATA_PARALLEL_GROUP_WITH_CP
115115
dp_cp: torch.distributed.ProcessGroup = field(init=False)
116116

117+
# _DATA_PARALLEL_GROUP_WITH_CP_AG
118+
dp_cp_ag: torch.distributed.ProcessGroup = field(init=False)
119+
117120
# MoE layers need expt_dp group for sharded state dict
118121
# we need this workaround until distributed checkpoint is refactored
119122
# to have sharded_state_dict can take the PG and pass it down
120123
# TODO (Hepteract): remove this once distributed checkpoint is refactored
121124
# _EXPERT_DATA_PARALLEL_GROUP
122125
expt_dp: torch.distributed.ProcessGroup = field(init=False)
123126

127+
# _EXPERT_DATA_PARALLEL_GROUP_AG
128+
expt_dp_ag: torch.distributed.ProcessGroup = field(init=False)
129+
124130
# _INTRA_PARTIAL_DATA_PARALLEL_GROUP_WITH_CP
125131
intra_dp_cp: torch.distributed.ProcessGroup = field(init=False)
126132

@@ -210,6 +216,7 @@ def use_mpu_process_groups(cls, required_pgs: Optional[List[str]] = None):
210216
),
211217
'dp': parallel_state.get_data_parallel_group,
212218
'dp_cp': partial(parallel_state.get_data_parallel_group, with_context_parallel=True),
219+
'dp_cp_ag': lambda: None, # AG groups should be created in user code
213220
'intra_dp_cp': partial(
214221
parallel_state.get_data_parallel_group,
215222
with_context_parallel=True,
@@ -232,6 +239,7 @@ def use_mpu_process_groups(cls, required_pgs: Optional[List[str]] = None):
232239
'expt_dp': partial(
233240
parallel_state.get_expert_data_parallel_group, check_initialized=False
234241
),
242+
'expt_dp_ag': lambda: None, # Expert AG groups should be created in user code
235243
'tp_dp_cp': partial(
236244
parallel_state.get_tensor_and_data_parallel_group,
237245
check_initialized=False,

megatron/training/arguments.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2301,9 +2301,6 @@ def _add_distributed_args(parser):
23012301
help='IB SHARP can be enabled from only one communication group. '
23022302
'By default, it is enabled from dp group. '
23032303
'Available options: [dp, dp_replica]')
2304-
group.add_argument('--create-all-gather-group', action='store_true',
2305-
help='Create a separate process group for all-gather operations '
2306-
'to overlap reduce-scatter and all-gather operations.')
23072304
group.add_argument('--use-megatron-fsdp', action='store_true',
23082305
help='Use the Megatron FSDP code path in DDP.')
23092306
group.add_argument('--data-parallel-sharding-strategy', type=str, default='no_shard',

megatron/training/initialize.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -381,7 +381,6 @@ def _initialize_distributed(get_embedding_ranks, get_position_embedding_ranks, s
381381
create_gloo_process_groups=args.enable_gloo_process_groups,
382382
high_priority_stream_groups=args.high_priority_stream_groups,
383383
sharp_enabled_group=args.sharp_enabled_group,
384-
create_all_gather_group=args.create_all_gather_group,
385384
)
386385
print_rank_0(
387386
f"> initialized tensor model parallel with size "

0 commit comments

Comments
 (0)