Skip to content

Commit 1167117

Browse files
committed
debugging nan issue when using FSDP+THD
Signed-off-by: tailaim <tailaim@nvidia.com>
1 parent be8d859 commit 1167117

File tree

14 files changed

+422
-20
lines changed

14 files changed

+422
-20
lines changed

megatron/core/distributed/fsdp/src/megatron_fsdp/param_and_grad_buffer.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3342,6 +3342,9 @@ def wait_bucket_ready(self, bucket_id, empty_ok=False):
33423342
# Wait for asynchronous / overlapped NCCL operations to complete.
33433343
param_gather_event, mark_bucket_ready_to_use = self.param_gather_event_map.pop(bucket_id)
33443344
param_gather_event.wait()
3345+
# debugmtl
3346+
if self.ag_stream is not None:
3347+
torch.cuda.current_stream().wait_stream(self.ag_stream)
33453348
mark_bucket_ready_to_use()
33463349

33473350
@torch.no_grad()

megatron/core/models/gpt/gpt_model.py

Lines changed: 59 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,65 @@
3737
from megatron.core.transformer.transformer_config import TransformerConfig
3838
from megatron.core.utils import WrappedTensor, deprecate_inference_params
3939

40+
# #debugmtl
41+
# _ACT_GRAD_DUMP_COUNTERS = {}
42+
43+
# def _sanitize_name(name: str) -> str:
44+
# return str(name).replace('/', '_').replace('\\', '_').replace('.', '_').replace(' ', '_')
45+
46+
# def _next_act_dump_index(rank: int, layer_name: str) -> int:
47+
# key = (rank, layer_name)
48+
# cnt = _ACT_GRAD_DUMP_COUNTERS.get(key, 0) + 1
49+
# _ACT_GRAD_DUMP_COUNTERS[key] = cnt
50+
# return cnt
51+
52+
# def get_debug_hook(layer_name: str):
53+
# """
54+
# Tensor-level grad hook: save activation grad by (rank, layer_name, index).
55+
# """
56+
# import os
57+
# def hook(grad: torch.Tensor):
58+
# if grad is None:
59+
# return
60+
61+
# rank = torch.distributed.get_rank() if torch.distributed.is_initialized() else 0
62+
63+
# # 基础目录自行改成你想要的
64+
# base_dir = "/home/tailaim/act_grad_dump"
65+
# if not base_dir:
66+
# return
67+
68+
# try:
69+
# idx = _next_act_dump_index(rank, layer_name)
70+
# layer_dir = os.path.join(
71+
# base_dir,
72+
# f"rank_{rank}",
73+
# _sanitize_name(layer_name),
74+
# )
75+
# os.makedirs(layer_dir, exist_ok=True)
76+
# file_path = os.path.join(layer_dir, f"grad_{idx:06d}.pt")
77+
78+
# # 只前几次写盘,避免太多文件
79+
# if idx <= 16:
80+
# torch.save(grad.detach().cpu(), file_path)
81+
82+
# # 只在第一次写时打印一行日志
83+
# if idx == 1:
84+
# try:
85+
# g_shape = tuple(grad.shape)
86+
# g_dtype = str(grad.dtype)
87+
# except Exception:
88+
# g_shape = "unknown"
89+
# g_dtype = "unknown"
90+
# print(
91+
# f"[Rank {rank}] Saved act grad: layer={layer_name}, "
92+
# f"idx={idx:06d}, shape={g_shape}, dtype={g_dtype}, path={file_path}"
93+
# )
94+
# except Exception as e:
95+
# print(f"[Rank {rank}] act grad dump failed for {layer_name}: {e}")
96+
97+
# return hook
98+
4099

