|
60 | 60 | # Expert data parallel group |
61 | 61 | _EXPERT_DATA_PARALLEL_GROUP = None |
62 | 62 | _EXPERT_DATA_PARALLEL_GROUP_GLOO = None |
| 63 | +_EXPERT_DATA_PARALLEL_GROUP_AG = None |
63 | 64 | _INTRA_PARTIAL_EXPERT_DATA_PARALLEL_GROUP = None |
64 | 65 | _INTRA_PARTIAL_EXPERT_DATA_PARALLEL_GROUP_GLOO = None |
65 | 66 | _INTER_PARTIAL_EXPERT_DATA_PARALLEL_GROUP = None |
@@ -1207,6 +1208,8 @@ def initialize_model_parallel( |
1207 | 1208 | assert _EXPERT_DATA_PARALLEL_GROUP is None, "Expert data group is already initialized" |
1208 | 1209 | global _EXPERT_DATA_PARALLEL_GROUP_GLOO |
1209 | 1210 | assert _EXPERT_DATA_PARALLEL_GROUP_GLOO is None, "Expert data group-gloo is already initialized" |
| 1211 | + global _EXPERT_DATA_PARALLEL_GROUP_AG |
| 1212 | + assert _EXPERT_DATA_PARALLEL_GROUP_AG is None, "Expert data parallel group with AG is already initialized" |
1210 | 1213 | global _INTRA_PARTIAL_EXPERT_DATA_PARALLEL_GROUP |
1211 | 1214 | assert ( |
1212 | 1215 | _INTRA_PARTIAL_EXPERT_DATA_PARALLEL_GROUP is None |
@@ -1240,10 +1243,20 @@ def initialize_model_parallel( |
1240 | 1243 | ) |
1241 | 1244 | else: |
1242 | 1245 | group_gloo = None |
| 1246 | + # Create separate all-gather group for expert data parallelism to enable overlap |
| 1247 | + if create_all_gather_group: |
| 1248 | + group_ag = create_group( |
| 1249 | + ranks, |
| 1250 | + timeout=timeout, |
| 1251 | + pg_options=get_nccl_options("ep_dp", nccl_comm_cfgs), |
| 1252 | + group_desc="EXPERT_DATA_PARALLEL_GROUP_AG", |
| 1253 | + ) |
| 1254 | + else: |
| 1255 | + group_ag = None |
1243 | 1256 | if rank in ranks: |
1244 | 1257 | _EXPERT_DATA_PARALLEL_GROUP = group |
1245 | 1258 | _EXPERT_DATA_PARALLEL_GROUP_GLOO = group_gloo |
1246 | | - |
| 1259 | + _EXPERT_DATA_PARALLEL_GROUP_AG = group_ag |
1247 | 1260 | if num_distributed_optimizer_instances > 1: |
1248 | 1261 | # Create groups for Partial DistOpt, one for intra-partial DP domain |
1249 | 1262 | # Another for inter-partial DP domain |
@@ -1397,6 +1410,15 @@ def has_separate_all_gather_group() -> bool: |
1397 | 1410 | return _DATA_PARALLEL_GROUP_WITH_CP_AG is not None |
1398 | 1411 |
|
1399 | 1412 |
|
| 1413 | +def has_separate_expert_all_gather_group() -> bool: |
| 1414 | + """Check if a separate all-gather process group for experts has been created. |
| 1415 | + |
| 1416 | + Returns True if a dedicated all-gather process group for expert parallelism exists |
| 1417 | + for improved communication overlap, False otherwise. |
| 1418 | + """ |
| 1419 | + return _EXPERT_DATA_PARALLEL_GROUP_AG is not None |
| 1420 | + |
| 1421 | + |
1400 | 1422 | def get_data_parallel_group_gloo(with_context_parallel=False, partial_data_parallel=False): |
1401 | 1423 | """Get the Gloo data-parallel group the caller rank belongs to.""" |
1402 | 1424 | if with_context_parallel: |
@@ -1886,8 +1908,14 @@ def get_expert_tensor_model_pipeline_parallel_group(check_initialized=True): |
1886 | 1908 | return _EXPERT_TENSOR_MODEL_PIPELINE_PARALLEL_GROUP |
1887 | 1909 |
|
1888 | 1910 |
|
1889 | | -def get_expert_data_parallel_group(check_initialized=True, partial_expert_data_parallel=False): |
| 1911 | +def get_expert_data_parallel_group(check_initialized=True, partial_expert_data_parallel=False, independent_all_gather=False): |
1890 | 1912 | """Get expert data parallel group.""" |
| 1913 | + if independent_all_gather: |
| 1914 | + if check_initialized: |
| 1915 | + assert ( |
| 1916 | + _EXPERT_DATA_PARALLEL_GROUP_AG is not None |
| 1917 | + ), "Expert data parallel group with AG is not initialized" |
| 1918 | + return _EXPERT_DATA_PARALLEL_GROUP_AG |
1891 | 1919 | if partial_expert_data_parallel: |
1892 | 1920 | if check_initialized: |
1893 | 1921 | assert ( |
@@ -2155,6 +2183,9 @@ def destroy_model_parallel(): |
2155 | 2183 | torch.distributed.destroy_process_group(_EXPERT_DATA_PARALLEL_GROUP_GLOO) |
2156 | 2184 | _EXPERT_DATA_PARALLEL_GROUP_GLOO = None |
2157 | 2185 |
|
| 2186 | + global _EXPERT_DATA_PARALLEL_GROUP_AG |
| 2187 | + _EXPERT_DATA_PARALLEL_GROUP_AG = None |
| 2188 | + |
2158 | 2189 | global _INTRA_PARTIAL_EXPERT_DATA_PARALLEL_GROUP |
2159 | 2190 | _INTRA_PARTIAL_EXPERT_DATA_PARALLEL_GROUP = None |
2160 | 2191 |
|
|
0 commit comments