Skip to content

Commit 57ce415

Browse files
wawltorZeyuChen
andauthored
add the time cost counter for the bert (#544)
Co-authored-by: Zeyu Chen <[email protected]>
1 parent 840149d commit 57ce415

File tree

3 files changed

+63
-28
lines changed

3 files changed

+63
-28
lines changed

examples/language_model/bert/run_pretrain.py

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from paddle.io import DataLoader, Dataset
3232

3333
from paddlenlp.data import Stack, Tuple, Pad
34+
from paddlenlp.utils.tools import TimeCostAverage
3435
from paddlenlp.transformers import BertForPretraining, BertModel, BertPretrainingCriterion
3536
from paddlenlp.transformers import ErnieForPretraining, ErnieModel, ErniePretrainingCriterion
3637
from paddlenlp.transformers import BertTokenizer, ErnieTokenizer
@@ -377,13 +378,13 @@ def do_train(args):
377378
dataset_future = pool.submit(create_pretraining_dataset, data_file,
378379
args.max_predictions_per_seq,
379380
shared_file_list, args, worker_init)
380-
train_reader_cost = 0.0
381-
train_run_cost = 0.0
381+
train_cost_avg = TimeCostAverage()
382+
reader_cost_avg = TimeCostAverage()
382383
total_samples = 0
383-
reader_start = time.time()
384+
batch_start = time.time()
384385
for step, batch in enumerate(train_data_loader):
385-
train_reader_cost += time.time() - reader_start
386-
train_start = time.time()
386+
train_reader_cost = time.time() - batch_start
387+
reader_cost_avg.record(train_reader_cost)
387388
global_step += 1
388389
(input_ids, segment_ids, input_mask, masked_lm_positions,
389390
masked_lm_labels, next_sentence_labels,
@@ -407,22 +408,23 @@ def do_train(args):
407408
optimizer.step()
408409
lr_scheduler.step()
409410
optimizer.clear_grad()
410-
train_run_cost += time.time() - train_start
411411
total_samples += args.batch_size
412+
train_run_cost = time.time() - batch_start
413+
train_cost_avg.record(train_run_cost)
412414
if global_step % args.logging_steps == 0:
413415
if paddle.distributed.get_rank() == 0:
414416
logger.info(
415417
"global step: %d, epoch: %d, batch: %d, loss: %f, "
416418
"avg_reader_cost: %.5f sec, avg_batch_cost: %.5f sec, avg_samples: %.5f, ips: %.5f sequences/sec"
417419
% (global_step, epoch, step, loss,
418-
train_reader_cost / args.logging_steps,
419-
(train_reader_cost + train_run_cost) /
420-
args.logging_steps, total_samples /
421-
args.logging_steps, total_samples /
422-
(train_reader_cost + train_run_cost)))
423-
train_reader_cost = 0.0
424-
train_run_cost = 0.0
420+
reader_cost_avg.get_average(),
421+
train_cost_avg.get_average(), total_samples /
422+
args.logging_steps, total_samples / (
423+
args.logging_steps *
424+
train_cost_avg.get_average())))
425425
total_samples = 0
426+
train_cost_avg.reset()
427+
reader_cost_avg.reset()
426428
if global_step % args.save_steps == 0:
427429
if paddle.distributed.get_rank() == 0:
428430
output_dir = os.path.join(args.output_dir,
@@ -440,7 +442,7 @@ def do_train(args):
440442
if global_step >= args.max_steps:
441443
del train_data_loader
442444
return
443-
reader_start = time.time()
445+
batch_start = time.time()
444446

445447
del train_data_loader
446448
train_data_loader, data_file = dataset_future.result(timeout=None)

examples/language_model/bert/static/run_pretrain.py

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
import paddle.distributed.fleet as fleet
3030
from paddle.io import DataLoader, Dataset
3131

32+
from paddlenlp.utils.tools import TimeCostAverage
3233
from paddlenlp.transformers import BertForPretraining, BertModel, BertPretrainingCriterion
3334
from paddlenlp.transformers import BertTokenizer
3435
from paddlenlp.transformers import LinearDecayWithWarmup
@@ -367,34 +368,35 @@ def do_train(args):
367368
data_holders, worker_init,
368369
paddle.static.cuda_places())
369370

