Skip to content

Commit 33ede31

Browse files
committed
add expert all-gather process-group for overlapping
1 parent a25421e commit 33ede31

File tree

4 files changed

+83
-4
lines changed

4 files changed

+83
-4
lines changed

megatron/core/distributed/fsdp/src/megatron_fsdp/param_and_grad_buffer.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1605,6 +1605,9 @@ def __init__(
16051605
if self.dist_index.get_fsdp_group(is_expert_parallel=False, independent_all_gather=True) is not None:
16061606
# All-gather group used when overlapping all-gather and gradient reduction.
16071607
self.ubr_groups.append(self.dist_index.get_fsdp_group(is_expert_parallel=False, independent_all_gather=True))
1608+
if self.dist_index.get_fsdp_group(is_expert_parallel=True, independent_all_gather=True) is not None:
1609+
# Expert all-gather group used when overlapping all-gather and gradient reduction.
1610+
self.ubr_groups.append(self.dist_index.get_fsdp_group(is_expert_parallel=True, independent_all_gather=True))
16081611

16091612
if torch.distributed.get_rank() == 0:
16101613
logging.info(
@@ -1896,9 +1899,9 @@ def _init_each_parameter_group_buffers(self, meta_device_init_fp8_params):
18961899
# operations (main_grad_buffer). This avoids head-of-line blocking between forward
18971900
# all-gather and backward reduce-scatter on the same communicator.
18981901
model_wbuf_dp_group = main_buf_dp_group
1899-
if not group.is_expert_param and not should_create_hfsdp_wbuf_and_gbuf:
1902+
if not should_create_hfsdp_wbuf_and_gbuf:
19001903
ag_group = self.dist_index.get_fsdp_group(
1901-
is_expert_parallel=False, independent_all_gather=True
1904+
is_expert_parallel=group.is_expert_param, independent_all_gather=True
19021905
)
19031906
if ag_group is not None:
19041907
model_wbuf_dp_group = ag_group

megatron/core/distributed/fsdp/src/megatron_fsdp/utils.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -518,6 +518,12 @@ def __init__(
518518
and contains_submesh(self.expt_device_mesh, self.dp_shard_dim)
519519
else None
520520
)
521+
# Expert AG group for overlap
522+
self.expt_fsdp_group_ag = None
523+
if HAVE_MEGATRON_CORE and parallel_state.has_separate_expert_all_gather_group():
524+
self.expt_fsdp_group_ag = parallel_state.get_expert_data_parallel_group(
525+
independent_all_gather=True
526+
)
521527

522528
"""
523529
Megatron-FSDP is responsible for storing all required DeviceMesh
@@ -638,6 +644,8 @@ def get_dp_group(self, is_expert_parallel: bool = False) -> ProcessGroup:
638644
def get_fsdp_group(self, is_expert_parallel: bool = False, independent_all_gather: bool = False) -> ProcessGroup:
639645
"""Get the FSDP process group."""
640646
if is_expert_parallel:
647+
if independent_all_gather:
648+
return self.expt_fsdp_group_ag
641649
return self.expt_fsdp_group
642650
if independent_all_gather:
643651
return self.fsdp_group_ag

megatron/core/parallel_state.py

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@
6060
# Expert data parallel group
6161
_EXPERT_DATA_PARALLEL_GROUP = None
6262
_EXPERT_DATA_PARALLEL_GROUP_GLOO = None
63+
_EXPERT_DATA_PARALLEL_GROUP_AG = None
6364
_INTRA_PARTIAL_EXPERT_DATA_PARALLEL_GROUP = None
6465
_INTRA_PARTIAL_EXPERT_DATA_PARALLEL_GROUP_GLOO = None
6566
_INTER_PARTIAL_EXPERT_DATA_PARALLEL_GROUP = None
@@ -1207,6 +1208,8 @@ def initialize_model_parallel(
12071208
assert _EXPERT_DATA_PARALLEL_GROUP is None, "Expert data group is already initialized"
12081209
global _EXPERT_DATA_PARALLEL_GROUP_GLOO
12091210
assert _EXPERT_DATA_PARALLEL_GROUP_GLOO is None, "Expert data group-gloo is already initialized"
1211+
global _EXPERT_DATA_PARALLEL_GROUP_AG
1212+
assert _EXPERT_DATA_PARALLEL_GROUP_AG is None, "Expert data parallel group with AG is already initialized"
12101213
global _INTRA_PARTIAL_EXPERT_DATA_PARALLEL_GROUP
12111214
assert (
12121215
_INTRA_PARTIAL_EXPERT_DATA_PARALLEL_GROUP is None
@@ -1240,10 +1243,20 @@ def initialize_model_parallel(
12401243
)
12411244
else:
12421245
group_gloo = None
1246+
# Create separate all-gather group for expert data parallelism to enable overlap
1247+
if create_all_gather_group:
1248+
group_ag = create_group(
1249+
ranks,
1250+
timeout=timeout,
1251+
pg_options=get_nccl_options("ep_dp", nccl_comm_cfgs),
1252+
group_desc="EXPERT_DATA_PARALLEL_GROUP_AG",
1253+
)
1254+
else:
1255+
group_ag = None
12431256
if rank in ranks:
12441257
_EXPERT_DATA_PARALLEL_GROUP = group
12451258
_EXPERT_DATA_PARALLEL_GROUP_GLOO = group_gloo
1246-
1259+
_EXPERT_DATA_PARALLEL_GROUP_AG = group_ag
12471260
if num_distributed_optimizer_instances > 1:
12481261
# Create groups for Partial DistOpt, one for intra-partial DP domain
12491262
# Another for inter-partial DP domain
@@ -1397,6 +1410,15 @@ def has_separate_all_gather_group() -> bool:
13971410
return _DATA_PARALLEL_GROUP_WITH_CP_AG is not None
13981411

13991412

1413+
def has_separate_expert_all_gather_group() -> bool:
1414+
"""Check if a separate all-gather process group for experts has been created.
1415+
1416+
Returns True if a dedicated all-gather process group for expert parallelism exists
1417+
for improved communication overlap, False otherwise.
1418+
"""
1419+
return _EXPERT_DATA_PARALLEL_GROUP_AG is not None
1420+
1421+
14001422
def get_data_parallel_group_gloo(with_context_parallel=False, partial_data_parallel=False):
14011423
"""Get the Gloo data-parallel group the caller rank belongs to."""
14021424
if with_context_parallel:
@@ -1886,8 +1908,14 @@ def get_expert_tensor_model_pipeline_parallel_group(check_initialized=True):
18861908
return _EXPERT_TENSOR_MODEL_PIPELINE_PARALLEL_GROUP
18871909

18881910

1889-
def get_expert_data_parallel_group(check_initialized=True, partial_expert_data_parallel=False):
1911+
def get_expert_data_parallel_group(check_initialized=True, partial_expert_data_parallel=False, independent_all_gather=False):
18901912
"""Get expert data parallel group."""
1913+
if independent_all_gather:
1914+
if check_initialized:
1915+
assert (
1916+
_EXPERT_DATA_PARALLEL_GROUP_AG is not None
1917+
), "Expert data parallel group with AG is not initialized"
1918+
return _EXPERT_DATA_PARALLEL_GROUP_AG
18911919
if partial_expert_data_parallel:
18921920
if check_initialized:
18931921
assert (
@@ -2155,6 +2183,9 @@ def destroy_model_parallel():
21552183
torch.distributed.destroy_process_group(_EXPERT_DATA_PARALLEL_GROUP_GLOO)
21562184
_EXPERT_DATA_PARALLEL_GROUP_GLOO = None
21572185

2186+
global _EXPERT_DATA_PARALLEL_GROUP_AG
2187+
_EXPERT_DATA_PARALLEL_GROUP_AG = None
2188+
21582189
global _INTRA_PARTIAL_EXPERT_DATA_PARALLEL_GROUP
21592190
_INTRA_PARTIAL_EXPERT_DATA_PARALLEL_GROUP = None
21602191

tests/unit_tests/test_parallel_state.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -535,3 +535,40 @@ def test_separate_all_gather_group():
535535
assert ag_ranks == regular_ranks
536536

537537
Utils.destroy_model_parallel()
538+
539+
540+
@pytest.mark.parametrize('order', test_parallel_order)
541+
@pytest.mark.flaky
542+
@pytest.mark.flaky_in_dev
543+
def test_separate_expert_all_gather_group(order):
544+
"""Test separate all-gather group for expert parallelism to enable communication overlap."""
545+
# Test without creating expert AG group (default)
546+
Utils.initialize_model_parallel(
547+
expert_model_parallel_size=world_size,
548+
create_all_gather_group=False,
549+
order=order,
550+
)
551+
assert not ps.has_separate_expert_all_gather_group()
552+
assert ps._EXPERT_DATA_PARALLEL_GROUP_AG is None
553+
Utils.destroy_model_parallel()
554+
555+
# Test with creating expert AG group
556+
Utils.initialize_model_parallel(
557+
expert_model_parallel_size=world_size,
558+
create_all_gather_group=True,
559+
order=order,
560+
)
561+
assert ps.has_separate_expert_all_gather_group()
562+
assert ps._EXPERT_DATA_PARALLEL_GROUP_AG is not None
563+
564+
# Verify it returns the correct group
565+
ag_group = ps.get_expert_data_parallel_group(independent_all_gather=True)
566+
regular_group = ps.get_expert_data_parallel_group(independent_all_gather=False)
567+
assert ag_group is not None
568+
assert regular_group is not None
569+
# They should have the same ranks but different communicators
570+
ag_ranks = torch.distributed.get_process_group_ranks(ag_group)
571+
regular_ranks = torch.distributed.get_process_group_ranks(regular_group)
572+
assert ag_ranks == regular_ranks
573+
574+
Utils.destroy_model_parallel()

0 commit comments

Comments
 (0)