Skip to content

Commit c5b9670

Browse files
committed
add expert all-gather process-group for overlapping
1 parent 259ce47 commit c5b9670

File tree

4 files changed

+112
-22
lines changed

4 files changed

+112
-22
lines changed

megatron/core/distributed/fsdp/src/megatron_fsdp/param_and_grad_buffer.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1602,9 +1602,28 @@ def __init__(
16021602
if self.dist_index.get_outer_fsdp_group() is not None:
16031603
# Outer/Inter-FSDP group when using hybrid FSDP
16041604
self.ubr_groups.append(self.dist_index.get_outer_fsdp_group())
1605-
if self.dist_index.get_fsdp_group(is_expert_parallel=False, independent_all_gather=True) is not None:
1605+
if (
1606+
self.dist_index.get_fsdp_group(
1607+
is_expert_parallel=False, independent_all_gather=True
1608+
)
1609+
is not None
1610+
):
16061611
# All-gather group used when overlapping all-gather and gradient reduction.
1607-
self.ubr_groups.append(self.dist_index.get_fsdp_group(is_expert_parallel=False, independent_all_gather=True))
1612+
self.ubr_groups.append(
1613+
self.dist_index.get_fsdp_group(
1614+
is_expert_parallel=False, independent_all_gather=True
1615+
)
1616+
)
1617+
if (
1618+
self.dist_index.get_fsdp_group(is_expert_parallel=True, independent_all_gather=True)
1619+
is not None
1620+
):
1621+
# Expert all-gather group used when overlapping all-gather and gradient reduction.
1622+
self.ubr_groups.append(
1623+
self.dist_index.get_fsdp_group(
1624+
is_expert_parallel=True, independent_all_gather=True
1625+
)
1626+
)
16081627

16091628
if torch.distributed.get_rank() == 0:
16101629
logging.info(
@@ -1896,9 +1915,9 @@ def _init_each_parameter_group_buffers(self, meta_device_init_fp8_params):
18961915
# operations (main_grad_buffer). This avoids head-of-line blocking between forward
18971916
# all-gather and backward reduce-scatter on the same communicator.
18981917
model_wbuf_dp_group = main_buf_dp_group
1899-
if not group.is_expert_param and not should_create_hfsdp_wbuf_and_gbuf:
1918+
if not should_create_hfsdp_wbuf_and_gbuf:
19001919
ag_group = self.dist_index.get_fsdp_group(
1901-
is_expert_parallel=False, independent_all_gather=True
1920+
is_expert_parallel=group.is_expert_param, independent_all_gather=True
19021921
)
19031922
if ag_group is not None:
19041923
model_wbuf_dp_group = ag_group

megatron/core/distributed/fsdp/src/megatron_fsdp/utils.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -498,8 +498,7 @@ def __init__(
498498
self.fsdp_group_ag = None
499499
if HAVE_MEGATRON_CORE and parallel_state.has_separate_all_gather_group():
500500
self.fsdp_group_ag = parallel_state.get_data_parallel_group(
501-
with_context_parallel=True,
502-
independent_all_gather=True
501+
with_context_parallel=True, independent_all_gather=True
503502
)
504503
# Retrieve the outer-FSDP process group from the DeviceMesh.
505504
self.outer_fsdp_group = (
@@ -518,6 +517,12 @@ def __init__(
518517
and contains_submesh(self.expt_device_mesh, self.dp_shard_dim)
519518
else None
520519
)
520+
# Expert AG group for overlap
521+
self.expt_fsdp_group_ag = None
522+
if HAVE_MEGATRON_CORE and parallel_state.has_separate_expert_all_gather_group():
523+
self.expt_fsdp_group_ag = parallel_state.get_expert_data_parallel_group(
524+
independent_all_gather=True
525+
)
521526

522527
"""
523528
Megatron-FSDP is responsible for storing all required DeviceMesh
@@ -635,9 +640,13 @@ def get_dp_group(self, is_expert_parallel: bool = False) -> ProcessGroup:
635640
return self.hybrid_fsdp_group
636641
return self.fsdp_group
637642

638-
def get_fsdp_group(self, is_expert_parallel: bool = False, independent_all_gather: bool = False) -> ProcessGroup:
643+
def get_fsdp_group(
644+
self, is_expert_parallel: bool = False, independent_all_gather: bool = False
645+
) -> ProcessGroup:
639646
"""Get the FSDP process group."""
640647
if is_expert_parallel:
648+
if independent_all_gather:
649+
return self.expt_fsdp_group_ag
641650
return self.expt_fsdp_group
642651
if independent_all_gather:
643652
return self.fsdp_group_ag

megatron/core/parallel_state.py

Lines changed: 41 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@
6161
# Expert data parallel group
6262
_EXPERT_DATA_PARALLEL_GROUP = None
6363
_EXPERT_DATA_PARALLEL_GROUP_GLOO = None
64+
_EXPERT_DATA_PARALLEL_GROUP_AG = None
6465
_INTRA_PARTIAL_EXPERT_DATA_PARALLEL_GROUP = None
6566
_INTRA_PARTIAL_EXPERT_DATA_PARALLEL_GROUP_GLOO = None
6667
_INTER_PARTIAL_EXPERT_DATA_PARALLEL_GROUP = None
@@ -1249,6 +1250,10 @@ def initialize_model_parallel(
12491250
assert _EXPERT_DATA_PARALLEL_GROUP is None, "Expert data group is already initialized"
12501251
global _EXPERT_DATA_PARALLEL_GROUP_GLOO
12511252
assert _EXPERT_DATA_PARALLEL_GROUP_GLOO is None, "Expert data group-gloo is already initialized"
1253+
global _EXPERT_DATA_PARALLEL_GROUP_AG
1254+
assert (
1255+
_EXPERT_DATA_PARALLEL_GROUP_AG is None
1256+
), "Expert data parallel group with AG is already initialized"
12521257
global _INTRA_PARTIAL_EXPERT_DATA_PARALLEL_GROUP
12531258
assert (
12541259
_INTRA_PARTIAL_EXPERT_DATA_PARALLEL_GROUP is None
@@ -1282,10 +1287,20 @@ def initialize_model_parallel(
12821287
)
12831288
else:
12841289
group_gloo = None
1290+
# Create separate all-gather group for expert data parallelism to enable overlap
1291+
if create_all_gather_group:
1292+
group_ag = create_group(
1293+
ranks,
1294+
timeout=timeout,
1295+
pg_options=get_nccl_options("ep_dp", nccl_comm_cfgs),
1296+
group_desc="EXPERT_DATA_PARALLEL_GROUP_AG",
1297+
)
1298+
else:
1299+
group_ag = None
12851300
if rank in ranks:
12861301
_EXPERT_DATA_PARALLEL_GROUP = group
12871302
_EXPERT_DATA_PARALLEL_GROUP_GLOO = group_gloo
1288-
1303+
_EXPERT_DATA_PARALLEL_GROUP_AG = group_ag
12891304
if num_distributed_optimizer_instances > 1:
12901305
# Create groups for Partial DistOpt, one for intra-partial DP domain
12911306
# Another for inter-partial DP domain
@@ -1407,7 +1422,9 @@ def get_pipeline_model_parallel_group(check_initialized=True):
14071422
return _PIPELINE_MODEL_PARALLEL_GROUP
14081423

14091424

1410-
def get_data_parallel_group(with_context_parallel=False, partial_data_parallel=False, independent_all_gather=False):
1425+
def get_data_parallel_group(
1426+
with_context_parallel=False, partial_data_parallel=False, independent_all_gather=False
1427+
):
14111428
"""Get the data-parallel group the caller rank belongs to."""
14121429
if with_context_parallel:
14131430
if partial_data_parallel:
@@ -1432,13 +1449,22 @@ def get_data_parallel_group(with_context_parallel=False, partial_data_parallel=F
14321449

14331450
def has_separate_all_gather_group() -> bool:
14341451
"""Check if a separate all-gather process group has been created.
1435-
1452+
14361453
Returns True if a dedicated all-gather process group exists for improved
14371454
communication overlap, False otherwise.
14381455
"""
14391456
return _DATA_PARALLEL_GROUP_WITH_CP_AG is not None
14401457

14411458

1459+
def has_separate_expert_all_gather_group() -> bool:
1460+
"""Check if a separate all-gather process group for experts has been created.
1461+
1462+
Returns True if a dedicated all-gather process group for expert parallelism exists
1463+
for improved communication overlap, False otherwise.
1464+
"""
1465+
return _EXPERT_DATA_PARALLEL_GROUP_AG is not None
1466+
1467+
14421468
def get_data_parallel_group_gloo(with_context_parallel=False, partial_data_parallel=False):
14431469
"""Get the Gloo data-parallel group the caller rank belongs to."""
14441470
if with_context_parallel:
@@ -1940,8 +1966,16 @@ def get_expert_tensor_model_pipeline_parallel_group(check_initialized=True):
19401966
return _EXPERT_TENSOR_MODEL_PIPELINE_PARALLEL_GROUP
19411967

19421968

1943-
def get_expert_data_parallel_group(check_initialized=True, partial_expert_data_parallel=False):
1969+
def get_expert_data_parallel_group(
1970+
check_initialized=True, partial_expert_data_parallel=False, independent_all_gather=False
1971+
):
19441972
"""Get expert data parallel group."""
1973+
if independent_all_gather:
1974+
if check_initialized:
1975+
assert (
1976+
_EXPERT_DATA_PARALLEL_GROUP_AG is not None
1977+
), "Expert data parallel group with AG is not initialized"
1978+
return _EXPERT_DATA_PARALLEL_GROUP_AG
19451979
if partial_expert_data_parallel:
19461980
if check_initialized:
19471981
assert (
@@ -2209,6 +2243,9 @@ def destroy_model_parallel():
22092243
torch.distributed.destroy_process_group(_EXPERT_DATA_PARALLEL_GROUP_GLOO)
22102244
_EXPERT_DATA_PARALLEL_GROUP_GLOO = None
22112245

2246+
global _EXPERT_DATA_PARALLEL_GROUP_AG
2247+
_EXPERT_DATA_PARALLEL_GROUP_AG = None
2248+
22122249
global _INTRA_PARTIAL_EXPERT_DATA_PARALLEL_GROUP
22132250
_INTRA_PARTIAL_EXPERT_DATA_PARALLEL_GROUP = None
22142251

tests/unit_tests/test_parallel_state.py

Lines changed: 36 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -533,26 +533,18 @@ def test_hybrid_dp_cp_groups(world_size, tp_size, cp_size, dp_size):
533533
def test_separate_all_gather_group():
534534
"""Test separate all-gather group for improved communication overlap."""
535535
# Test without creating AG group (default)
536-
Utils.initialize_model_parallel(
537-
context_parallel_size=world_size,
538-
create_all_gather_group=False,
539-
)
536+
Utils.initialize_model_parallel(context_parallel_size=world_size, create_all_gather_group=False)
540537
assert not ps.has_separate_all_gather_group()
541538
assert ps._DATA_PARALLEL_GROUP_WITH_CP_AG is None
542539
Utils.destroy_model_parallel()
543540

544541
# Test with creating AG group
545-
Utils.initialize_model_parallel(
546-
context_parallel_size=world_size,
547-
create_all_gather_group=True,
548-
)
542+
Utils.initialize_model_parallel(context_parallel_size=world_size, create_all_gather_group=True)
549543
assert ps.has_separate_all_gather_group()
550544
assert ps._DATA_PARALLEL_GROUP_WITH_CP_AG is not None
551545

552546
# Verify it returns the correct group
553-
ag_group = ps.get_data_parallel_group(
554-
with_context_parallel=True, independent_all_gather=True
555-
)
547+
ag_group = ps.get_data_parallel_group(with_context_parallel=True, independent_all_gather=True)
556548
regular_group = ps.get_data_parallel_group(
557549
with_context_parallel=True, independent_all_gather=False
558550
)
@@ -564,3 +556,36 @@ def test_separate_all_gather_group():
564556
assert ag_ranks == regular_ranks
565557

566558
Utils.destroy_model_parallel()
559+
560+
561+
@pytest.mark.parametrize('order', test_parallel_order)
562+
@pytest.mark.flaky
563+
@pytest.mark.flaky_in_dev
564+
def test_separate_expert_all_gather_group(order):
565+
"""Test separate all-gather group for expert parallelism to enable communication overlap."""
566+
# Test without creating expert AG group (default)
567+
Utils.initialize_model_parallel(
568+
expert_model_parallel_size=world_size, create_all_gather_group=False, order=order
569+
)
570+
assert not ps.has_separate_expert_all_gather_group()
571+
assert ps._EXPERT_DATA_PARALLEL_GROUP_AG is None
572+
Utils.destroy_model_parallel()
573+
574+
# Test with creating expert AG group
575+
Utils.initialize_model_parallel(
576+
expert_model_parallel_size=world_size, create_all_gather_group=True, order=order
577+
)
578+
assert ps.has_separate_expert_all_gather_group()
579+
assert ps._EXPERT_DATA_PARALLEL_GROUP_AG is not None
580+
581+
# Verify it returns the correct group
582+
ag_group = ps.get_expert_data_parallel_group(independent_all_gather=True)
583+
regular_group = ps.get_expert_data_parallel_group(independent_all_gather=False)
584+
assert ag_group is not None
585+
assert regular_group is not None
586+
# They should have the same ranks but different communicators
587+
ag_ranks = torch.distributed.get_process_group_ranks(ag_group)
588+
regular_ranks = torch.distributed.get_process_group_ranks(regular_group)
589+
assert ag_ranks == regular_ranks
590+
591+
Utils.destroy_model_parallel()

0 commit comments

Comments
 (0)