@@ -107,6 +107,11 @@ def set_stop_word_finished_ids_callback(self, callback):
107107
108108 def _init_dynamic_sampling_tensors (self ):
109109 """Initialize tensors needed for dynamic sampling."""
110+ import sys
111+ _rank = torch .distributed .get_rank () if torch .distributed .is_initialized () else 0
112+ _dbg_f = open (f"/tmp/tde_debug_rank{ _rank } .log" , "a" )
113+ _dbg_f .write ("[INIT] _init_dynamic_sampling_tensors start\n " ); _dbg_f .flush ()
114+ print (f"[rank{ _rank } ] [INIT] _init_dynamic_sampling_tensors start" , flush = True , file = sys .stderr )
110115 context = self .inference_wrapped_model .inference_context
111116 max_requests = context .max_requests
112117 if context .config .materialize_only_last_token_logits :
@@ -143,6 +148,8 @@ def _init_dynamic_sampling_tensors(self):
143148 self ._torch_sampling_buckets : List [Tuple ] = []
144149
145150 self ._init_mtp_sampling_tensor ()
151+ _dbg_f .write ("[INIT] _init_dynamic_sampling_tensors done\n " ); _dbg_f .flush ()
152+ print (f"[rank{ _rank } ] [INIT] _init_dynamic_sampling_tensors done" , flush = True , file = sys .stderr )
146153
147154 def _init_mtp_sampling_tensor (self ):
148155 """Initialize the MTP sampling tensor after num_speculative_tokens is set."""
@@ -626,11 +633,20 @@ def _dynamic_step_forward_logits(self, input_ids: Tensor, position_ids: Tensor):
626633 else context .padded_active_token_count
627634 )
628635
636+ import os , sys
637+ _rank = torch .distributed .get_rank () if torch .distributed .is_initialized () else 0
638+ _dbg_f = open (f"/tmp/tde_debug_rank{ _rank } .log" , "a" )
639+ def _dbg (msg ):
640+ _dbg_f .write (f"[FWD] { msg } \n " ); _dbg_f .flush ()
641+ print (f"[rank{ _rank } ] [FWD] { msg } " , flush = True , file = sys .stderr )
642+
643+ _dbg (f"run_one_forward_step start (logits_seq_len={ logits_seq_len } )" )
629644 with torch .inference_mode ():
630645 logits = self .inference_wrapped_model .run_one_forward_step (
631646 {"tokens" : input_ids , "position_ids" : position_ids , "attention_mask" : None }
632647 )
633648 # logits shape: [1, seq_len, vocab_size]
649+ _dbg (f"run_one_forward_step done (logits={ 'None' if logits is None else tuple (logits .shape )} )" )
634650
635651 # Note: When speculative decoding is active (num_speculative_tokens > 0),
636652 # the model skips MTP computation during the forward pass. MTP logits
@@ -653,12 +669,14 @@ def _dynamic_step_forward_logits(self, input_ids: Tensor, position_ids: Tensor):
653669 if is_pipeline_last_stage (self .pp_group ):
654670 assert logits is not None and torch .Size (logits_shape ) == logits .shape
655671
672+ _dbg ("broadcast_from_last_pipeline_stage start" )
656673 logits = broadcast_from_last_pipeline_stage (
657674 logits_shape ,
658675 dtype = self .model_config .params_dtype ,
659676 tensor = logits ,
660677 pp_group = self .pp_group ,
661678 )
679+ _dbg ("broadcast_from_last_pipeline_stage done" )
662680
663681 # Copy logits to contiguous buffer.
664682 if self ._enable_cuda_graph :
@@ -1754,11 +1772,20 @@ async def async_generate_output_tokens_dynamic_batch(
17541772 context = self .inference_wrapped_model .inference_context
17551773 active_request_count = context .total_request_count - context .paused_request_count
17561774
1775+ import os , sys
1776+ _rank = torch .distributed .get_rank () if torch .distributed .is_initialized () else 0
1777+ _dbg_f = open (f"/tmp/tde_debug_rank{ _rank } .log" , "a" )
1778+ def _dbg (msg ):
1779+ _dbg_f .write (f"[STEP] { msg } \n " ); _dbg_f .flush ()
1780+ print (f"[rank{ _rank } ] [STEP] { msg } " , flush = True , file = sys .stderr )
1781+
17571782 # No tokens and no active requests?
17581783 if context .active_token_count == 0 and active_request_count == 0 :
17591784 return None
17601785
1786+ _dbg (f"context_init start (tokens={ context .active_token_count } , reqs={ active_request_count } )" )
17611787 input_ids , position_ids = self ._dynamic_step_context_init ()
1788+ _dbg (f"context_init done (input_ids.shape={ tuple (input_ids .shape )} )" )
17621789
17631790 cuda_graph_request_count = (
17641791 context .padded_active_request_count if context .is_decode_only () else None
@@ -1771,7 +1798,9 @@ async def async_generate_output_tokens_dynamic_batch(
17711798
17721799 # Forward pass produces only base logits. When speculative decoding is
17731800 # active, MTP logits are computed serially after verification.
1801+ _dbg ("forward_logits start" )
17741802 self ._dynamic_step_forward_logits (input_ids , position_ids )
1803+ _dbg ("forward_logits done" )
17751804
17761805 # Commit Mamba intermediate states before update_requests, which
17771806 # may swap request indices. The Python lists tracking EOS block IDs
@@ -1790,10 +1819,13 @@ async def async_generate_output_tokens_dynamic_batch(
17901819 # asynchronous.
17911820 # Todo [Siddharth]: Can we condition the sleep on a cuda event?
17921821 # NOTE [TDE]: This will be moved once CPU and GPU methods are separated.
1822+ _dbg ("yield start" )
17931823 await asyncio .sleep (0 )
1824+ _dbg ("yield done" )
17941825 return_log_probs , return_top_n_logprobs = self ._dynamic_step_log_probs_bookkeeping ()
17951826
17961827 self ._dynamic_step_sample_bookkeeping ()
1828+ _dbg ("sample_logits start" )
17971829
17981830 if self .num_speculative_tokens > 0 :
17991831 # Phase 1: Verify speculative tokens using base logits only.
@@ -1810,6 +1842,7 @@ async def async_generate_output_tokens_dynamic_batch(
18101842 self ._compute_serial_mtp_and_sample ()
18111843 else :
18121844 self ._dynamic_step_sample_logits ()
1845+ _dbg ("sample_logits done" )
18131846
18141847 log_probs = None
18151848 top_n_logprobs = None
@@ -1825,10 +1858,12 @@ async def async_generate_output_tokens_dynamic_batch(
18251858 if return_top_n_logprobs :
18261859 top_n_logprobs = self ._dynamic_step_calculate_top_n_logprobs (log_probs_tensor )
18271860
1861+ _dbg ("bookkeeping start" )
18281862 if skip_bookkeeping :
18291863 request_bookkeeping = {}
18301864 else :
18311865 request_bookkeeping = self ._dynamic_step_context_bookkeeping ()
1866+ _dbg ("bookkeeping done" )
18321867
18331868 ret = {
18341869 # Clone needed: _sampled_tokens_cuda is a reused buffer overwritten each step.
0 commit comments