Skip to content

Commit ae6277f

Browse files
authored
Replace approximate formula with exact one for throughput (#251)
1 parent 541b967 commit ae6277f

File tree

1 file changed

+8
-8
lines changed

1 file changed

+8
-8
lines changed

megatron/training.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -668,23 +668,23 @@ def add_to_logging(name):
668668
elapsed_time_per_iteration = elapsed_time / total_iterations
669669

670670
seq_len = args.curriculum_seqlen if args.curriculum_learning else args.seq_length
671+
hidden_size = args.hidden_size
672+
num_layers = args.num_layers
673+
vocab_size = args.padded_vocab_size
671674

672-
# throughput
675+
# Compute throughput.
673676
samples_per_sec = batch_size / elapsed_time_per_iteration
674677
samples_per_sec_per_replica = samples_per_sec / args.data_parallel_size
675678
tokens_per_sec = samples_per_sec * seq_len
676679
tokens_per_sec_per_replica = tokens_per_sec / args.data_parallel_size
677680

678-
# general TFLOPs formula
679-
# model_size_in_B * 4 * 2 * seqlen * global_batch_size / (time_in_sec_per_interation * total_gpus * 1e3)
680-
#
681+
# General TFLOPs formula (borrowed from Equation 3 in Section 5.1 of
682+
# https://arxiv.org/pdf/2104.04473.pdf).
681683
# The factor of 4 is when used with activation check-pointing,
682684
# otherwise it will be 3, but for 200B model, activation check-pointing will always be on.
683-
#
684-
# here:
685-
# model_size_in_B * 4 * 2 * seqlen * batch_size / (time_in_msec_per_interation * total_gpus * 1e3)
686685
checkpoint_activations_factor = 4 if args.checkpoint_activations else 3
687-
tflops = args.parameters_in_billions_no_embedding * checkpoint_activations_factor * 2 * seq_len * batch_size / (elapsed_time_per_iteration * args.world_size * 1e3)
686+
flops_per_iteration = (24 * checkpoint_activations_factor * batch_size * seq_len * num_layers * (hidden_size**2)) * (1. + (seq_len / (6. * hidden_size)) + (vocab_size / (16. * num_layers * hidden_size)))
687+
tflops = flops_per_iteration / (elapsed_time_per_iteration * args.world_size * (10**12))
688688

689689
# only the last rank process has a non-None _GLOBAL_TENSORBOARD_WRITER
690690
if writer and is_last_rank():

0 commit comments

Comments
 (0)