41100
class GPTModel(LanguageModule):
42101
"""GPT Transformer language model.
@@ -640,7 +699,6 @@ def _postprocess(
640699
hidden_states, weight=output_weight, runtime_gather_output=runtime_gather_output
641700
)
642701

643-
# Restore sequence parallel execution to the output layer if necessary.
644702
if sequence_parallel_override:
645703
assert (
646704
in_inference_mode

megatron/core/parallel_state.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -970,6 +970,17 @@ def initialize_model_parallel(
970970
if rank in ranks:
971971
_HIERARCHICAL_CONTEXT_PARALLEL_GROUPS = hierarchical_groups
972972

973+
if hybrid_context_parallel:
974+
# PyTorch is performing lazy initialization of the communicator group.
975+
# Therefore, we need to perform a nccl call to ensure that the communicator group is created.
976+
group_sizes = [2**i for i in range(int(log2(data_parallel_size)))]
977+
if group_sizes[-1] * 2 == data_parallel_size:
978+
group_sizes.append(data_parallel_size)
979+
for group_size in group_sizes:
980+
group = get_hybrid_data_context_parallel_groups(group_size=group_size)
981+
torch.distributed.barrier(group=group, device_ids=[torch.cuda.current_device()])
982+
torch.cuda.synchronize()
983+
973984
# Build the model-parallel groups.
974985
global _MODEL_PARALLEL_GROUP
975986
global _MODEL_PARALLEL_GLOBAL_RANKS

megatron/core/pipeline_parallel/data_schedule.py

Lines changed: 51 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,13 @@
1010
import torch
1111

1212
from megatron.core import parallel_state
13+
14+
# from megatron.core.pipeline_parallel.utils import (
15+
# is_pp_first_stage,
16+
# is_pp_last_stage,
17+
# is_vp_first_stage,
18+
# is_vp_last_stage,
19+
# )
1320
from megatron.core.process_groups_config import ProcessGroupCollection
1421
from megatron.core.rerun_state_machine import RerunDataIterator
1522

@@ -293,17 +300,24 @@ def _broadcast(item):
293300
dp_cp_group = parallel_state.get_data_parallel_group(with_context_parallel=True)
294301
dp_group = parallel_state.get_data_parallel_group()
295302
tp_group = parallel_state.get_tensor_model_parallel_group()
303+
pp_group = parallel_state.get_pipeline_model_parallel_group()
296304
else:
297305
dp_cp_group = pg_collection.dp_cp
298306
dp_group = pg_collection.dp
299307
tp_group = pg_collection.tp
308+
pp_group = pg_collection.pp
300309
assert (
301310
dp_cp_group is not None and dp_group is not None and tp_group is not None
302311
), "dp_cp_group, dp_group, tp_group must not be None when using hybrid context parallel"
303312

304313
total_hdp_gpus = dp_cp_group.size()
305314
dev = torch.cuda.current_device()
306315

316+
# if is_pp_first_stage(pp_group) or is_pp_last_stage(pp_group) and tp_group.rank() == 0:
317+
# # do what data_iterator is doing
318+
319+
# # first stage tp-0 broadcast num_micro_batches cu_seqlens to
320+
307321
if data_iterator is None:
308322
# TP-0 reads from data_iterator, others receive via broadcast.
309323
sample_id_groups, batch = None, None
@@ -329,6 +343,16 @@ def _broadcast(item):
329343

330344
groups, sample_id_groups = scheduler.get_groups_and_subsamples(global_id_seqlens, config)
331345

346+
# debugmtl
347+
set_gbs = set()
348+
for group in sample_id_groups:
349+
for sub in group:
350+
set_gbs.update(sub)
351+
assert len(set_gbs) == len(
352+
global_id_seqlens
353+
), f"set_gbs length: {len(set_gbs)} \
354+
!= global_ids_this_rank length: {len(global_id_seqlens)}"
355+
332356
batch = _unpack_batch(batch)
333357
samples_this_rank_with_id = _reroute_samples_to_hdp_ranks(
334358
batch,
@@ -384,9 +408,10 @@ def _pack_tensors(tensors):
384408
new_sample["labels"] = labels
385409
new_sample["loss_mask"] = loss_mask
386410
new_sample["position_ids"] = position_ids
387-
new_sample["local_cp_size"] = torch.tensor(
388-
partner_cp_size, dtype=torch.int32, device=dev
389-
)
411+
if scheduler_type is PackingScheduler.HYBRID_CP:
412+
new_sample["local_cp_size"] = torch.tensor(
413+
partner_cp_size, dtype=torch.int32, device=dev
414+
)
390415

391416
# create cu_seqlens_padded
392417
lengths_padding = np.fromiter(
@@ -415,7 +440,9 @@ def _pack_tensors(tensors):
415440
new_sample["cu_seqlens"] = cu_seqlens
416441

417442
new_samples.append(new_sample)
418-
443+
# #debugmtl
444+
# print(f"rank {parallel_state.get_data_parallel_rank
445+
# (with_context_parallel=True)} new_samples length: {len(new_samples)}")
419446
new_data_iterator = RerunDataIterator(iter(new_samples))
420447

421448
return (
@@ -460,15 +487,30 @@ def get_groups_and_subsamples(self, sample_id_seqlens, config):
460487
sum_seqlen = 0
461488
single_microbatch = []
462489

490+
# # debugmtl use 1 seq per microbatch
491+
# num_micro_batches = len(sample_id_seqlens)//self.dp_size
492+
# for i in range(num_micro_batches):
493+
# for j in range(self.dp_size):
494+
# packed_id_groups.append([i+j*num_micro_batches])
495+
463496
for i in range(len(sample_id_seqlens)):
464497
if sum_seqlen + sample_id_seqlens[i][1] <= self.max_seq_len_all_ranks:
465498
single_microbatch.append(i)
466499
sum_seqlen += sample_id_seqlens[i][1]
467500
else:
468-
groups.append(single_microbatch)
469501
packed_id_groups.append(single_microbatch)
470502
single_microbatch = [i]
471503
sum_seqlen = sample_id_seqlens[i][1]
504+
if len(single_microbatch) > 0:
505+
packed_id_groups.append(single_microbatch)
506+
507+
# debugmtl
508+
gbs_sum = 0
509+
for i in packed_id_groups:
510+
gbs_sum += len(i)
511+
assert gbs_sum == len(
512+
sample_id_seqlens
513+
), f"gbs_sum: {gbs_sum} != sample_id_seqlens length: {len(sample_id_seqlens)}"
472514

473515
# we want the number of packed sequences to be multiple of dp_size
474516
# so we move few samples from previous microbatch
@@ -482,7 +524,7 @@ def get_groups_and_subsamples(self, sample_id_seqlens, config):
482524
assert i > 0, "Not enough samples to move"
483525
if len(packed_id_groups[i]) > 1:
484526
seq_id = packed_id_groups[i].pop()
485-
packed_id_groups[i].append(seq_id)
527+
packed_id_groups.append([seq_id])
486528
num_to_move -= 1
487529
else:
488530
i -= 1
@@ -493,7 +535,9 @@ def get_groups_and_subsamples(self, sample_id_seqlens, config):
493535
for j in range(self.cp_size * self.dp_size):
494536
seq_id = int(i * self.dp_size + j / self.cp_size)
495537
sample_id_groups[i].append(packed_id_groups[seq_id])
496-
538+
# debugmtl
539+
# print(f"rank {parallel_state.get_data_parallel_rank(with_context_parallel=True)} \
540+
# sample_id_groups: {len(sample_id_groups)}")
497541
return groups, sample_id_groups
498542

499543

megatron/core/pipeline_parallel/schedules.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,8 @@ def deallocate_output_tensor(out, deallocate_pipeline_outputs=False):
147147
if (out is None) or (not deallocate_pipeline_outputs):
148148
return
149149
assert isinstance(out, torch.Tensor), "expected Tensor, found %s." % type(out).__name__
150-
assert out._base is None, "counter-productive to free a view of another tensor."
150+
# debugmtl
151+
# assert out._base is None, "counter-productive to free a view of another tensor."
151152
out.data = torch.empty((1,), device=out.device, dtype=out.dtype)
152153

153154

megatron/core/transformer/moe/moe_layer.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,10 @@ def route(self, hidden_states: torch.Tensor):
185185
producing routing probabilities and a mapping.
186186
"""
187187
probs, routing_map = self.router(hidden_states)
188+
# #debugmtl
189+
# true_per_row = routing_map.sum(dim=1) # tensor of shape [n]
190+
# assert torch.all(true_per_row == 8), \
191+
# f"routing_map row true counts not all 8, got: {true_per_row}"
188192
return probs, routing_map
189193

