Skip to content

Commit e23393f

Browse files
Fix tflops glu computation (#283)
* Fix tflops glu computation * Explain GLU TFLOPs difference * Fix typo * Specify MLP Co-authored-by: Thomas Wang <[email protected]> Co-authored-by: Thomas Wang <[email protected]>
1 parent cb48bd2 commit e23393f

File tree

1 file changed

+5
-1
lines changed

1 file changed

+5
-1
lines changed

megatron/training.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -692,7 +692,11 @@ def add_to_logging(name):
692692
# The factor of 4 is when used with activation check-pointing,
693693
# otherwise it will be 3, but for 200B model, activation check-pointing will always be on.
694694
checkpoint_activations_factor = 4 if args.checkpoint_activations else 3
695-
flops_per_iteration = (24 * checkpoint_activations_factor * batch_size * seq_len * num_layers * (hidden_size**2)) * (1. + (seq_len / (6. * hidden_size)) + (vocab_size / (16. * num_layers * hidden_size)))
695+
# GLU activations double the hidden states in the upscaling feed-forward in each transformer layer
696+
# This leads to 16bsh^2 instead of 8bsh^2 per first feed-forward layer in MLP, thus we increase the coefficient by 8.
697+
# Refer to https://github.com/bigscience-workshop/Megatron-DeepSpeed/pull/283#issue-1260805063 for more details.
698+
coefficient = 32 if args.glu_activation else 24
699+
flops_per_iteration = (coefficient * checkpoint_activations_factor * batch_size * seq_len * num_layers * (hidden_size**2)) * (1. + (seq_len / (6. * hidden_size)) + (vocab_size / (16. * num_layers * hidden_size)))
696700
tflops = flops_per_iteration / (elapsed_time_per_iteration * args.world_size * (10**12))
697701

698702
# only the last rank process has a non-None _GLOBAL_TENSORBOARD_WRITER

0 commit comments

Comments
 (0)