Skip to content

Commit 155ce98

Browse files
authored
Add tgs metrics (bigscience-workshop#286)
1 parent 15355af commit 155ce98

File tree

2 files changed

+5
-3
lines changed

2 files changed

+5
-3
lines changed

megatron/training.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1035,13 +1035,17 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
10351035
samples_per_sec_per_replica = samples_per_sec / args.data_parallel_size
10361036
tokens_per_sec = samples_per_sec * seq_len
10371037
tokens_per_sec_per_replica = tokens_per_sec / args.data_parallel_size
1038+
tokens_per_gpu_per_second = tokens_per_sec / args.world_size
1039+
tokens_per_gpu_per_second_per_replica = tokens_per_gpu_per_second / args.data_parallel_size
10381040
if wandb is not None and getattr(wandb, 'run', None) is not None:
10391041
tput = {
10401042
'throughput/iteration-time': elapsed_time_per_iteration, # 1000 ms / s
10411043
'throughput/samples_per_sec': samples_per_sec,
10421044
'throughput/samples_per_sec_per_replica': samples_per_sec_per_replica,
10431045
'throughput/tokens_per_sec': tokens_per_sec,
10441046
'throughput/tokens_per_sec_per_replica': tokens_per_sec_per_replica,
1047+
'throughput/tokens_per_gpu_per_sec': tokens_per_gpu_per_second,
1048+
'throughput/tokens_per_gpu_per_sec_per_replica': tokens_per_gpu_per_second_per_replica,
10451049
'throughput/tflops': tflops,
10461050
'throughput/approx_params_in_billions': approx_parameters_in_billions,
10471051
'throughput/elapsed_ms_per_iteration': elapsed_time_per_iteration,
@@ -1091,6 +1095,7 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
10911095
log_string += ' number of nan iterations: {:3d} |'.format(
10921096
total_loss_dict[nan_iters_key])
10931097
log_string += ' samples per second: {:.3f} |'.format(samples_per_sec)
1098+
log_string += ' tokens per gpu per second (tgs): {:.3f} |'.format(tokens_per_gpu_per_second)
10941099
log_string += ' TFLOPs: {:.2f} |'.format(tflops)
10951100
total_loss_dict[advanced_iters_key] = 0
10961101
total_loss_dict[skipped_iters_key] = 0

megatron/utils.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -268,10 +268,7 @@ def get_parameters_in_billions(model):
268268
return approx_parameters_in_billions*gpus_per_model/(1e9)
269269

270270
def throughput_calculator(model, args, iteration_time, total_iterations):
271-
gpus_per_model = torch.distributed.get_world_size(group = mpu.get_model_parallel_group())
272271
batch_size = args.micro_batch_size * get_num_microbatches() * args.data_parallel_size
273-
samples_per_model = batch_size * args.seq_length
274-
model_replica_count = torch.distributed.get_world_size() / gpus_per_model
275272
approx_parameters_in_billions = None if (model is None) else get_parameters_in_billions(model)
276273
elapsed_time_per_iter = iteration_time/total_iterations
277274
samples_per_second = batch_size / elapsed_time_per_iter

0 commit comments

Comments
 (0)