190194
@maybe_skip_or_early_return_by_cudagraph("preprocess")
@@ -290,6 +294,14 @@ def forward(self, hidden_states: torch.Tensor):
290294
"During training, performance may degrade if MoE and tensor parallelism"
291295
"are enabled without also enabling sequence parallelism."
292296
)
297+
# # debugmtl
298+
# if torch.isnan(hidden_states).any():
299+
# bad_mask = torch.isnan(hidden_states)
300+
# bad_idx = bad_mask.nonzero(as_tuple=False)[:10]
301+
# raise RuntimeError(
302+
# f"[MoE] hidden_states contains NaN, first indices: {bad_idx.tolist()}, "
303+
# f"shape={tuple(hidden_states.shape)}"
304+
# )
293305

294306
# MoE forward: route -> dispatch -> compute -> combine
295307
def custom_forward(hidden_states):

megatron/core/transformer/moe/moe_utils.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -618,6 +618,12 @@ def compute_topk(scores, topk, num_groups=None, group_topk=None):
618618
(rows, top_indices), torch.ones_like(probs, dtype=routing_map.dtype), accumulate=False
619619
)
620620
routing_map = routing_map.bool()
621+
# debugmtl
622+
true_per_row = routing_map.sum(dim=1) # tensor of shape [n]
623+
assert torch.all(
624+
true_per_row == 8
625+
), f"in topk routing_with_score_function row true counts not \
626+
all 8, got: {true_per_row}, num_tokens: {num_tokens}, logits shape: {logits.shape}"
621627
else:
622628
# TODO Try using element-wise operations instead of scatter?
623629
routing_probs = torch.zeros_like(logits).scatter(1, top_indices, probs)

