Skip to content
Draft
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 18 additions & 8 deletions megatron/core/pipeline_parallel/schedules.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,9 +399,9 @@ def forward_step(
context_manager = contextlib.nullcontext()
with context_manager:
if checkpoint_activations_microbatch is None:
output_tensor, loss_func = forward_step_func(data_iterator, model)
output_tensor, loss_func, num_empty_bins = forward_step_func(data_iterator, model)
else:
output_tensor, loss_func = forward_step_func(
output_tensor, loss_func, num_empty_bins = forward_step_func(
data_iterator, model, checkpoint_activations_microbatch
)
output_tensor, num_tokens = forward_step_calc_loss(
Expand All @@ -418,8 +418,8 @@ def forward_step(
)

if unwrap_output_tensor:
return output_tensor, num_tokens
return [output_tensor], num_tokens
return output_tensor, num_tokens, num_empty_bins
return [output_tensor], num_tokens, num_empty_bins


def backward_step(input_tensor, output_tensor, output_tensor_grad, model_type, config):
Expand Down Expand Up @@ -573,6 +573,7 @@ def forward_backward_no_pipelining(
total_num_tokens = torch.zeros([], dtype=torch.int, device="cuda")

if config.overlap_moe_expert_parallel_comm and not forward_only:
num_empty_bins = 0
forward_data_store, total_num_tokens = combined_1f1b_schedule_for_no_pipelining(
forward_step_func,
data_iterator,
Expand All @@ -592,7 +593,7 @@ def forward_backward_no_pipelining(
else:
with no_sync_func():
for i in range(num_microbatches - 1):
output_tensor, num_tokens = forward_step(
output_tensor, num_tokens, num_empty_bins = forward_step(
forward_step_func,
data_iterator,
model,
Expand All @@ -612,7 +613,7 @@ def forward_backward_no_pipelining(
)
# Run computation for last microbatch out of context handler (want to
# synchronize gradients).
output_tensor, num_tokens = forward_step(
output_tensor, num_tokens, num_empty_bins = forward_step(
forward_step_func,
data_iterator,
model,
Expand Down Expand Up @@ -652,6 +653,8 @@ def forward_backward_no_pipelining(
):
create_cudagraphs()

forward_data_store.append(num_empty_bins)

return forward_data_store


Expand Down Expand Up @@ -1348,6 +1351,7 @@ def forward_backward_helper_wrapper(
return forward_output_tensor, backward_input_tensor_grad

# ==============================main logic=========================================
num_empty_bins = 0
_is_vp_first_stage = partial(
is_vp_first_stage, vp_size=config.virtual_pipeline_model_parallel_size
)
Expand Down Expand Up @@ -1917,6 +1921,8 @@ def pp_post_backward(input_tensor_grad, vp_stage=None):
create_cudagraphs()
nvtx_range_pop(suffix="misc")

forward_data_store.append(num_empty_bins)

return forward_data_store


Expand Down Expand Up @@ -1977,6 +1983,8 @@ def forward_backward_pipelining_without_interleaving(
data_iterator = data_iterator[0]

config = get_model_config(model)
num_empty_bins = 0

if config.overlap_p2p_comm:
raise ValueError(
"Non-interleaved pipeline parallelism does not support overlapping p2p communication"
Expand Down Expand Up @@ -2132,7 +2140,7 @@ def enable_grad_sync():
input_tensor = p2p_communicator.recv_forward(
recv_tensor_shapes, is_pp_first_stage(p2p_communicator.pp_group)
)
output_tensor, num_tokens = forward_step(
output_tensor, num_tokens, num_empty_bins = forward_step(
forward_step_func,
data_iterator,
model,
Expand Down Expand Up @@ -2175,7 +2183,7 @@ def enable_grad_sync():
else:
checkpoint_activations_microbatch = None

output_tensor, num_tokens = forward_step(
output_tensor, num_tokens, num_empty_bins = forward_step(
forward_step_func,
data_iterator,
model,
Expand Down Expand Up @@ -2300,4 +2308,6 @@ def enable_grad_sync():
):
create_cudagraphs()

forward_data_store.append(num_empty_bins)

return forward_data_store
8 changes: 8 additions & 0 deletions megatron/rl/rl_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1574,6 +1574,13 @@ def prepare_data_for_update(
bin_idx = current_bins + i
packing_info['bin_seq_indices'].append(entry['bin_seq_indices'])
packing_info['seq_starts'][bin_idx] = entry['seq_starts']
else:
num_empty_bins = 0

# Sum the total number of empty bins across the ranks
empty_bin_count = torch.tensor([num_empty_bins], device='cuda')
torch.distributed.all_reduce(empty_bin_count, op=torch.distributed.ReduceOp.SUM)
num_empty_bins = empty_bin_count[0].cpu().numpy()

packing_context['packing_info'] = packing_info
packing_context['original_generation_masks'] = generation_masks
Expand Down Expand Up @@ -1991,6 +1998,7 @@ def prepare_data_for_update(
'num_sequences': len(packing_info['seq_lengths']),
'avg_seqs_per_bin': global_avg_seqs_per_bin,
'avg_seqs_per_bin_this_rank': actual_seqs_per_bin_this_rank,
'num_empty_bins': num_empty_bins, # summed across ranks
}

if args.micro_batch_size != 1:
Expand Down
2 changes: 2 additions & 0 deletions megatron/training/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -1968,6 +1968,8 @@ def _add_logging_args(parser):
help='Path to save the wandb results locally.')
group.add_argument('--logging-level', type=int, default=None,
help='Set default logging level')
group.add_argument('--log-tokens-per-second', default=False, action="store_true",
help='Whether to log tokens per second.')
return parser


Expand Down
14 changes: 13 additions & 1 deletion megatron/training/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -1434,10 +1434,13 @@ def train_step(forward_step_func, data_iterator, model, optimizer, opt_param_sch
if args.empty_unused_memory_level >= 2:
torch.cuda.empty_cache()

num_empty_bins = 0
if mpu.is_pipeline_last_stage(ignore_virtual=True):
# Average loss across microbatches.
loss_reduced = {}

num_empty_bins = losses_reduced.pop()

for key in losses_reduced[0].keys():
val = [x[key].view(-1) for x in losses_reduced]
if val[0].numel() == 2:
Expand Down Expand Up @@ -1478,8 +1481,9 @@ def train_step(forward_step_func, data_iterator, model, optimizer, opt_param_sch
grad_norm,
num_zeros_in_grad,
log_max_attention_logit,
num_empty_bins
)
return {}, skipped_iter, should_checkpoint, should_exit, exit_code, grad_norm, num_zeros_in_grad, log_max_attention_logit
return {}, skipped_iter, should_checkpoint, should_exit, exit_code, grad_norm, num_zeros_in_grad, log_max_attention_logit, num_empty_bins


def training_log(
Expand All @@ -1494,6 +1498,7 @@ def training_log(
params_norm,
num_zeros_in_grad,
max_attention_logit,
num_empty_bins,
):
"""Log training information such as losses, timing, ...."""
args = get_args()
Expand Down Expand Up @@ -1765,6 +1770,8 @@ def training_log(
total_loss_dict[skipped_iters_key]
)
log_string += ' number of nan iterations: {:3d} |'.format(total_loss_dict[nan_iters_key])
if args.log_tokens_per_second:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This sounds like an extremely useful metric, should we add it to W&B metrics too?

log_string += f' tokens per second: {((batch_size - num_empty_bins) * args.seq_length) / elapsed_time_per_iteration}'
total_loss_dict[advanced_iters_key] = 0
total_loss_dict[skipped_iters_key] = 0
total_loss_dict[nan_iters_key] = 0
Expand Down Expand Up @@ -2432,6 +2439,7 @@ def get_e2e_base_metrics():
grad_norm,
num_zeros_in_grad,
max_attention_logit,
num_empty_bins
) = train_step(
forward_step_func, train_data_iterator, model, optimizer, opt_param_scheduler, config, forward_backward_func
)
Expand Down Expand Up @@ -2530,6 +2538,7 @@ def get_e2e_base_metrics():
params_norm,
num_zeros_in_grad,
max_attention_logit,
num_empty_bins
)

# Evaluation.
Expand Down Expand Up @@ -2707,6 +2716,9 @@ def evaluate(
torch.cuda.empty_cache()

if mpu.is_pipeline_last_stage(ignore_virtual=True):

_ = loss_dicts.pop()

# Reduce across processes.
for key in loss_dicts[0].keys():
if key not in total_loss_dict:
Expand Down
3 changes: 2 additions & 1 deletion pretrain_gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ def forward_step(data_iterator, model: GPTModel, return_schedule_plan: bool = Fa
"""
args = get_args()
timers = get_timers()
num_empty_bins = 0 # Number of padding bins from the data loader

# Get the batch.
timers('batch-generator', log_level=2).start()
Expand All @@ -155,7 +156,7 @@ def forward_step(data_iterator, model: GPTModel, return_schedule_plan: bool = Fa
)

# [ModelOpt]: model is needed to access ModelOpt distillation losses
return output_tensor, partial(loss_func, loss_mask, model=model)
return output_tensor, partial(loss_func, loss_mask, model=model), num_empty_bins


def is_dataset_built_on_rank(vp_stage=None):
Expand Down
7 changes: 6 additions & 1 deletion train_rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,11 @@ def forward_step(data_iterator, model: GPTModel, loss_only: bool = False):
)
)

if runtime_state.sequence_packing_metadata:
num_empty_bins = runtime_state.sequence_packing_metadata['num_empty_bins']
else:
num_empty_bins = None

# loss_mask will not be applied to 0th token as we do not have a logprob for it.
return loss, partial(
loss_func,
Expand All @@ -338,7 +343,7 @@ def forward_step(data_iterator, model: GPTModel, loss_only: bool = False):
entropy_term,
truncated_from_above,
truncated_from_below,
)
), num_empty_bins


def train_valid_test_datasets_provider(train_val_test_num_samples):
Expand Down
Loading