370-
train_reader_cost = 0.0
371-
train_run_cost = 0.0
371+
train_cost_avg = TimeCostAverage()
372+
reader_cost_avg = TimeCostAverage()
372373
total_samples = 0
373-
reader_start = time.time()
374+
batch_start = time.time()
374375
for step, batch in enumerate(train_data_loader):
375-
train_reader_cost += time.time() - reader_start
376+
train_reader_cost = time.time() - batch_start
377+
reader_cost_avg.record(train_reader_cost)
376378
global_step += 1
377379
train_start = time.time()
378380
loss_return = exe.run(main_program,
379381
feed=batch,
380382
fetch_list=[loss])
381-
train_run_cost += time.time() - train_start
382383
total_samples += args.batch_size
383384
# In the new 2.0 api, must call this function to change the learning_rate
384385
lr_scheduler.step()
386+
train_run_cost = time.time() - batch_start
387+
train_cost_avg.record(train_run_cost)
385388
if global_step % args.logging_steps == 0:
386389
print(
387390
"tobal step: %d, epoch: %d, batch: %d, loss: %f, "
388391
"avg_reader_cost: %.5f sec, avg_batch_cost: %.5f sec, avg_samples: %.5f, ips: %.5f sequences/sec"
389-
%
390-
(global_step, epoch, step, loss_return[0],
391-
train_reader_cost / args.logging_steps,
392-
(train_reader_cost + train_run_cost) /
393-
args.logging_steps, total_samples / args.logging_steps,
394-
total_samples / (train_reader_cost + train_run_cost)))
395-
train_reader_cost = 0.0
396-
train_run_cost = 0.0
392+
% (global_step, epoch, step, loss_return[0],
393+
reader_cost_avg.get_average(),
394+
train_cost_avg.get_average(),
395+
total_samples / args.logging_steps, total_samples /
396+
(args.logging_steps * train_cost_avg.get_average())))
397397
total_samples = 0
398+
train_cost_avg.reset()
399+
reader_cost_avg.reset()
398400
if global_step % args.save_steps == 0:
399401
if worker_index == 0:
400402
output_dir = os.path.join(args.output_dir,
@@ -410,7 +412,7 @@ def do_train(args):
410412
reader_start = time.time()
411413
del train_data_loader
412414
return
413-
reader_start = time.time()
415+
batch_start = time.time()
414416
del train_data_loader
415417
train_data_loader, data_file = dataset_future.result(timeout=None)
416418
epoch += 1

paddlenlp/utils/tools.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,3 +57,34 @@ def dygraph_params_to_static(model, dygraph_tensor_dict):
5757
ret_dict[parm.name] = dygraph_tensor_dict[name]
5858

5959
return ret_dict
60+
61+
62+
class TimeCostAverage(object):
63+
"""
64+
Simple tool for calcluating time average cost in the process of training and inferencing.
65+
"""
66+
67+
def __init__(self):
68+
self.reset()
69+
70+
def reset(self):
71+
"""
72+
Reset the recoder state, and reset the `cnt` to zero.
73+
"""
74+
self.cnt = 0
75+
self.total_time = 0
76+
77+
def record(self, usetime):
78+
"""
79+
Recoding the time cost in current step and accumulating the `cnt`.
80+
"""
81+
self.cnt += 1
82+
self.total_time += usetime
83+
84+
def get_average(self):
85+
"""
86+
Returning the average time cost after the start of training.
87+
"""
88+
if self.cnt == 0:
89+
return 0
90+
return self.total_time / self.cnt

0 commit comments

Comments
 (0)