Skip to content

Commit e0c90c5

Browse files
committed
debugging convergence issue caused by padding tokens
Signed-off-by: tailaim <tailaim@nvidia.com>
1 parent 39708c1 commit e0c90c5

File tree

8 files changed

+237
-20
lines changed

8 files changed

+237
-20
lines changed

megatron/core/models/gpt/gpt_model.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,43 @@
3737
from megatron.core.transformer.transformer_config import TransformerConfig
3838
from megatron.core.utils import WrappedTensor, deprecate_inference_params
3939

40+
# #debugmtl
41+
# def get_debug_hook(layer_name):
42+
# """
43+
# 这是一个“生产 Hook 的工厂”。
44+
# 调用它会返回一个已经记住了 layer_name 的 hook 函数。
45+
# """
46+
# def hook(grad_output):
47+
# # 如果没梯度或者梯度为空,直接跳过
48+
# if grad_output is None:
49+
# print(f"[Rank {rank}] [BWD] {layer_name:25s} | grad_output is None")
50+
# return
51+
52+
# g = grad_output[0]
53+
# if g is None:
54+
# return
55+
# if torch.distributed.is_initialized():
56+
# rank = torch.distributed.get_rank()
57+
# # if rank == 0:
58+
# # 简单的统计
59+
# g_float = g.float()
60+
# g_max = g_float.max().item()
61+
# g_min = g_float.min().item()
62+
# g_mean = g_float.mean().item()
63+
# g_norm = torch.linalg.vector_norm(g_float, ord=2).item()
64+
# has_nan = torch.isnan(g_float).any().item()
65+
66+
# # 【关键】这里可以直接打印 layer_name
67+
# print(f"[Rank {rank}] [BWD] {layer_name:25s} | "
68+
# f"Max: {g_max:.4e} | Min: {g_min:.4e} | Mean: {g_mean:.4e} | "
69+
# f"Norm: {g_norm:.4e} | NaN: {has_nan}")
70+
71+
# # 如果发现 NaN,可以加个断点或者报错
72+
# # if has_nan:
73+
# # raise RuntimeError(f"NaN found in {layer_name}")
74+
75+
# return hook
76+
4077

