Skip to content

Commit c1e5656

Browse files
authored
[Cherry-Pick] support context parallel (#74201) (#74983)
* [Cherry-Pick] support context parallel (#74201) * refine * refine * refine * add ut * fix no cp * refine ut * fix dense topo
1 parent 025b425 commit c1e5656

File tree

6 files changed

+393
-4
lines changed

6 files changed

+393
-4
lines changed

paddle/fluid/framework/distributed_strategy.proto

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,12 @@ message HybridConfig {
137137
optional EpConfig ep_configs = 17;
138138
optional MoeShardingConfig moe_sharding_configs = 18;
139139
optional DefaultCommGroupConfig default_comm_group_configs = 19;
140+
optional int32 cp_degree = 20 [ default = 1 ];
141+
optional int32 cp_sharding_degree = 21 [ default = 1 ];
142+
optional CpConfig cp_configs = 22;
143+
optional CpShardingConfig cp_sharding_configs = 23;
144+
optional DpCpConfig dp_cp_configs = 24;
145+
optional CpMpConfig cp_mp_configs = 25;
140146
}
141147

142148
message AMPConfig {
@@ -502,6 +508,22 @@ message MoeShardingConfig {
502508
optional NCCLConfig check_nccl_config = 2;
503509
}
504510

511+
message CpConfig {
512+
optional NCCLConfig nccl_config = 1;
513+
}
514+
515+
message CpShardingConfig {
516+
optional NCCLConfig nccl_config = 1;
517+
}
518+
519+
message DpCpConfig {
520+
optional NCCLConfig nccl_config = 1;
521+
}
522+
523+
message CpMpConfig {
524+
optional NCCLConfig nccl_config = 1;
525+
}
526+
505527
message DefaultCommGroupConfig {
506528
optional NCCLConfig nccl_config = 1;
507529
}

python/paddle/distributed/fleet/base/distributed_strategy.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@ class _HybridConfig(TypedDict, total=False):
102102
mp_degree: int
103103
pp_degree: int
104104
sep_degree: int
105+
cp_degree: int
105106
sharding_degree: int
106107
order: list[str]
107108

@@ -325,6 +326,7 @@ def __init__(self) -> None:
325326
'pp',
326327
'sharding',
327328
'sep',
329+
'cp',
328330
'mp',
329331
]
330332
self.sync_param_name: list[str] = ["embedding", "layer_norm", ".b_"]
@@ -1907,6 +1909,7 @@ def hybrid_configs(self) -> _HybridConfig:
19071909
19081910
**pp_degree(int)**: set number of GPUs in a pipeline parallel group. Default 1
19091911
**sep_degree(int)**: set number of GPUs in a sep parallel group. Default 1
1912+
**cp_degree(int)**: set number of GPUs in a context parallel group. Default 1
19101913
**sharding_degree(int)**: set number of GPUs in a sharding parallel group. Default 1
19111914
**order(list(string))**: set hybrid parallel dimensions, the order is from outside to inside. Default ['dp','pp','sharding','sep', 'mp']
19121915

0 commit comments

Comments
 (0)