@@ -126,8 +126,9 @@ def do_train(args):
126
126
"sharding_degree" : args .sharding_degree
127
127
}
128
128
129
+ accumulate_steps = args .local_batch_size // args .micro_batch_size
129
130
strategy .pipeline_configs = {
130
- "accumulate_steps" : args . local_batch_size // args . micro_batch_size ,
131
+ "accumulate_steps" : accumulate_steps ,
131
132
"micro_batch_size" : args .micro_batch_size
132
133
}
133
134
@@ -160,8 +161,8 @@ def do_train(args):
160
161
# Define log writer
161
162
log_writer_path = os .path .join (
162
163
args .output_dir , "train_log" ,
163
- "{}_globalbsz_{}_amp_ {}_recompute_{}_card_{}" .format (
164
- args .model_name_or_path , args .global_batch_size , args .use_amp ,
164
+ "{}_globalbsz_{}_pure_fp16_ {}_recompute_{}_card_{}" .format (
165
+ args .model_name_or_path , args .global_batch_size , args .use_pure_fp16 ,
165
166
False , global_rank ).lower ())
166
167
167
168
if os .path .exists (log_writer_path ):
@@ -246,16 +247,25 @@ def do_train(args):
246
247
parameters = model .parameters (),
247
248
weight_decay = args .weight_decay ,
248
249
grad_clip = clip ,
249
- apply_decay_param_fun = lambda x : x in decay_params )
250
+ apply_decay_param_fun = lambda x : x in decay_params ,
251
+ # TODO: remove 'multi_precision' in definition of optimizer
252
+ # and add it to 'paddle.amp.decorate'
253
+ multi_precision = args .use_pure_fp16 )
254
+
255
+ if args .use_pure_fp16 :
256
+ scaler = paddle .amp .GradScaler (init_loss_scaling = args .scale_loss )
257
+ scaler = fleet .distributed_scaler (scaler )
258
+ # level O2 means converting the network to FP16
259
+ model , optimizer = paddle .amp .decorate (
260
+ models = model ,
261
+ optimizers = optimizer ,
262
+ level = 'O2' ,
263
+ save_dtype = 'float32' )
250
264
251
265
if paddle .distributed .get_world_size () > 1 :
252
266
model = fleet .distributed_model (model )
253
267
optimizer = fleet .distributed_optimizer (optimizer )
254
268
255
- if args .use_amp :
256
- scaler = paddle .amp .GradScaler (init_loss_scaling = args .scale_loss )
257
- scaler = fleet .distributed_scaler (scaler )
258
-
259
269
if args .model_name_or_path not in pretrained_models_list :
260
270
logger .info ("Try to load checkpoint from %s " % args .model_name_or_path )
261
271
opt_path = os .path .join (args .model_name_or_path , "model_state.pdopt" )
@@ -294,23 +304,36 @@ def do_train(args):
294
304
position_ids .stop_gradient = True
295
305
296
306
if args .pp_degree == 1 :
297
- with paddle .amp .auto_cast (
298
- args .use_amp ,
299
- custom_white_list = [
300
- "layer_norm" , "softmax" , "gelu"
301
- ],
302
- custom_black_list = [
303
- "reduce_sum" , "c_softmax_with_cross_entropy" ,
304
- "c_embedding"
305
- ]):
306
- preds = model (tokens , position_ids )
307
- loss = criterion (preds , labels , loss_mask )
308
-
309
- if args .use_amp :
310
- scaler .scale (loss ).backward ()
307
+ # In ParallelMode of DataParallel, 'no_sync' can be used for improving
308
+ # performance of model by gradient accumulation.
309
+ loss = 0.0
310
+ for i in range (accumulate_steps ):
311
+ start_index = i * args .micro_batch_size
312
+ end_index = start_index + args .micro_batch_size
313
+ with paddle .amp .auto_cast (
314
+ args .use_pure_fp16 ,
315
+ custom_black_list = [
316
+ "reduce_sum" ,
317
+ "c_softmax_with_cross_entropy" ,
318
+ "elementwise_div"
319
+ ],
320
+ level = 'O2' ):
321
+ preds = model (
322
+ tokens [start_index :end_index , :],
323
+ position_ids [start_index :end_index , :])
324
+ loss_mbs = criterion (
325
+ preds , labels [start_index :end_index , :],
326
+ loss_mask [start_index :end_index , :])
327
+ loss_mbs = loss_mbs / accumulate_steps
328
+ if args .use_pure_fp16 :
329
+ scaler .scale (loss_mbs ).backward ()
330
+ else :
331
+ loss_mbs .backward ()
332
+ loss = loss + loss_mbs
333
+
334
+ if args .use_pure_fp16 :
311
335
scaler .minimize (optimizer , loss )
312
336
else :
313
- loss .backward ()
314
337
optimizer .step ()
315
338
316
339
if lr_scheduler is not None :
@@ -320,19 +343,17 @@ def do_train(args):
320
343
else :
321
344
data = [(tokens , position_ids ), (labels , loss_mask )]
322
345
with paddle .amp .auto_cast (
323
- args .use_amp ,
324
- custom_white_list = [
325
- "layer_norm" , "softmax" , "gelu"
326
- ],
346
+ args .use_pure_fp16 ,
327
347
custom_black_list = [
328
348
"reduce_sum" , "c_softmax_with_cross_entropy" ,
329
- "c_embedding"
330
- ]):
349
+ "elementwise_div"
350
+ ],
351
+ level = 'O2' ):
331
352
loss = model .train_batch (
332
353
data ,
333
354
optimizer = optimizer ,
334
355
lr_scheduler = lr_scheduler ,
335
- scaler = scaler if args .use_amp else None )
356
+ scaler = scaler if args .use_pure_fp16 else None )
336
357
337
358
if global_step % args .logging_freq == 0 :
338
359
avg_loss = loss .numpy ()
0 commit comments