@@ -1194,28 +1194,36 @@ def is_segment_parallel_supported():
1194
1194
1195
1195
elif self .enable_auto_parallel :
1196
1196
self .tensor_parallel_degree = max (self .tensor_parallel_degree , 1 )
1197
+ self .sep_parallel_degree = max (self .sep_parallel_degree , 1 )
1197
1198
self .pipeline_parallel_degree = max (self .pipeline_parallel_degree , 1 )
1198
1199
1199
1200
assert (
1200
1201
world_size % (self .tensor_parallel_degree * self .pipeline_parallel_degree ) == 0
1201
1202
), f"Total world_size:{ world_size } shoule be devided by tensor_parallel_degree: { self .tensor_parallel_degree } and pipeline_parallel_degree: { self .pipeline_parallel_degree } ."
1202
1203
1203
- self .data_parallel_degree = world_size // (self .tensor_parallel_degree * self .pipeline_parallel_degree )
1204
-
1205
1204
if self .sharding_parallel_degree == - 1 :
1206
1205
if len (self .sharding ) > 0 :
1207
- self .sharding_parallel_degree = self .data_parallel_degree
1206
+ self .sharding_parallel_degree = world_size // (
1207
+ self .tensor_parallel_degree * self .sep_parallel_degree * self .pipeline_parallel_degree
1208
+ )
1208
1209
1209
1210
self .sharding_parallel_degree = max (self .sharding_parallel_degree , 1 )
1210
1211
if self .sharding_parallel_degree == 1 and len (self .sharding ) > 0 :
1211
1212
logger .warning ("sharding_parallel_degree=1 means no sharding, please set sharding to empty!" )
1212
1213
self .sharding = []
1213
1214
1215
+ self .data_parallel_degree = world_size // (
1216
+ self .sharding_parallel_degree
1217
+ * self .tensor_parallel_degree
1218
+ * self .sep_parallel_degree
1219
+ * self .pipeline_parallel_degree
1220
+ )
1221
+
1214
1222
if ShardingOption .OFFLOAD in self .sharding :
1215
1223
warnings .warn ("`offload` is not supported NOW!" )
1216
1224
1217
1225
strategy = fleet .auto .Strategy ()
1218
- if self .data_parallel_degree > 1 :
1226
+ if self .dataset_world_size > 1 :
1219
1227
data_parallel_config = set (self .data_parallel_config .split (" " ))
1220
1228
for x in data_parallel_config :
1221
1229
if len (x ) > 0 :
@@ -1356,10 +1364,10 @@ def is_segment_parallel_supported():
1356
1364
self .strategy = strategy
1357
1365
if self .hybrid_parallel_topo_order == "pp_first" :
1358
1366
order = ["pp" , "dp" , "mp" ]
1359
- degree = [self .pipeline_parallel_degree , self .data_parallel_degree , self .tensor_parallel_degree ]
1367
+ degree = [self .pipeline_parallel_degree , self .dataset_world_size , self .tensor_parallel_degree ]
1360
1368
elif self .hybrid_parallel_topo_order == "sharding_first" :
1361
1369
order = ["dp" , "pp" , "mp" ]
1362
- degree = [self .data_parallel_degree , self .pipeline_parallel_degree , self .tensor_parallel_degree ]
1370
+ degree = [self .dataset_world_size , self .pipeline_parallel_degree , self .tensor_parallel_degree ]
1363
1371
mesh_dims = list (zip (order , degree ))
1364
1372
fleet .auto .create_mesh (mesh_dims )
1365
1373
@@ -1371,7 +1379,7 @@ def is_segment_parallel_supported():
1371
1379
1372
1380
strategy = fleet .DistributedStrategy ()
1373
1381
strategy .hybrid_configs = {
1374
- "dp_degree" : self .data_parallel_degree ,
1382
+ "dp_degree" : self .dataset_world_size ,
1375
1383
"mp_degree" : self .tensor_parallel_degree ,
1376
1384
"pp_degree" : self .pipeline_parallel_degree ,
1377
1385
"order" : order ,
@@ -1526,7 +1534,7 @@ def dataset_world_size(self):
1526
1534
if self .use_hybrid_parallel :
1527
1535
return max (self .sharding_parallel_degree , 1 ) * max (self .data_parallel_degree , 1 )
1528
1536
elif self .enable_auto_parallel :
1529
- return max (self .data_parallel_degree , 1 )
1537
+ return max (self .sharding_parallel_degree , 1 ) * max ( self . data_parallel_degree , 1 )
1530
1538
else :
1531
1539
return paddle .distributed .get_world_size ()
1532
1540
0 commit comments