megatron/core/transformer/moe/router.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -512,6 +512,19 @@ def routing(self, logits: torch.Tensor):
512512
fused=self.config.moe_router_fusion,
513513
)
514514

515+
# debugmtl
516+
# true_per_row = routing_map.sum(dim=1) # tensor of shape [n]
517+
# torch.set_printoptions(threshold=torch.inf)
518+
# assert torch.all(true_per_row == 8), \
519+
# f"in class topkrouter routing_map row true counts not all 8,
520+
# got: {true_per_row}, logits is:{logits}, topk is:{self.topk},
521+
# use_pre_softmax is:{self.config.moe_router_pre_softmax}, num_groups
522+
# is:{self.config.moe_router_num_groups}, group_topk is:
523+
# {self.config.moe_router_group_topk}, scaling_factor is:
524+
# {self.config.moe_router_topk_scaling_factor}, score_function
525+
# is:{self.score_function}, expert_bias is:{self.expert_bias},
526+
# fused is:{self.config.moe_router_fusion}"
527+
515528
# Apply token dropping to probs and routing_map.
516529
if self.config.moe_expert_capacity_factor is not None:
517530
probs, routing_map = apply_router_token_dropping(

megatron/core/transformer/moe/token_dispatcher.py

Lines changed: 46 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -664,12 +664,52 @@ def token_dispatch(self, permutated_local_input_tokens, permuted_probs):
664664
self.tokens_per_expert = self._maybe_dtoh_and_synchronize(
665665
"before_ep_alltoall", self.tokens_per_expert
666666
)
667-
global_input_tokens = all_to_all(
668-
self.ep_group, permutated_local_input_tokens, self.output_splits, self.input_splits
669-
)
670-
global_probs = all_to_all(
671-
self.ep_group, permuted_probs, self.output_splits, self.input_splits
672-
)
667+
# debugmtl
668+
# global_input_tokens = all_to_all(
669+
# self.ep_group, permutated_local_input_tokens,
670+
# self.output_splits, self.input_splits
671+
# )
672+
# global_probs = all_to_all(
673+
# self.ep_group, permuted_probs, self.output_splits,
674+
# self.input_splits
675+
# )
676+
try:
677+
global_input_tokens = all_to_all(
678+
self.ep_group, permutated_local_input_tokens, self.output_splits, self.input_splits
679+
)
680+
global_probs = all_to_all(
681+
self.ep_group, permuted_probs, self.output_splits, self.input_splits
682+
)
683+
except RuntimeError as e:
684+
# 获取 EP group 内的 rank(防止 group 还没初始化时报错)
685+
try:
686+
rank = torch.distributed.get_rank(self.ep_group)
687+
except Exception:
688+
rank = -1
689+
690+
print(f"[MoE all_to_all error] rank={rank}, err={e}")
691+
print(
692+
f"[MoE all_to_all debug] "
693+
f"tokens_shape={getattr(permutated_local_input_tokens, 'shape', None)}, "
694+
f"probs_shape={getattr(permuted_probs, 'shape', None)}"
695+
)
696+
print(
697+
f"[MoE all_to_all debug] "
698+
f"input_splits={self.input_splits}, sum={sum(self.input_splits) \
699+
if self.input_splits is not None else None}, "
700+
f"output_splits={self.output_splits}, sum={sum(self.output_splits) \
701+
if self.output_splits is not None else None}"
702+
)
703+
print(
704+
f"[MoE all_to_all debug] "
705+
f"tokens_per_expert={self.tokens_per_expert}, "
706+
f"sum={self.tokens_per_expert.sum() if \
707+
hasattr(self.tokens_per_expert, 'sum') else None}"
708+
)
709+
torch.set_printoptions(profile="full")
710+
print(f"hidden_states shape: {self.hidden_shape}")
711+
print(f"routing_map: {self.routing_map}")
712+
raise
673713

674714
return global_input_tokens, global_probs
675715

0 commit comments

Comments
 (0)