4178
class GPTModel(LanguageModule):
4279
"""GPT Transformer language model.
@@ -475,6 +512,9 @@ def forward(
475512
preproc_output[:5]
476513
)
477514

515+
# #debugmtl
516+
# decoder_input.register_hook(get_debug_hook("Embedding_Output"))
517+
478518
rotary_pos_cos_sin = preproc_output[5] if len(preproc_output) == 6 else None
479519

480520
# Run decoder.
@@ -491,6 +531,9 @@ def forward(
491531
**(extra_block_kwargs or {}),
492532
)
493533

534+
# #debugmtl
535+
# hidden_states.register_hook(get_debug_hook("Decoder_Output_Before_Head"))
536+
494537
return self._postprocess(
495538
hidden_states=hidden_states,
496539
input_ids=input_ids,
@@ -633,6 +676,8 @@ def _postprocess(
633676
hidden_states, weight=output_weight, runtime_gather_output=runtime_gather_output
634677
)
635678

679+
# #debugmtl
680+
# logits.register_hook(get_debug_hook("Logits_Output"))
636681
# Restore sequence parallel execution to the output layer if necessary.
637682
if sequence_parallel_override:
638683
assert (

megatron/core/parallel_state.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -970,6 +970,17 @@ def initialize_model_parallel(
970970
if rank in ranks:
971971
_HIERARCHICAL_CONTEXT_PARALLEL_GROUPS = hierarchical_groups
972972

973+
if hybrid_context_parallel:
974+
# PyTorch is performing lazy initialization of the communicator group.
975+
# Therefore, we need to perform a nccl call to ensure that the communicator group is created.
976+
group_sizes = [2**i for i in range(int(log2(data_parallel_size)))]
977+
if group_sizes[-1] * 2 == data_parallel_size:
978+
group_sizes.append(data_parallel_size)
979+
for group_size in group_sizes:
980+
group = get_hybrid_data_context_parallel_groups(group_size=group_size)
981+
torch.distributed.barrier(group=group, device_ids=[torch.cuda.current_device()])
982+
torch.cuda.synchronize()
983+
973984
# Build the model-parallel groups.
974985
global _MODEL_PARALLEL_GROUP
975986
global _MODEL_PARALLEL_GLOBAL_RANKS

megatron/core/pipeline_parallel/data_schedule.py

Lines changed: 49 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -329,6 +329,27 @@ def _broadcast(item):
329329

330330
groups, sample_id_groups = scheduler.get_groups_and_subsamples(global_id_seqlens, config)
331331

332+
# #debugmtl
333+
# if parallel_state.get_data_parallel_rank(with_context_parallel=True) == 0:
334+
# k = 0
335+
# for group in sample_id_groups:
336+
# print(f"group {k}: ",end="")
337+
# for i in range(len(group)):
338+
# print(f"GPU-{i}: [",end="")
339+
# for j in range(len(group[i])):
340+
# print(f"{group[i][j]}-{global_id_seqlens[group[i][j]][1]}, ",end=" ")
341+
# print(f"], ")
342+
# k += 1
343+
# print()
344+
345+
# debugmtl
346+
# set_gbs = set()
347+
# for group in sample_id_groups:
348+
# for sub in group:
349+
# set_gbs.update(sub)
350+
# assert len(set_gbs) == len(global_id_seqlens),
351+
# f"set_gbs length: {len(set_gbs)} != global_ids_this_rank length: {len(global_id_seqlens)}"
352+
332353
batch = _unpack_batch(batch)
333354
samples_this_rank_with_id = _reroute_samples_to_hdp_ranks(
334355
batch,
@@ -415,7 +436,9 @@ def _pack_tensors(tensors):
415436
new_sample["cu_seqlens"] = cu_seqlens
416437

417438
new_samples.append(new_sample)
418-
439+
# #debugmtl
440+
# print(f"rank {parallel_state.get_data_parallel_rank
441+
# (with_context_parallel=True)} new_samples length: {len(new_samples)}")
419442
new_data_iterator = RerunDataIterator(iter(new_samples))
420443

421444
return (
@@ -460,15 +483,28 @@ def get_groups_and_subsamples(self, sample_id_seqlens, config):
460483
sum_seqlen = 0
461484
single_microbatch = []
462485

486+
# debugmtl use 1 seq per microbatch
463487
for i in range(len(sample_id_seqlens)):
464-
if sum_seqlen + sample_id_seqlens[i][1] <= self.max_seq_len_all_ranks:
465-
single_microbatch.append(i)
466-
sum_seqlen += sample_id_seqlens[i][1]
467-
else:
468-
groups.append(single_microbatch)
469-
packed_id_groups.append(single_microbatch)
470-
single_microbatch = [i]
471-
sum_seqlen = sample_id_seqlens[i][1]
488+
packed_id_groups.append([i])
489+
490+
# for i in range(len(sample_id_seqlens)):
491+
# if sum_seqlen + sample_id_seqlens[i][1] <= self.max_seq_len_all_ranks:
492+
# single_microbatch.append(i)
493+
# sum_seqlen += sample_id_seqlens[i][1]
494+
# else:
495+
# packed_id_groups.append(single_microbatch)
496+
# single_microbatch = [i]
497+
# sum_seqlen = sample_id_seqlens[i][1]
498+
# if len(single_microbatch) > 0:
499+
# packed_id_groups.append(single_microbatch)
500+
501+
# debugmtl
502+
gbs_sum = 0
503+
for i in packed_id_groups:
504+
gbs_sum += len(i)
505+
assert gbs_sum == len(
506+
sample_id_seqlens
507+
), f"gbs_sum: {gbs_sum} != sample_id_seqlens length: {len(sample_id_seqlens)}"
472508

473509
# we want the number of packed sequences to be multiple of dp_size
474510
# so we move few samples from previous microbatch
@@ -482,7 +518,7 @@ def get_groups_and_subsamples(self, sample_id_seqlens, config):
482518
assert i > 0, "Not enough samples to move"
483519
if len(packed_id_groups[i]) > 1:
484520
seq_id = packed_id_groups[i].pop()
485-
packed_id_groups[i].append(seq_id)
521+
packed_id_groups.append([seq_id])
486522
num_to_move -= 1
487523
else:
488524
i -= 1
@@ -493,7 +529,9 @@ def get_groups_and_subsamples(self, sample_id_seqlens, config):
493529
for j in range(self.cp_size * self.dp_size):
494530
seq_id = int(i * self.dp_size + j / self.cp_size)
495531
sample_id_groups[i].append(packed_id_groups[seq_id])
496-
532+
# debugmtl
533+
# print(f"rank {parallel_state.get_data_parallel_rank(with_context_parallel=True)} \
534+
# sample_id_groups: {len(sample_id_groups)}")
497535
return groups, sample_id_groups
498536

499537

megatron/core/pipeline_parallel/schedules.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,8 @@ def deallocate_output_tensor(out, deallocate_pipeline_outputs=False):
146146
if (out is None) or (not deallocate_pipeline_outputs):
147147
return
148148
assert isinstance(out, torch.Tensor), "expected Tensor, found %s." % type(out).__name__
149-
assert out._base is None, "counter-productive to free a view of another tensor."
149+
# debugmtl
150+
# assert out._base is None, "counter-productive to free a view of another tensor."
150151
out.data = torch.empty((1,), device=out.device, dtype=out.dtype)
151152

152153

megatron/core/transformer/moe/token_dispatcher.py

Lines changed: 43 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -672,12 +672,49 @@ def token_dispatch(self, permutated_local_input_tokens, permuted_probs):
672672
self.tokens_per_expert = self._maybe_dtoh_and_synchronize(
673673
"before_ep_alltoall", self.tokens_per_expert
674674
)
675-
global_input_tokens = all_to_all(
676-
self.ep_group, permutated_local_input_tokens, self.output_splits, self.input_splits
677-
)
678-
global_probs = all_to_all(
679-
self.ep_group, permuted_probs, self.output_splits, self.input_splits
680-
)
675+
# debugmtl
676+
# global_input_tokens = all_to_all(
677+
# self.ep_group, permutated_local_input_tokens,
678+
# self.output_splits, self.input_splits
679+
# )
680+
# global_probs = all_to_all(
681+
# self.ep_group, permuted_probs, self.output_splits,
682+
# self.input_splits
683+
# )
684+
try:
685+
global_input_tokens = all_to_all(
686+
self.ep_group, permutated_local_input_tokens, self.output_splits, self.input_splits
687+
)
688+
global_probs = all_to_all(
689+
self.ep_group, permuted_probs, self.output_splits, self.input_splits
690+
)
691+
except RuntimeError as e:
692+
# 获取 EP group 内的 rank(防止 group 还没初始化时报错)
693+
try:
694+
rank = torch.distributed.get_rank(self.ep_group)
695+
except Exception:
696+
rank = -1
697+
698+
print(f"[MoE all_to_all error] rank={rank}, err={e}")
699+
print(
700+
f"[MoE all_to_all debug] "
701+
f"tokens_shape={getattr(permutated_local_input_tokens, 'shape', None)}, "
702+
f"probs_shape={getattr(permuted_probs, 'shape', None)}"
703+
)
704+
print(
705+
f"[MoE all_to_all debug] "
706+
f"input_splits={self.input_splits}, sum={sum(self.input_splits) \
707+
if self.input_splits is not None else None}, "
708+
f"output_splits={self.output_splits}, sum={sum(self.output_splits) \
709+
if self.output_splits is not None else None}"
710+
)
711+
print(
712+
f"[MoE all_to_all debug] "
713+
f"tokens_per_expert={self.tokens_per_expert}, "
714+
f"sum={self.tokens_per_expert.sum() if \
715+
hasattr(self.tokens_per_expert, 'sum') else None}"
716+
)
717+
raise
681718

682719
return global_input_tokens, global_probs
683720

megatron/training/datasets/sft_dataset.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,8 @@ def __getitem__(self, idx: int) -> Dict[str, Any]:
124124
num_tokens = len(tokens) + force_eod_length
125125
if sft_sequence_packing:
126126
padding_len = self.get_padding_size(num_tokens) - num_tokens
127+
# debugmtl
128+
# padding_len = max_seq_len - num_tokens
127129
else:
128130
padding_len = max_seq_len - num_tokens
129131
assert padding_len >= 0

megatron/training/training.py

Lines changed: 69 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -476,7 +476,7 @@ def transformer_flops():
476476
)
477477
+
478478
# Self Attention
479-
standard_self_attn_term
479+
self_attn_term
480480

481481
)
482482
return total_floating_point_operations
@@ -1460,6 +1460,8 @@ def train_step(forward_step_func, data_iterator, model, optimizer, opt_param_sch
14601460
val,
14611461
group=mpu.get_data_parallel_group(with_context_parallel=True)
14621462
)
1463+
#debugmtl
1464+
print_rank_0(f"key: {key}, val: {val}")
14631465
loss_reduced[key] = val[0] / val[1]
14641466
elif val[0].numel() == 1:
14651467
# legacy behavior, we average over the number of microbatches
@@ -1747,6 +1749,9 @@ def training_log(
17471749
avg = total_loss_dict[key].item() / float(
17481750
max(1, total_loss_dict[advanced_iters_key])
17491751
)
1752+
#debugmtl
1753+
print_rank_0(f"in training_log, key: {key}, avg: {total_loss_dict[key].item()}, \
1754+
advanced_iters_key: {total_loss_dict[advanced_iters_key]}")
17501755
if avg > 0.0:
17511756
log_string += ' {}: {:.6E} |'.format(key, avg)
17521757
total_loss_dict[key] = torch.tensor([0.0], dtype=torch.float, device='cuda')
@@ -2070,7 +2075,70 @@ def train(
20702075
"""Training function: run train_step desired number of times, run validation, checkpoint."""
20712076
args = get_args()
20722077
timers = get_timers()
2078+
# debugmtl
2079+
def get_debug_hook(layer_name):
2080+
"""
2081+
这是一个“生产 Hook 的工厂”。
2082+
调用它会返回一个已经记住了 layer_name 的 hook 函数。
2083+
"""
2084+
def hook(module, grad_input, grad_output):
2085+
# 如果没梯度或者梯度为空,直接跳过
2086+
if not grad_output:
2087+
return
2088+
2089+
g = grad_output[0]
2090+
if g is None:
2091+
return
2092+
if torch.distributed.is_initialized():
2093+
rank = torch.distributed.get_rank()
2094+
if rank == 0:
2095+
# 简单的统计
2096+
g_float = g.float()
2097+
g_max = g_float.max().item()
2098+
g_min = g_float.min().item()
2099+
g_mean = g_float.mean().item()
2100+
g_norm = torch.linalg.vector_norm(g_float, ord=2).item()
2101+
has_nan = torch.isnan(g_float).any().item()
2102+
2103+
# 【关键】这里可以直接打印 layer_name
2104+
print(f"[Rank {rank}] [BWD] {layer_name:25s} | "
2105+
f"Max: {g_max:.4e} | Min: {g_min:.4e} | Mean: {g_mean:.4e} | "
2106+
f"Norm: {g_norm:.4e} | NaN: {has_nan}")
2107+
2108+
# 如果发现 NaN,可以加个断点或者报错
2109+
# if has_nan:
2110+
# raise RuntimeError(f"NaN found in {layer_name}")
2111+
2112+
return hook
2113+
2114+
for chunk_id, model_chunk in enumerate(model):
2115+
prefix = f"Chunk{chunk_id}"
2116+
gpt_model = model_chunk.module.module
2117+
# --- 注册 Embedding ---
2118+
if hasattr(gpt_model, 'embedding'):
2119+
# 传入名字 "Embedding"
2120+
gpt_model.embedding.register_full_backward_hook(
2121+
get_debug_hook(f"{prefix}.Embedding")
2122+
)
2123+
2124+
if hasattr(gpt_model, 'output_layer'):
2125+
# 传入名字 "Embedding"
2126+
gpt_model.output_layer.register_full_backward_hook(
2127+
get_debug_hook(f"{prefix}.OutputLayer")
2128+
)
20732129

2130+
# --- 注册 Decoder Layers ---
2131+
if hasattr(gpt_model, 'decoder') and hasattr(gpt_model.decoder, 'layers'):
2132+
for i, layer in enumerate(gpt_model.decoder.layers):
2133+
# 传入名字 "Layer_0", "Layer_1" ...
2134+
layer.register_full_backward_hook(
2135+
get_debug_hook(f"{prefix}.Layer_{i}")
2136+
)
2137+
2138+
print_rank_0(f">>> {prefix} backward debug hook registered")
2139+
print_rank_0(f"model chunk is: {model_chunk.module.module}")
2140+
2141+
20742142
if getattr(args, 'perform_rl_step', False):
20752143
assert has_rl_utils, "RL cannot run without the megatron.rl package"
20762144

pretrain_gpt.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,6 @@ def get_batch(data_iterator, vp_stage: Optional[int] = None):
5858
"""Generate a batch."""
5959
args = get_args()
6060
config = core_transformer_config_from_args(args)
61-
args = get_args()
6261

6362
# TODO: this is pretty hacky, find a better way
6463
if not is_first_or_last_pipeline_stage(vp_stage) and (
@@ -83,6 +82,22 @@ def get_batch(data_iterator, vp_stage: Optional[int] = None):
8382
cu_seqlens_padded, max_seqlen, local_cp_size=local_cp_size)
8483

8584
else:
85+
# #debugmtl
86+
# sample_length = batch['tokens'].shape[1]
87+
# if args.sft:
88+
# packed_seq_params = PackedSeqParams(
89+
# qkv_format="sbhd",
90+
# cu_seqlens_q=torch.tensor([0, sample_length], device="cuda", pin_memory=True),
91+
# cu_seqlens_kv=torch.tensor([0, sample_length], device="cuda", pin_memory=True),
92+
# cu_seqlens_q_padded=torch.tensor([0, sample_length], device="cuda", pin_memory=True),
93+
# cu_seqlens_kv_padded=torch.tensor([0, sample_length], device="cuda", pin_memory=True),
94+
# max_seqlen_q=sample_length,
95+
# max_seqlen_kv=sample_length,
96+
# local_cp_size=None,
97+
# cp_group=None,
98+
# )
99+
# else:
100+
# packed_seq_params = None
86101
# slice batch along sequence dimension for context parallelism
87102
batch = get_batch_on_this_cp_rank(batch) # The implementation of this function is in MCore
88103
packed_seq_params = None

0 commit comments

Comments
 (0)