Skip to content

Commit 91d81c9

Browse files
zhaoyingliaZHUI
andauthored
add sharding for gpt3 (#1064)
* add sharding for gpt-3 * del debug * add sharding save model * update model save * fix seed func * set control in tensor parallel Co-authored-by: Zhong Hui <[email protected]>
1 parent e05aed8 commit 91d81c9

File tree

3 files changed

+62
-32
lines changed

3 files changed

+62
-32
lines changed

examples/language_model/gpt-3/dygraph/args.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,10 @@ def process_batch_size(args):
2626
"global_batch_size[{}] should be divided by local_batch_size[{}] when dp_degree is [{}]"\
2727
.format(args.global_batch_size, args.local_batch_size, args.dp_degree)
2828
elif args.global_batch_size is not None and args.local_batch_size is None:
29-
args.local_batch_size = args.global_batch_size // args.dp_degree
29+
args.local_batch_size = args.global_batch_size // (args.dp_degree *
30+
args.sharding_degree)
3031
else:
31-
args.global_batch_size = args.local_batch_size * args.dp_degree
32+
args.global_batch_size = args.local_batch_size * args.dp_degree * args.sharding_degree
3233
assert args.local_batch_size % args.micro_batch_size == 0
3334

3435

examples/language_model/gpt-3/dygraph/run.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ python -m paddle.distributed.launch --log_dir $log_dir --gpus "0,1,2,3,4,5,6,7"
2020
--dp_degree 2\
2121
--mp_degree 2\
2222
--pp_degree 2\
23+
--sharding_degree 1\
2324
--use_amp True\
2425
--use_recompute False
2526

examples/language_model/gpt-3/dygraph/run_pretrain.py

Lines changed: 58 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -30,23 +30,24 @@
3030
import lr
3131
from paddle.distributed import fleet
3232
from paddle.distributed.fleet.meta_parallel import get_rng_state_tracker
33+
from paddle.distributed.fleet.meta_optimizers.dygraph_optimizer import DygraphShardingOptimizer
3334

3435
MODEL_CLASSES = {
3536
"gpt": (GPTForPretraining, GPTTokenizer),
3637
"gpt-cn": (GPTForPretraining, GPTChineseTokenizer),
3738
}
3839

3940

40-
def set_hyrbid_parallel_seed(basic_seed, dp_rank, mp_rank, pp_rank):
41+
def set_hyrbid_parallel_seed(basic_seed, data_world_rank, mp_rank, pp_rank):
4142
assert args.device != "cpu"
4243

43-
random.seed(basic_seed + dp_rank)
44-
np.random.seed(basic_seed + dp_rank)
45-
paddle.seed(basic_seed + dp_rank)
44+
random.seed(basic_seed + data_world_rank)
45+
np.random.seed(basic_seed + data_world_rank)
46+
paddle.seed(basic_seed + data_world_rank)
4647

4748
# local_seed/ global_seed is used to control dropout in ModelParallel
4849
local_seed = basic_seed + 123 + mp_rank * 10 + pp_rank * 1000
49-
global_seed = basic_seed + dp_rank
50+
global_seed = basic_seed + data_world_rank
5051
tracker = get_rng_state_tracker()
5152
tracker.add('global_seed', global_seed)
5253
tracker.add('local_seed', local_seed)
@@ -92,14 +93,18 @@ def do_train(args):
9293
strategy.hybrid_configs = {
9394
"dp_degree": args.dp_degree,
9495
"mp_degree": args.mp_degree,
95-
"pp_degree": args.pp_degree
96+
"pp_degree": args.pp_degree,
97+
"sharding_degree": args.sharding_degree
9698
}
9799

98100
strategy.pipeline_configs = {
99101
"accumulate_steps": args.local_batch_size // args.micro_batch_size,
100102
"micro_batch_size": args.micro_batch_size
101103
}
102104

105+
# set control in tensor parallel
106+
strategy.tensor_parallel_configs = {"tensor_init_seed": args.seed}
107+
103108
fleet.init(is_collective=True, strategy=strategy)
104109

105110
# obtain rank message of hybrid parallel
@@ -108,10 +113,15 @@ def do_train(args):
108113
mp_rank = hcg.get_model_parallel_rank()
109114
pp_rank = hcg.get_stage_id()
110115
dp_rank = hcg.get_data_parallel_rank()
116+
sharding_rank = hcg.get_sharding_parallel_rank()
117+
118+
sharding_size = hcg.get_sharding_parallel_world_size()
119+
data_world_rank = dp_rank * sharding_size + sharding_rank
120+
data_world_size = args.dp_degree * args.sharding_degree
111121
local_rank = int(os.getenv("PADDLE_RANK_IN_NODE", 0))
112122

113123
# seed control in hybrid parallel
114-
set_hyrbid_parallel_seed(args.seed, dp_rank, mp_rank, pp_rank)
124+
set_hyrbid_parallel_seed(args.seed, data_world_rank, mp_rank, pp_rank)
115125

