Skip to content

Commit f04322f

Browse files
authored
[GPT-3] Add support for pure fp16 in dygraph mode (#1176)
1 parent 41bad63 commit f04322f

File tree

3 files changed

+62
-37
lines changed

3 files changed

+62
-37
lines changed

examples/language_model/gpt-3/dygraph/args.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,14 @@ def process_batch_size(args):
2222
if args.global_batch_size is None and args.local_batch_size is None:
2323
raise ValueError("global_batch_size or local_batch_size should be set.")
2424
elif args.global_batch_size is not None and args.local_batch_size is not None:
25-
assert args.global_batch_size // args.local_batch_size == args.dp_degree, \
26-
"global_batch_size[{}] should be divided by local_batch_size[{}] when dp_degree is [{}]"\
27-
.format(args.global_batch_size, args.local_batch_size, args.dp_degree)
25+
assert args.global_batch_size // args.local_batch_size == (args.dp_degree *
26+
args.sharding_degree), "global_batch_size[{}] should be divided by local_batch_size[{}] "\
27+
"when dp_degree is [{}] and sharding_degree is [{}]".format(args.global_batch_size,
28+
args.local_batch_size, args.dp_degree, args.sharding_degree)
2829
elif args.global_batch_size is not None and args.local_batch_size is None:
30+
assert args.global_batch_size % (args.dp_degree * args.sharding_degree) == 0, \
31+
"global_batch_size[{}] should be divided by dp_degree[{}] times sharding_degree[{}]"\
32+
.format(args.global_batch_size, args.dp_degree, args.sharding_degree)
2933
args.local_batch_size = args.global_batch_size // (args.dp_degree *
3034
args.sharding_degree)
3135
else:
@@ -220,13 +224,13 @@ def parse_args(MODEL_CLASSES):
220224
const=False,
221225
help="Using the recompute to save the memory.")
222226

223-
# AMP config
227+
# Pure FP16 config
224228
parser.add_argument(
225-
"--use_amp",
229+
"--use_pure_fp16",
226230
type=str2bool,
227231
nargs='?',
228232
const=False,
229-
help="Enable mixed precision training.")
233+
help="Enable pure fp16 precision training.")
230234

231235
parser.add_argument(
232236
"--scale_loss",

examples/language_model/gpt-3/dygraph/run.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,6 @@ python -m paddle.distributed.launch --log_dir $log_dir --gpus "0,1,2,3,4,5,6,7"
2121
--mp_degree 2\
2222
--pp_degree 2\
2323
--sharding_degree 1\
24-
--use_amp True\
24+
--use_pure_fp16 True\
2525
--use_recompute False
2626

examples/language_model/gpt-3/dygraph/run_pretrain.py

Lines changed: 51 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -126,8 +126,9 @@ def do_train(args):
126126
"sharding_degree": args.sharding_degree
127127
}
128128

129+
accumulate_steps = args.local_batch_size // args.micro_batch_size
129130
strategy.pipeline_configs = {
130-
"accumulate_steps": args.local_batch_size // args.micro_batch_size,
131+
"accumulate_steps": accumulate_steps,
131132
"micro_batch_size": args.micro_batch_size
132133
}
133134

@@ -160,8 +161,8 @@ def do_train(args):
160161
# Define log writer
161162
log_writer_path = os.path.join(
162163
args.output_dir, "train_log",
163-
"{}_globalbsz_{}_amp_{}_recompute_{}_card_{}".format(
164-
args.model_name_or_path, args.global_batch_size, args.use_amp,
164+
"{}_globalbsz_{}_pure_fp16_{}_recompute_{}_card_{}".format(
165+
args.model_name_or_path, args.global_batch_size, args.use_pure_fp16,
165166
False, global_rank).lower())
166167

167168
if os.path.exists(log_writer_path):
@@ -246,16 +247,25 @@ def do_train(args):
246247
parameters=model.parameters(),
247248
weight_decay=args.weight_decay,
248249
grad_clip=clip,
249-
apply_decay_param_fun=lambda x: x in decay_params)
250+
apply_decay_param_fun=lambda x: x in decay_params,
251+
# TODO: remove 'multi_precision' in definition of optimizer
252+
# and add it to 'paddle.amp.decorate'
253+
multi_precision=args.use_pure_fp16)
254+
255+
if args.use_pure_fp16:
256+
scaler = paddle.amp.GradScaler(init_loss_scaling=args.scale_loss)
257+
scaler = fleet.distributed_scaler(scaler)
258+
# level O2 means converting the network to FP16
259+
model, optimizer = paddle.amp.decorate(
260+
models=model,
261+
optimizers=optimizer,
262+
level='O2',
263+
save_dtype='float32')
250264

