Skip to content

Commit ce44736

Browse files
committed
add expert all-gather process-group for overlapping
1 parent a25421e commit ce44736

File tree

4 files changed

+112
-22
lines changed

4 files changed

+112
-22
lines changed

megatron/core/distributed/fsdp/src/megatron_fsdp/param_and_grad_buffer.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1602,9 +1602,28 @@ def __init__(
16021602
if self.dist_index.get_outer_fsdp_group() is not None:
16031603
# Outer/Inter-FSDP group when using hybrid FSDP
16041604
self.ubr_groups.append(self.dist_index.get_outer_fsdp_group())
1605-
if self.dist_index.get_fsdp_group(is_expert_parallel=False, independent_all_gather=True) is not None:
1605+
if (
1606+
self.dist_index.get_fsdp_group(
1607+
is_expert_parallel=False, independent_all_gather=True
1608+
)
1609+
is not None
1610+
):
16061611
# All-gather group used when overlapping all-gather and gradient reduction.
1607-
self.ubr_groups.append(self.dist_index.get_fsdp_group(is_expert_parallel=False, independent_all_gather=True))
1612+
self.ubr_groups.append(
1613+
self.dist_index.get_fsdp_group(
1614+
is_expert_parallel=False, independent_all_gather=True
1615+
)
1616+
)
1617+
if (
1618+
self.dist_index.get_fsdp_group(is_expert_parallel=True, independent_all_gather=True)
1619+
is not None
1620+
):
1621+
# Expert all-gather group used when overlapping all-gather and gradient reduction.
1622+
self.ubr_groups.append(
1623+
self.dist_index.get_fsdp_group(
1624+
is_expert_parallel=True, independent_all_gather=True
1625+
)
1626+
)
16081627

16091628
if torch.distributed.get_rank() == 0:
16101629
logging.info(
@@ -1896,9 +1915,9 @@ def _init_each_parameter_group_buffers(self, meta_device_init_fp8_params):
18961915
# operations (main_grad_buffer). This avoids head-of-line blocking between forward
18971916
# all-gather and backward reduce-scatter on the same communicator.
18981917
model_wbuf_dp_group = main_buf_dp_group
1899-
if not group.is_expert_param and not should_create_hfsdp_wbuf_and_gbuf:
1918+
if not should_create_hfsdp_wbuf_and_gbuf:
19001919
ag_group = self.dist_index.get_fsdp_group(
1901-
is_expert_parallel=False, independent_all_gather=True
1920+
is_expert_parallel=group.is_expert_param, independent_all_gather=True
19021921
)
19031922
if ag_group is not None:
19041923
model_wbuf_dp_group = ag_group

megatron/core/distributed/fsdp/src/megatron_fsdp/utils.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -498,8 +498,7 @@ def __init__(
498498
self.fsdp_group_ag = None
499499
if HAVE_MEGATRON_CORE and parallel_state.has_separate_all_gather_group():
500500
self.fsdp_group_ag = parallel_state.get_data_parallel_group(
501-
with_context_parallel=True,
502-
independent_all_gather=True
501+
with_context_parallel=True, independent_all_gather=True
503502
)
504503
# Retrieve the outer-FSDP process group from the DeviceMesh.
505504
self.outer_fsdp_group = (
@@ -518,6 +517,12 @@ def __init__(
518517
and contains_submesh(self.expt_device_mesh, self.dp_shard_dim)
519518
else None
520519
)
520+
# Expert AG group for overlap
521+
self.expt_fsdp_group_ag = None
522+
if HAVE_MEGATRON_CORE and parallel_state.has_separate_expert_all_gather_group():
523+
self.expt_fsdp_group_ag = parallel_state.get_expert_data_parallel_group(
524+
independent_all_gather=True
525+
)
521526

522527
"""
523528
Megatron-FSDP is responsible for storing all required DeviceMesh
@@ -635,9 +640,13 @@ def get_dp_group(self, is_expert_parallel: bool = False) -> ProcessGroup:
635640
return self.hybrid_fsdp_group
636641
return self.fsdp_group
637642

638-
def get_fsdp_group(self, is_expert_parallel: bool = False, independent_all_gather: bool = False) -> ProcessGroup:
643+
def get_fsdp_group(
644+
self, is_expert_parallel: bool = False, independent_all_gather: bool = False
645+
) -> ProcessGroup:
639646
"""Get the FSDP process group."""
640647
if is_expert_parallel:
648+
if independent_all_gather:
649+
return self.expt_fsdp_group_ag
641650
return self.expt_fsdp_group
642651
if independent_all_gather:
643652
return self.fsdp_group_ag

megatron/core/parallel_state.py

Lines changed: 41 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@
6060
# Expert data parallel group
6161
_EXPERT_DATA_PARALLEL_GROUP = None
6262
_EXPERT_DATA_PARALLEL_GROUP_GLOO = None
63+
_EXPERT_DATA_PARALLEL_GROUP_AG = None
6364
_INTRA_PARTIAL_EXPERT_DATA_PARALLEL_GROUP = None
6465
_INTRA_PARTIAL_EXPERT_DATA_PARALLEL_GROUP_GLOO = None
6566
_INTER_PARTIAL_EXPERT_DATA_PARALLEL_GROUP = None
@@ -1207,6 +1208,10 @@ def initialize_model_parallel(
12071208
assert _EXPERT_DATA_PARALLEL_GROUP is None, "Expert data group is already initialized"
12081209
global _EXPERT_DATA_PARALLEL_GROUP_GLOO
12091210
assert _EXPERT_DATA_PARALLEL_GROUP_GLOO is None, "Expert data group-gloo is already initialized"
1211+
global _EXPERT_DATA_PARALLEL_GROUP_AG
1212+
assert (
1213+
_EXPERT_DATA_PARALLEL_GROUP_AG is None
1214+
), "Expert data parallel group with AG is already initialized"
12101215
global _INTRA_PARTIAL_EXPERT_DATA_PARALLEL_GROUP
12111216
assert (
12121217
_INTRA_PARTIAL_EXPERT_DATA_PARALLEL_GROUP is None
@@ -1240,10 +1245,20 @@ def initialize_model_parallel(
12401245
)
12411246
else:
12421247
group_gloo = None
1248+
# Create separate all-gather group for expert data parallelism to enable overlap
1249+
if create_all_gather_group:
1250+
group_ag = create_group(
1251+
ranks,
1252+
timeout=timeout,
1253+
pg_options=get_nccl_options("ep_dp", nccl_comm_cfgs),
1254+
group_desc="EXPERT_DATA_PARALLEL_GROUP_AG",
1255+
)
1256+
else:
1257+
group_ag = None
12431258
if rank in ranks:
12441259
_EXPERT_DATA_PARALLEL_GROUP = group
12451260
_EXPERT_DATA_PARALLEL_GROUP_GLOO = group_gloo
1246-
1261+
_EXPERT_DATA_PARALLEL_GROUP_AG = group_ag
12471262
if num_distributed_optimizer_instances > 1:
12481263
# Create groups for Partial DistOpt, one for intra-partial DP domain
12491264
# Another for inter-partial DP domain
@@ -1365,7 +1380,9 @@ def get_pipeline_model_parallel_group(check_initialized=True):
13651380
return _PIPELINE_MODEL_PARALLEL_GROUP
13661381

13671382

1368-
def get_data_parallel_group(with_context_parallel=False, partial_data_parallel=False, independent_all_gather=False):
1383+
def get_data_parallel_group(
1384+
with_context_parallel=False, partial_data_parallel=False, independent_all_gather=False
1385+
):
13691386
"""Get the data-parallel group the caller rank belongs to."""
13701387
if with_context_parallel:
13711388
if partial_data_parallel:
@@ -1390,13 +1407,22 @@ def get_data_parallel_group(with_context_parallel=False, partial_data_parallel=F
13901407

13911408
def has_separate_all_gather_group() -> bool:
13921409
"""Check if a separate all-gather process group has been created.
1393-
1410+
13941411
Returns True if a dedicated all-gather process group exists for improved
13951412
communication overlap, False otherwise.
13961413
"""
13971414
return _DATA_PARALLEL_GROUP_WITH_CP_AG is not None
13981415

13991416

1417+
def has_separate_expert_all_gather_group() -> bool:
1418+
"""Check if a separate all-gather process group for experts has been created.
1419+
1420+
Returns True if a dedicated all-gather process group for expert parallelism exists
1421+
for improved communication overlap, False otherwise.
1422+
"""
1423+
return _EXPERT_DATA_PARALLEL_GROUP_AG is not None
1424+
1425+
14001426
def get_data_parallel_group_gloo(with_context_parallel=False, partial_data_parallel=False):
14011427
"""Get the Gloo data-parallel group the caller rank belongs to."""
14021428
if with_context_parallel:
@@ -1886,8 +1912,16 @@ def get_expert_tensor_model_pipeline_parallel_group(check_initialized=True):
18861912
return _EXPERT_TENSOR_MODEL_PIPELINE_PARALLEL_GROUP
18871913

18881914

1889-
def get_expert_data_parallel_group(check_initialized=True, partial_expert_data_parallel=False):
1915+
def get_expert_data_parallel_group(
1916+
check_initialized=True, partial_expert_data_parallel=False, independent_all_gather=False
1917+
):
18901918
"""Get expert data parallel group."""
1919+
if independent_all_gather:
1920+
if check_initialized:
1921+
assert (
1922+
_EXPERT_DATA_PARALLEL_GROUP_AG is not None
1923+
), "Expert data parallel group with AG is not initialized"
1924+
return _EXPERT_DATA_PARALLEL_GROUP_AG
18911925
if partial_expert_data_parallel:
18921926
if check_initialized:
18931927
assert (
@@ -2155,6 +2189,9 @@ def destroy_model_parallel():
21552189
torch.distributed.destroy_process_group(_EXPERT_DATA_PARALLEL_GROUP_GLOO)
21562190
_EXPERT_DATA_PARALLEL_GROUP_GLOO = None
21572191

2192+
global _EXPERT_DATA_PARALLEL_GROUP_AG
2193+
_EXPERT_DATA_PARALLEL_GROUP_AG = None
2194+
21582195
global _INTRA_PARTIAL_EXPERT_DATA_PARALLEL_GROUP
21592196
_INTRA_PARTIAL_EXPERT_DATA_PARALLEL_GROUP = None
21602197

tests/unit_tests/test_parallel_state.py

Lines changed: 36 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -504,26 +504,18 @@ def golden_rank_result_from_past_code(
504504
def test_separate_all_gather_group():
505505
"""Test separate all-gather group for improved communication overlap."""
506506
# Test without creating AG group (default)
507-
Utils.initialize_model_parallel(
508-
context_parallel_size=world_size,
509-
create_all_gather_group=False,
510-
)
507+
Utils.initialize_model_parallel(context_parallel_size=world_size, create_all_gather_group=False)
511508
assert not ps.has_separate_all_gather_group()
512509
assert ps._DATA_PARALLEL_GROUP_WITH_CP_AG is None
513510
Utils.destroy_model_parallel()
514511

515512
# Test with creating AG group
516-
Utils.initialize_model_parallel(
517-
context_parallel_size=world_size,
518-
create_all_gather_group=True,
519-
)
513+
Utils.initialize_model_parallel(context_parallel_size=world_size, create_all_gather_group=True)
520514
assert ps.has_separate_all_gather_group()
521515
assert ps._DATA_PARALLEL_GROUP_WITH_CP_AG is not None
522516

523517
# Verify it returns the correct group
524-
ag_group = ps.get_data_parallel_group(
525-
with_context_parallel=True, independent_all_gather=True
526-
)
518+
ag_group = ps.get_data_parallel_group(with_context_parallel=True, independent_all_gather=True)
527519
regular_group = ps.get_data_parallel_group(
528520
with_context_parallel=True, independent_all_gather=False
529521
)
@@ -535,3 +527,36 @@ def test_separate_all_gather_group():
535527
assert ag_ranks == regular_ranks
536528

537529
Utils.destroy_model_parallel()
530+
531+
532+
@pytest.mark.parametrize('order', test_parallel_order)
533+
@pytest.mark.flaky
534+
@pytest.mark.flaky_in_dev
535+
def test_separate_expert_all_gather_group(order):
536+
"""Test separate all-gather group for expert parallelism to enable communication overlap."""
537+
# Test without creating expert AG group (default)
538+
Utils.initialize_model_parallel(
539+
expert_model_parallel_size=world_size, create_all_gather_group=False, order=order
540+
)
541+
assert not ps.has_separate_expert_all_gather_group()
542+
assert ps._EXPERT_DATA_PARALLEL_GROUP_AG is None
543+
Utils.destroy_model_parallel()
544+
545+
# Test with creating expert AG group
546+
Utils.initialize_model_parallel(
547+
expert_model_parallel_size=world_size, create_all_gather_group=True, order=order
548+
)
549+
assert ps.has_separate_expert_all_gather_group()
550+
assert ps._EXPERT_DATA_PARALLEL_GROUP_AG is not None
551+
552+
# Verify it returns the correct group
553+
ag_group = ps.get_expert_data_parallel_group(independent_all_gather=True)
554+
regular_group = ps.get_expert_data_parallel_group(independent_all_gather=False)
555+
assert ag_group is not None
556+
assert regular_group is not None
557+
# They should have the same ranks but different communicators
558+
ag_ranks = torch.distributed.get_process_group_ranks(ag_group)
559+
regular_ranks = torch.distributed.get_process_group_ranks(regular_group)
560+
assert ag_ranks == regular_ranks
561+
562+
Utils.destroy_model_parallel()

0 commit comments

Comments
 (0)