116126
default_global_tokens_num = args.global_batch_size * args.max_seq_len
117127

@@ -183,15 +193,31 @@ def do_train(args):
183193
if not any(nd in n for nd in ["bias", "norm"])
184194
]
185195

186-
optimizer = paddle.optimizer.AdamW(
187-
learning_rate=lr_scheduler if lr_scheduler is not None else args.max_lr,
188-
beta1=args.adam_beta1,
189-
beta2=args.adam_beta2,
190-
epsilon=args.adam_epsilon,
191-
parameters=model.parameters(),
192-
weight_decay=args.weight_decay,
193-
grad_clip=clip,
194-
apply_decay_param_fun=lambda x: x in decay_params)
196+
if args.sharding_degree > 1:
197+
optimizer = DygraphShardingOptimizer(
198+
hcg=fleet.get_hybrid_communicate_group(),
199+
user_defined_strategy=strategy,
200+
params=model.parameters(),
201+
inner_optimizer_class=paddle.optimizer.AdamW,
202+
learning_rate=lr_scheduler
203+
if lr_scheduler is not None else args.max_lr,
204+
beta1=args.adam_beta1,
205+
beta2=args.adam_beta2,
206+
epsilon=args.adam_epsilon,
207+
weight_decay=args.weight_decay,
208+
grad_clip=clip,
209+
apply_decay_param_fun=lambda x: x in decay_params)
210+
else:
211+
optimizer = paddle.optimizer.AdamW(
212+
learning_rate=lr_scheduler
213+
if lr_scheduler is not None else args.max_lr,
214+
beta1=args.adam_beta1,
215+
beta2=args.adam_beta2,
216+
epsilon=args.adam_epsilon,
217+
parameters=model.parameters(),
218+
weight_decay=args.weight_decay,
219+
grad_clip=clip,
220+
apply_decay_param_fun=lambda x: x in decay_params)
195221

196222
if paddle.distributed.get_world_size() > 1:
197223
model = fleet.distributed_model(model)
@@ -227,8 +253,8 @@ def do_train(args):
227253
args,
228254
data_file,
229255
local_rank=local_rank,
230-
data_world_size=args.dp_degree,
231-
data_world_rank=dp_rank,
256+
data_world_size=data_world_size,
257+
data_world_rank=data_world_rank,
232258
eos_id=tokenizer.eos_token_id)
233259
# Bug fix, if not call valid_data_loader, the enumerate will call valid_data_loader
234260
# many times. and start a new random dataloader.
@@ -309,6 +335,7 @@ def do_train(args):
309335
args.eval_iters, log_writer, global_step,
310336
epoch, "valid")
311337

338+
# TODO: 1. merge paramters while saving model. 2. ensure that the model is saved and loaded correctly
312339
# only dp_rank = 0 save model
313340
if (global_step % args.save_steps == 0 or
314341
global_step >= args.max_steps) and dp_rank == 0:
@@ -322,24 +349,25 @@ def do_train(args):
322349
logger.info("Save model to %s" % output_dir)
323350

324351
if args.pp_degree > 1:
325-
model_to_save.save_state_dict(output_dir)
326-
if mp_rank * pp_rank == 1:
352+
if mp_rank == 0 and sharding_rank == 0 and pp_rank == 0:
327353
tokenizer.save_pretrained(output_dir)
354+
model_to_save.save_state_dict(output_dir)
328355
paddle.save(
329356
optimizer.state_dict(),
330357
os.path.join(
331358
output_dir,
332-
"model_state_mp_{:0>2d}_pp_{:0>2d}.pdopt".
333-
format(mp_rank, pp_rank)))
359+
"model_state_mp_{:0>2d}_sharding_{:0>2d}_pp_{:0>2d}.pdopt".
360+
format(mp_rank, sharding_rank, pp_rank)))
334361
else:
335-
path = os.path.join(output_dir,
336-
'model_{:0>2d}'.format(mp_rank))
337-
os.makedirs(path, exist_ok=True)
338-
model_to_save.save_pretrained(path)
339-
340-
paddle.save(optimizer.state_dict(),
341-
os.path.join(path, "model_state.pdopt"))
342-
tokenizer.save_pretrained(path)
362+
if mp_rank == 0 and sharding_rank == 0:
363+
tokenizer.save_pretrained(output_dir)
364+
model_to_save.save_pretrained(output_dir)
365+
paddle.save(
366+
optimizer.state_dict(),
367+
os.path.join(
368+
output_dir,
369+
"model_state_mp_{:0>2d}_sharding_{:0>2d}.pdopt".
370+
format(mp_rank, sharding_rank)))
343371

344372
if global_step >= args.max_steps:
345373
run_evaluate(args, test_data_loader, model, criterion,

0 commit comments

Comments
 (0)