251265
if paddle.distributed.get_world_size() > 1:
252266
model = fleet.distributed_model(model)
253267
optimizer = fleet.distributed_optimizer(optimizer)
254268

255-
if args.use_amp:
256-
scaler = paddle.amp.GradScaler(init_loss_scaling=args.scale_loss)
257-
scaler = fleet.distributed_scaler(scaler)
258-
259269
if args.model_name_or_path not in pretrained_models_list:
260270
logger.info("Try to load checkpoint from %s " % args.model_name_or_path)
261271
opt_path = os.path.join(args.model_name_or_path, "model_state.pdopt")
@@ -294,23 +304,36 @@ def do_train(args):
294304
position_ids.stop_gradient = True
295305

296306
if args.pp_degree == 1:
297-
with paddle.amp.auto_cast(
298-
args.use_amp,
299-
custom_white_list=[
300-
"layer_norm", "softmax", "gelu"
301-
],
302-
custom_black_list=[
303-
"reduce_sum", "c_softmax_with_cross_entropy",
304-
"c_embedding"
305-
]):
306-
preds = model(tokens, position_ids)
307-
loss = criterion(preds, labels, loss_mask)
308-
309-
if args.use_amp:
310-
scaler.scale(loss).backward()
307+
# In ParallelMode of DataParallel, 'no_sync' can be used for improving
308+
# performance of model by gradient accumulation.
309+
loss = 0.0
310+
for i in range(accumulate_steps):
311+
start_index = i * args.micro_batch_size
312+
end_index = start_index + args.micro_batch_size
313+
with paddle.amp.auto_cast(
314+
args.use_pure_fp16,
315+
custom_black_list=[
316+
"reduce_sum",
317+
"c_softmax_with_cross_entropy",
318+
"elementwise_div"
319+
],
320+
level='O2'):
321+
preds = model(
322+
tokens[start_index:end_index, :],
323+
position_ids[start_index:end_index, :])
324+
loss_mbs = criterion(
325+
preds, labels[start_index:end_index, :],
326+
loss_mask[start_index:end_index, :])
327+
loss_mbs = loss_mbs / accumulate_steps
328+
if args.use_pure_fp16:
329+
scaler.scale(loss_mbs).backward()
330+
else:
331+
loss_mbs.backward()
332+
loss = loss + loss_mbs
333+
334+
if args.use_pure_fp16:
311335
scaler.minimize(optimizer, loss)
312336
else:
313-
loss.backward()
314337
optimizer.step()
315338

316339
if lr_scheduler is not None:
@@ -320,19 +343,17 @@ def do_train(args):
320343
else:
321344
data = [(tokens, position_ids), (labels, loss_mask)]
322345
with paddle.amp.auto_cast(
323-
args.use_amp,
324-
custom_white_list=[
325-
"layer_norm", "softmax", "gelu"
326-
],
346+
args.use_pure_fp16,
327347
custom_black_list=[
328348
"reduce_sum", "c_softmax_with_cross_entropy",
329-
"c_embedding"
330-
]):
349+
"elementwise_div"
350+
],
351+
level='O2'):
331352
loss = model.train_batch(
332353
data,
333354
optimizer=optimizer,
334355
lr_scheduler=lr_scheduler,
335-
scaler=scaler if args.use_amp else None)
356+
scaler=scaler if args.use_pure_fp16 else None)
336357

337358
if global_step % args.logging_freq == 0:
338359
avg_loss = loss.numpy()

0 commit comments

Comments
 (0)