Skip to content

Commit 7b493a8

Browse files
authored
Support sharding for auto_trainer (#8164)
* add * fix * refine code * refine * fix * fix * fix * refine
1 parent 4d661bc commit 7b493a8

File tree

5 files changed

+41
-55
lines changed

5 files changed

+41
-55
lines changed

llm/llama/auto_parallel/run_pretrain_auto.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -275,14 +275,14 @@ def create_pretrained_dataset(
275275

276276
train_val_test_num_samples = [
277277
training_args.per_device_train_batch_size
278-
* training_args.data_parallel_degree
278+
* training_args.dataset_world_size
279279
* training_args.max_steps
280280
* training_args.gradient_accumulation_steps,
281281
training_args.per_device_eval_batch_size
282-
* training_args.data_parallel_degree
282+
* training_args.dataset_world_size
283283
* training_args.eval_iters
284284
* (training_args.max_steps // training_args.eval_steps + 1),
285-
training_args.per_device_eval_batch_size * training_args.data_parallel_degree * training_args.test_iters,
285+
training_args.per_device_eval_batch_size * training_args.dataset_world_size * training_args.test_iters,
286286
]
287287

288288
print_rank_0(" > datasets target sizes (minimum size):")
@@ -411,7 +411,7 @@ def init_seed(seed: int = 1234, args=None):
411411
topo = Topology(
412412
dist.get_rank(),
413413
dist.get_world_size(),
414-
dp_degree=args.data_parallel_degree,
414+
dp_degree=args.dataset_world_size,
415415
pp_degree=args.pipeline_parallel_degree,
416416
mp_degree=args.tensor_parallel_degree,
417417
sharding_degree=1, # auto_parallel's sharding is not orthogonal with dp, mp and pp

llm/llama/auto_parallel/run_pretrain_auto_static.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -274,14 +274,14 @@ def create_pretrained_dataset(
274274

275275
train_val_test_num_samples = [
276276
training_args.per_device_train_batch_size
277-
* training_args.data_parallel_degree
277+
* training_args.dataset_world_size
278278
* training_args.max_steps
279279
* training_args.gradient_accumulation_steps,
280280
training_args.per_device_eval_batch_size
281-
* training_args.data_parallel_degree
281+
* training_args.dataset_world_size
282282
* training_args.eval_iters
283283
* (training_args.max_steps // training_args.eval_steps + 1),
284-
training_args.per_device_eval_batch_size * training_args.data_parallel_degree * training_args.test_iters,
284+
training_args.per_device_eval_batch_size * training_args.dataset_world_size * training_args.test_iters,
285285
]
286286

287287
print_rank_0(" > datasets target sizes (minimum size):")
@@ -421,7 +421,7 @@ def init_seed(seed: int = 1234, args=None):
421421
topo = Topology(
422422
dist.get_rank(),
423423
dist.get_world_size(),
424-
dp_degree=args.data_parallel_degree,
424+
dp_degree=args.dataset_world_size,
425425
pp_degree=args.pipeline_parallel_degree,
426426
mp_degree=args.tensor_parallel_degree,
427427
sharding_degree=1, # auto_parallel's sharding is not orthogonal with dp, mp and pp
@@ -600,9 +600,7 @@ def fn(layer):
600600
def loss_func(loss, outputs):
601601
return loss
602602

603-
total_train_batch_size_per_acc_step = (
604-
training_args.per_device_train_batch_size * training_args.data_parallel_degree
605-
)
603+
total_train_batch_size_per_acc_step = training_args.per_device_train_batch_size * training_args.dataset_world_size
606604
total_train_batch_size = total_train_batch_size_per_acc_step * training_args.gradient_accumulation_steps
607605

608606
print_config(training_args)

paddlenlp/trainer/auto_trainer.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from .trainer_callback import TrainerState
3333
from .trainer_utils import ( # set_hyrbid_parallel_seed,
3434
PREFIX_CHECKPOINT_DIR,
35+
ShardingOption,
3536
TrainOutput,
3637
_exec_mode_guard,
3738
get_last_checkpoint,
@@ -111,6 +112,13 @@ def _wrap_for_dist_loader(self, train_dataloader):
111112
def _wrap_for_auto(self, model, train_dataloader):
112113
dist_loader = self._wrap_for_dist_loader(train_dataloader)
113114

115+
if ShardingOption.SHARD_OP in self.args.sharding:
116+
self.optimizer = dist.shard_optimizer(self.optimizer, dist.ShardingStage1())
117+
elif ShardingOption.SHARD_GRAD_OP in self.args.sharding:
118+
self.optimizer = dist.shard_optimizer(self.optimizer, dist.ShardingStage2())
119+
elif ShardingOption.FULL_SHARD in self.args.sharding:
120+
self.optimizer = dist.shard_optimizer(self.optimizer, dist.ShardingStage3())
121+
114122
if self.args.to_static:
115123
unified_strategy = dist.Strategy()
116124
unified_strategy._from_legacy_strategy(self.args.strategy)
@@ -119,7 +127,6 @@ def _wrap_for_auto(self, model, train_dataloader):
119127
dist_loader,
120128
)
121129
else:
122-
self.optimizer = dist.shard_optimizer(self.optimizer)
123130
return model, dist_loader
124131

125132
def _wrap_amp_model(self, args, model):

paddlenlp/trainer/trainer.py

Lines changed: 8 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,6 @@
4040
import paddle.nn as nn
4141
from packaging import version
4242
from paddle.distributed import fleet
43-
from paddle.distributed.fleet.meta_optimizers.dygraph_optimizer.dygraph_sharding_optimizer import (
44-
DygraphShardingOptimizer,
45-
)
4643
from paddle.distributed.fleet.meta_optimizers.dygraph_optimizer.hybrid_parallel_optimizer import (
4744
HybridParallelOptimizer,
4845
)
@@ -1538,38 +1535,14 @@ def apply_decay_param_fun(x):
15381535
if hasattr(optimizer_cls, "_create_master_weight") and self.args.fp16_opt_level == "O2":
15391536
optimizer_kwargs["multi_precision"] = True
15401537

1541-
def is_new_version_sharding_stage1_optimizer():
1542-
signature_keys = set(inspect.signature(DygraphShardingOptimizer).parameters.keys())
1543-
return "inner_optimizer_class" not in signature_keys
1544-
1545-
if ShardingOption.SHARD_OP in self.args.sharding and not is_new_version_sharding_stage1_optimizer():
1546-
# for backward compatibility.
1547-
# this call will raise, if sharding stage1 is supported in HybridParallelOptimizer,
1548-
# in which case, the logic follows will handle it
1549-
self.optimizer = DygraphShardingOptimizer(
1550-
hcg=fleet.get_hybrid_communicate_group(),
1551-
user_defined_strategy=None,
1552-
params=params,
1553-
inner_optimizer_class=optimizer_cls,
1554-
learning_rate=self.lr_scheduler if lr_scheduler is None else lr_scheduler,
1555-
apply_decay_param_fun=apply_decay_param_fun,
1556-
weight_decay=self.args.weight_decay,
1557-
grad_clip=nn.ClipGradByGlobalNorm(self.args.max_grad_norm)
1558-
if self.args.max_grad_norm > 0
1559-
else None,
1560-
**optimizer_kwargs,
1561-
)
1562-
else:
1563-
self.optimizer = optimizer_cls(
1564-
learning_rate=self.lr_scheduler if lr_scheduler is None else lr_scheduler,
1565-
apply_decay_param_fun=apply_decay_param_fun,
1566-
parameters=params,
1567-
weight_decay=self.args.weight_decay,
1568-
grad_clip=nn.ClipGradByGlobalNorm(self.args.max_grad_norm)
1569-
if self.args.max_grad_norm > 0
1570-
else None,
1571-
**optimizer_kwargs,
1572-
)
1538+
self.optimizer = optimizer_cls(
1539+
learning_rate=self.lr_scheduler if lr_scheduler is None else lr_scheduler,
1540+
apply_decay_param_fun=apply_decay_param_fun,
1541+
parameters=params,
1542+
weight_decay=self.args.weight_decay,
1543+
grad_clip=nn.ClipGradByGlobalNorm(self.args.max_grad_norm) if self.args.max_grad_norm > 0 else None,
1544+
**optimizer_kwargs,
1545+
)
15731546

15741547
return self.optimizer
15751548

paddlenlp/trainer/training_args.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1194,28 +1194,36 @@ def is_segment_parallel_supported():
11941194

11951195
elif self.enable_auto_parallel:
11961196
self.tensor_parallel_degree = max(self.tensor_parallel_degree, 1)
1197+
self.sep_parallel_degree = max(self.sep_parallel_degree, 1)
11971198
self.pipeline_parallel_degree = max(self.pipeline_parallel_degree, 1)
11981199

11991200
assert (
12001201
world_size % (self.tensor_parallel_degree * self.pipeline_parallel_degree) == 0
12011202
), f"Total world_size:{world_size} shoule be devided by tensor_parallel_degree: {self.tensor_parallel_degree} and pipeline_parallel_degree: {self.pipeline_parallel_degree}."
12021203

1203-
self.data_parallel_degree = world_size // (self.tensor_parallel_degree * self.pipeline_parallel_degree)
1204-
12051204
if self.sharding_parallel_degree == -1:
12061205
if len(self.sharding) > 0:
1207-
self.sharding_parallel_degree = self.data_parallel_degree
1206+
self.sharding_parallel_degree = world_size // (
1207+
self.tensor_parallel_degree * self.sep_parallel_degree * self.pipeline_parallel_degree
1208+
)
12081209

12091210
self.sharding_parallel_degree = max(self.sharding_parallel_degree, 1)
12101211
if self.sharding_parallel_degree == 1 and len(self.sharding) > 0:
12111212
logger.warning("sharding_parallel_degree=1 means no sharding, please set sharding to empty!")
12121213
self.sharding = []
12131214

1215+
self.data_parallel_degree = world_size // (
1216+
self.sharding_parallel_degree
1217+
* self.tensor_parallel_degree
1218+
* self.sep_parallel_degree
1219+
* self.pipeline_parallel_degree
1220+
)
1221+
12141222
if ShardingOption.OFFLOAD in self.sharding:
12151223
warnings.warn("`offload` is not supported NOW!")
12161224

12171225
strategy = fleet.auto.Strategy()
1218-
if self.data_parallel_degree > 1:
1226+
if self.dataset_world_size > 1:
12191227
data_parallel_config = set(self.data_parallel_config.split(" "))
12201228
for x in data_parallel_config:
12211229
if len(x) > 0:
@@ -1356,10 +1364,10 @@ def is_segment_parallel_supported():
13561364
self.strategy = strategy
13571365
if self.hybrid_parallel_topo_order == "pp_first":
13581366
order = ["pp", "dp", "mp"]
1359-
degree = [self.pipeline_parallel_degree, self.data_parallel_degree, self.tensor_parallel_degree]
1367+
degree = [self.pipeline_parallel_degree, self.dataset_world_size, self.tensor_parallel_degree]
13601368
elif self.hybrid_parallel_topo_order == "sharding_first":
13611369
order = ["dp", "pp", "mp"]
1362-
degree = [self.data_parallel_degree, self.pipeline_parallel_degree, self.tensor_parallel_degree]
1370+
degree = [self.dataset_world_size, self.pipeline_parallel_degree, self.tensor_parallel_degree]
13631371
mesh_dims = list(zip(order, degree))
13641372
fleet.auto.create_mesh(mesh_dims)
13651373

@@ -1371,7 +1379,7 @@ def is_segment_parallel_supported():
13711379

13721380
strategy = fleet.DistributedStrategy()
13731381
strategy.hybrid_configs = {
1374-
"dp_degree": self.data_parallel_degree,
1382+
"dp_degree": self.dataset_world_size,
13751383
"mp_degree": self.tensor_parallel_degree,
13761384
"pp_degree": self.pipeline_parallel_degree,
13771385
"order": order,
@@ -1526,7 +1534,7 @@ def dataset_world_size(self):
15261534
if self.use_hybrid_parallel:
15271535
return max(self.sharding_parallel_degree, 1) * max(self.data_parallel_degree, 1)
15281536
elif self.enable_auto_parallel:
1529-
return max(self.data_parallel_degree, 1)
1537+
return max(self.sharding_parallel_degree, 1) * max(self.data_parallel_degree, 1)
15301538
else:
15311539
return paddle.distributed.get_world_size()
15321540

0 commit comments

Comments
 (0)