Skip to content

Commit b34746e

Browse files
committed
debug merge
1 parent 37c081a commit b34746e

File tree

3 files changed

+89
-0
lines changed

3 files changed

+89
-0
lines changed

megatron/core/inference/contexts/dynamic_context.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1659,11 +1659,22 @@ def initialize_attention_state(
16591659
self.padded_active_request_count = self.padded_batch_dimensions.req_count
16601660
self.padding_slice = slice(self.active_token_count, self.padded_active_token_count)
16611661

1662+
import os, sys
1663+
_rank = torch.distributed.get_rank() if torch.distributed.is_initialized() else 0
1664+
_dbg_f = open(f"/tmp/tde_debug_rank{_rank}.log", "a")
1665+
def _dbg(msg):
1666+
_dbg_f.write(f"[ATTN] {msg}\n"); _dbg_f.flush()
1667+
print(f"[rank{_rank}] [ATTN] {msg}", flush=True, file=sys.stderr)
1668+
1669+
_dbg(f"build_active_slices start (padded_req={self.padded_active_request_count}, padded_tok={self.padded_active_token_count}, paused={self.paused_request_count}, total={self.total_request_count})")
16621670
self.build_active_slices(self.padded_active_request_count)
1671+
_dbg("build_active_slices done")
16631672
self.pad_active_slices()
1673+
_dbg("pad_active_slices done")
16641674

16651675
batch_size = self.total_request_count - self.paused_request_count
16661676
assert self.active_attn_metadata is not None
1677+
_dbg(f"mha_metadata.update start (batch_size={batch_size})")
16671678
self.active_attn_metadata["mha_metadata"].update(
16681679
request_query_lengths=self.active_request_query_lengths[:batch_size],
16691680
request_kv_length_offsets=self.active_request_kv_length_offsets[:batch_size],
@@ -1672,6 +1683,7 @@ def initialize_attention_state(
16721683
padded_batch_dimensions=self.padded_batch_dimensions,
16731684
num_speculative_tokens=self.num_speculative_tokens,
16741685
)
1686+
_dbg("mha_metadata.update done")
16751687

16761688
if self.is_hybrid_model:
16771689
active_slice = slice(self.paused_request_count, self.total_request_count)

megatron/core/inference/engines/dynamic_engine.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -334,7 +334,16 @@ def create_cuda_graphs(self, reset_context: bool = True):
334334
reset_context (bool): Whether to reset the context after building cuda graphs.
335335
"""
336336

337+
import sys
338+
_rank = torch.distributed.get_rank() if torch.distributed.is_initialized() else 0
339+
_dbg_f = open(f"/tmp/tde_debug_rank{_rank}.log", "a")
340+
def _dbg(msg):
341+
_dbg_f.write(f"[CG] {msg}\n"); _dbg_f.flush()
342+
print(f"[rank{_rank}] [CG] {msg}", flush=True, file=sys.stderr)
343+
344+
_dbg(f"create_cuda_graphs start (impl={self.cuda_graph_impl})")
337345
if self.cuda_graph_impl != "local":
346+
_dbg("skipping (not local)")
338347
return
339348

340349
if (
@@ -393,12 +402,15 @@ def create_cuda_graphs(self, reset_context: bool = True):
393402
)
394403

395404
tbar = enumerate(context.cuda_graph_batch_dimensions_list)
405+
_dbg(f"warmup loop start ({len(context.cuda_graph_batch_dimensions_list)} graphs)")
396406
if HAVE_TQDM:
397407
tbar = tqdm(tbar, total=len(context.cuda_graph_batch_dimensions_list))
398408
for tbar_idx, cuda_graph_batch_dimension in tbar:
409+
_dbg(f"warmup iter {tbar_idx}: context_init start ({cuda_graph_batch_dimension})")
399410
input_ids, position_ids = self.controller._dynamic_step_context_init(
400411
construct_graph_dimensions=cuda_graph_batch_dimension
401412
)
413+
_dbg(f"warmup iter {tbar_idx}: context_init done")
402414
# Progress.
403415
tbar_str = f"cuda graph warmup - {cuda_graph_batch_dimension}"
404416
if HAVE_TQDM:
@@ -1630,12 +1642,23 @@ async def async_forward(self) -> Tuple[Dict, Dict, float]:
16301642
step_time (float): How long this step took.
16311643
"""
16321644

1645+
import sys
1646+
_rank = torch.distributed.get_rank() if torch.distributed.is_initialized() else 0
1647+
_dbg_f = open(f"/tmp/tde_debug_rank{_rank}.log", "a")
1648+
def _dbg(msg):
1649+
_dbg_f.write(f"[ENGINE] {msg}\n"); _dbg_f.flush()
1650+
print(f"[rank{_rank}] [ENGINE] {msg}", flush=True, file=sys.stderr)
1651+
1652+
_dbg("async_forward enter")
1653+
16331654
# If suspended, no stepping.
16341655
if self.state in (EngineState.SUSPENDED, EngineState.SUSPENDING):
16351656
raise EngineSuspendedError(self.context.step_count)
16361657

16371658
# schedule requests
1659+
_dbg("schedule_waiting_requests start")
16381660
self.schedule_waiting_requests()
1661+
_dbg(f"schedule_waiting_requests done (total={self.context.total_request_count}, paused={self.context.paused_request_count}, tokens={self.context.active_token_count})")
16391662

16401663
# Saving pre-step state, for printing output below.
16411664
is_decode_only = self.context.is_decode_only()
@@ -1654,7 +1677,9 @@ async def async_forward(self) -> Tuple[Dict, Dict, float]:
16541677
self.is_decode_only = is_decode_only
16551678

16561679
self.step_start_event.record()
1680+
_dbg("async_generate_output_tokens_dynamic_batch start")
16571681
result = await self.controller.async_generate_output_tokens_dynamic_batch()
1682+
_dbg("async_generate_output_tokens_dynamic_batch done")
16581683
self.step_end_event.record()
16591684
self.step_end_event.synchronize()
16601685
step_time = self.step_start_event.elapsed_time(self.step_end_event) / 1e3
@@ -2283,17 +2308,30 @@ async def run_engine_with_coordinator(
22832308
self._loop = get_asyncio_loop(loop)
22842309
self.use_coordinator = True
22852310

2311+
import sys
2312+
_rank = torch.distributed.get_rank() if torch.distributed.is_initialized() else 0
2313+
_dbg_f = open(f"/tmp/tde_debug_rank{_rank}.log", "a")
2314+
_iter = 0
2315+
def _dbg(msg):
2316+
_dbg_f.write(f"[COORD iter={_iter}] {msg}\n"); _dbg_f.flush()
2317+
print(f"[rank{_rank}] [COORD iter={_iter}] {msg}", flush=True, file=sys.stderr)
2318+
22862319
try:
22872320
while True:
2321+
_iter += 1
2322+
_dbg(f"loop top (state={self.state})")
22882323
self.schedule_requests()
2324+
_dbg(f"schedule done (active={self.context.get_active_request_count()}, waiting={len(self.waiting_request_ids)})")
22892325

22902326
if self.state in (EngineState.RUNNING, EngineState.PAUSING):
22912327
local_pending = self.context.get_active_request_count() + len(
22922328
self.waiting_request_ids
22932329
)
2330+
_dbg(f"ep_consensus start (local_pending={local_pending})")
22942331
global_work, all_pausing = await self._ep_establish_consensus(
22952332
local_pending, signal_consensus=(self.state == EngineState.PAUSING)
22962333
)
2334+
_dbg(f"ep_consensus done (global_work={global_work}, all_pausing={all_pausing})")
22972335

22982336
if all_pausing:
22992337
# All EP peers are PAUSING: pause immediately.
@@ -2303,15 +2341,19 @@ async def run_engine_with_coordinator(
23032341
elif global_work > 0:
23042342
# At least one EP peer has work: all must participate.
23052343
if local_pending > 0:
2344+
_dbg("async_step start")
23062345
await self.async_step()
2346+
_dbg("async_step done")
23072347
else:
23082348
# Dummy forward to participate in the EP collective.
2349+
_dbg("dummy_forward start")
23092350
self.step_start_event.record()
23102351
self.controller.dummy_forward()
23112352
self.step_end_event.record()
23122353
self.step_end_event.synchronize()
23132354
self.context.step_count += 1
23142355
self.context.prefix_cache_lru_clock += 1
2356+
_dbg("dummy_forward done")
23152357
else:
23162358
# No work, but not all pausing: idle.
23172359
await asyncio.sleep(0.02)

megatron/core/inference/text_generation_controllers/text_generation_controller.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,11 @@ def set_stop_word_finished_ids_callback(self, callback):
107107

108108
def _init_dynamic_sampling_tensors(self):
109109
"""Initialize tensors needed for dynamic sampling."""
110+
import sys
111+
_rank = torch.distributed.get_rank() if torch.distributed.is_initialized() else 0
112+
_dbg_f = open(f"/tmp/tde_debug_rank{_rank}.log", "a")
113+
_dbg_f.write("[INIT] _init_dynamic_sampling_tensors start\n"); _dbg_f.flush()
114+
print(f"[rank{_rank}] [INIT] _init_dynamic_sampling_tensors start", flush=True, file=sys.stderr)
110115
context = self.inference_wrapped_model.inference_context
111116
max_requests = context.max_requests
112117
if context.config.materialize_only_last_token_logits:
@@ -143,6 +148,8 @@ def _init_dynamic_sampling_tensors(self):
143148
self._torch_sampling_buckets: List[Tuple] = []
144149

145150
self._init_mtp_sampling_tensor()
151+
_dbg_f.write("[INIT] _init_dynamic_sampling_tensors done\n"); _dbg_f.flush()
152+
print(f"[rank{_rank}] [INIT] _init_dynamic_sampling_tensors done", flush=True, file=sys.stderr)
146153

147154
def _init_mtp_sampling_tensor(self):
148155
"""Initialize the MTP sampling tensor after num_speculative_tokens is set."""
@@ -626,11 +633,20 @@ def _dynamic_step_forward_logits(self, input_ids: Tensor, position_ids: Tensor):
626633
else context.padded_active_token_count
627634
)
628635

636+
import os, sys
637+
_rank = torch.distributed.get_rank() if torch.distributed.is_initialized() else 0
638+
_dbg_f = open(f"/tmp/tde_debug_rank{_rank}.log", "a")
639+
def _dbg(msg):
640+
_dbg_f.write(f"[FWD] {msg}\n"); _dbg_f.flush()
641+
print(f"[rank{_rank}] [FWD] {msg}", flush=True, file=sys.stderr)
642+
643+
_dbg(f"run_one_forward_step start (logits_seq_len={logits_seq_len})")
629644
with torch.inference_mode():
630645
logits = self.inference_wrapped_model.run_one_forward_step(
631646
{"tokens": input_ids, "position_ids": position_ids, "attention_mask": None}
632647
)
633648
# logits shape: [1, seq_len, vocab_size]
649+
_dbg(f"run_one_forward_step done (logits={'None' if logits is None else tuple(logits.shape)})")
634650

635651
# Note: When speculative decoding is active (num_speculative_tokens > 0),
636652
# the model skips MTP computation during the forward pass. MTP logits
@@ -653,12 +669,14 @@ def _dynamic_step_forward_logits(self, input_ids: Tensor, position_ids: Tensor):
653669
if is_pipeline_last_stage(self.pp_group):
654670
assert logits is not None and torch.Size(logits_shape) == logits.shape
655671

672+
_dbg("broadcast_from_last_pipeline_stage start")
656673
logits = broadcast_from_last_pipeline_stage(
657674
logits_shape,
658675
dtype=self.model_config.params_dtype,
659676
tensor=logits,
660677
pp_group=self.pp_group,
661678
)
679+
_dbg("broadcast_from_last_pipeline_stage done")
662680

663681
# Copy logits to contiguous buffer.
664682
if self._enable_cuda_graph:
@@ -1754,11 +1772,20 @@ async def async_generate_output_tokens_dynamic_batch(
17541772
context = self.inference_wrapped_model.inference_context
17551773
active_request_count = context.total_request_count - context.paused_request_count
17561774

1775+
import os, sys
1776+
_rank = torch.distributed.get_rank() if torch.distributed.is_initialized() else 0
1777+
_dbg_f = open(f"/tmp/tde_debug_rank{_rank}.log", "a")
1778+
def _dbg(msg):
1779+
_dbg_f.write(f"[STEP] {msg}\n"); _dbg_f.flush()
1780+
print(f"[rank{_rank}] [STEP] {msg}", flush=True, file=sys.stderr)
1781+
17571782
# No tokens and no active requests?
17581783
if context.active_token_count == 0 and active_request_count == 0:
17591784
return None
17601785

1786+
_dbg(f"context_init start (tokens={context.active_token_count}, reqs={active_request_count})")
17611787
input_ids, position_ids = self._dynamic_step_context_init()
1788+
_dbg(f"context_init done (input_ids.shape={tuple(input_ids.shape)})")
17621789

17631790
cuda_graph_request_count = (
17641791
context.padded_active_request_count if context.is_decode_only() else None
@@ -1771,7 +1798,9 @@ async def async_generate_output_tokens_dynamic_batch(
17711798

17721799
# Forward pass produces only base logits. When speculative decoding is
17731800
# active, MTP logits are computed serially after verification.
1801+
_dbg("forward_logits start")
17741802
self._dynamic_step_forward_logits(input_ids, position_ids)
1803+
_dbg("forward_logits done")
17751804

17761805
# Commit Mamba intermediate states before update_requests, which
17771806
# may swap request indices. The Python lists tracking EOS block IDs
@@ -1790,10 +1819,13 @@ async def async_generate_output_tokens_dynamic_batch(
17901819
# asynchronous.
17911820
# Todo [Siddharth]: Can we condition the sleep on a cuda event?
17921821
# NOTE [TDE]: This will be moved once CPU and GPU methods are separated.
1822+
_dbg("yield start")
17931823
await asyncio.sleep(0)
1824+
_dbg("yield done")
17941825
return_log_probs, return_top_n_logprobs = self._dynamic_step_log_probs_bookkeeping()
17951826

17961827
self._dynamic_step_sample_bookkeeping()
1828+
_dbg("sample_logits start")
17971829

17981830
if self.num_speculative_tokens > 0:
17991831
# Phase 1: Verify speculative tokens using base logits only.
@@ -1810,6 +1842,7 @@ async def async_generate_output_tokens_dynamic_batch(
18101842
self._compute_serial_mtp_and_sample()
18111843
else:
18121844
self._dynamic_step_sample_logits()
1845+
_dbg("sample_logits done")
18131846

18141847
log_probs = None
18151848
top_n_logprobs = None
@@ -1825,10 +1858,12 @@ async def async_generate_output_tokens_dynamic_batch(
18251858
if return_top_n_logprobs:
18261859
top_n_logprobs = self._dynamic_step_calculate_top_n_logprobs(log_probs_tensor)
18271860

1861+
_dbg("bookkeeping start")
18281862
if skip_bookkeeping:
18291863
request_bookkeeping = {}
18301864
else:
18311865
request_bookkeeping = self._dynamic_step_context_bookkeeping()
1866+
_dbg("bookkeeping done")
18321867

18331868
ret = {
18341869
# Clone needed: _sampled_tokens_cuda is a reused buffer overwritten each step.

0 commit comments

Comments
 (0)