@@ -109,6 +109,10 @@ def _init_dynamic_sampling_tensors(self):
109109 """Initialize tensors needed for dynamic sampling."""
110110 context = self .inference_wrapped_model .inference_context
111111 max_requests = context .max_requests
112+ if context .materialize_only_last_token_logits :
113+ max_logits = max_requests
114+ else :
115+ max_logits = context .max_tokens
112116
113117 # Callback to get request IDs that should be marked as finished due to stop words
114118 self ._get_stop_word_finished_ids_callback = None
@@ -117,6 +121,15 @@ def _init_dynamic_sampling_tensors(self):
117121 logits_dtype = self .inference_wrapped_model .config .params_dtype
118122
119123 self ._sampling_backend = "torch"
124+ self ._enable_cuda_graph = False
125+
126+ # Initialize bookkeeping tensors.
127+ if self ._enable_cuda_graph :
128+ self ._all_logits_cuda = torch .empty (
129+ (1 , max_logits , self .vocab_size ), dtype = logits_dtype , device = device
130+ )
131+ else :
132+ self ._all_logits_cuda = None
120133 self ._sampled_tokens_cuda = torch .empty (max_requests , dtype = torch .int64 , device = device )
121134 # Speculative tokens tensor will be allocated later when num_speculative_tokens is set by the engine
122135 self ._accepted_tokens_per_request = None
@@ -596,7 +609,7 @@ def _dynamic_step_context_init(
596609 else :
597610 return context .current_input_and_position_ids ()
598611
599- def _dynamic_step_forward_logits (self , input_ids : Tensor , position_ids : Tensor ) -> Tensor :
612+ def _dynamic_step_forward_logits (self , input_ids : Tensor , position_ids : Tensor ):
600613 """Forward step the model to get logits for dynamic batching.
601614
602615 This also handles logits-broadcasting for pipeline parallelism.
@@ -607,6 +620,11 @@ def _dynamic_step_forward_logits(self, input_ids: Tensor, position_ids: Tensor)
607620 """
608621 context = self .inference_wrapped_model .inference_context
609622 active_request_count = context .total_request_count - context .paused_request_count
623+ logits_seq_len = (
624+ active_request_count
625+ if context .materialize_only_last_token_logits
626+ else context .padded_active_token_count
627+ )
610628
611629 with torch .inference_mode ():
612630 logits = self .inference_wrapped_model .run_one_forward_step (
@@ -619,6 +637,12 @@ def _dynamic_step_forward_logits(self, input_ids: Tensor, position_ids: Tensor)
619637 # will be computed serially after verification to ensure they are
620638 # conditioned on verified tokens only.
621639
640+ assert logits_seq_len == (
641+ active_request_count
642+ if context .materialize_only_last_token_logits
643+ else input_ids .shape [1 ]
644+ )
645+
622646 if self .model_is_pipeline_parallel :
623647 if context .config .materialize_only_last_token_logits :
624648 logits_seq_len = active_request_count
@@ -636,7 +660,11 @@ def _dynamic_step_forward_logits(self, input_ids: Tensor, position_ids: Tensor)
636660 pp_group = self .pp_group ,
637661 )
638662
639- return logits
663+ # Copy logits to contiguous buffer.
664+ if self ._enable_cuda_graph :
665+ self ._all_logits_cuda [:, :logits_seq_len , :].copy_ (logits )
666+ else :
667+ self ._all_logits_cuda = logits
640668
641669 def _dynamic_step_sample_bookkeeping (self ):
642670 """Perform bookkeeping necessary to sample logits for dynamic batching."""
@@ -1053,7 +1081,7 @@ def _verify_speculative_tokens(
10531081
10541082 return last_one_indices , accepted_tokens_mask , input_tokens_required
10551083
1056- def _dynamic_step_sample_logits_and_verify_tokens (self , logits : Tensor , input_ids : Tensor ):
1084+ def _dynamic_step_sample_logits_and_verify_tokens (self , input_ids : Tensor ):
10571085 """
10581086 Sample tokens from logits for dynamic batching with speculative tokens and verify the tokens.
10591087 """
@@ -1069,6 +1097,7 @@ def _dynamic_step_sample_logits_and_verify_tokens(self, logits: Tensor, input_id
10691097 num_decode_requests = active_request_count - num_prefill_requests
10701098
10711099 # Get the logit indices for tokens that need sampling.
1100+ logits = self ._all_logits_cuda
10721101 required_logit_indices = self ._get_required_logit_indices (
10731102 request_in_prefill_status_tensor ,
10741103 request_query_lengths ,
@@ -1132,24 +1161,22 @@ def _dynamic_step_sample_logits_and_verify_tokens(self, logits: Tensor, input_id
11321161 dim = 1
11331162 )
11341163
1135- def _dynamic_step_sample_logits (self , logits : Tensor ):
1136- """Sample tokens from logits for dynamic batching.
1137-
1138- Args:
1139- logits (Tensor): The logits from the forward pass.
1140- """
1164+ def _dynamic_step_sample_logits (self ):
1165+ """Sample tokens from logits for dynamic batching."""
11411166 # TODO(ksanthanam): Evaluate whether it makes more sense to sample on 1 rank
11421167 # and then broadcast the sampled tokens rather than broadcasting the raw logits.
11431168
11441169 # Last token logits.
11451170 context = self .inference_wrapped_model .inference_context
1171+ active_request_count = context .total_request_count - context .paused_request_count
1172+
11461173 if context .config .materialize_only_last_token_logits :
11471174 # When materialize_only_last_token_logits is true, last_token_logits is
11481175 # already called in the forward pass of GPT.
1149- required_token_logits = logits . squeeze (0 )
1176+ required_token_logits = self . _all_logits_cuda . squeeze (0 )[: active_request_count , :]
11501177 else :
11511178 # todo : Should do verification here and get approrpiate las token logits
1152- required_token_logits = context .last_token_logits (logits )
1179+ required_token_logits = context .last_token_logits (self . _all_logits_cuda )
11531180
11541181 if self ._sampling_backend == "torch" :
11551182 # Concatenate the outputs once to prevent repeated small writes.
@@ -1247,19 +1274,24 @@ def _router_record_bookkeeping(self) -> Optional[Dict[int, Tensor]]:
12471274
12481275 return routing_indices_per_request
12491276
1250- def _dynamic_step_calculate_log_probs (self , logits : Tensor ) -> Optional [Tensor ]:
1277+ def _dynamic_step_calculate_log_probs (self ) -> Optional [Tensor ]:
12511278 """Calculate log probs from logits."""
12521279 context = self .inference_wrapped_model .inference_context
12531280 active_request_count = context .total_request_count - context .paused_request_count
1281+ logits_seq_len = (
1282+ active_request_count
1283+ if context .materialize_only_last_token_logits
1284+ else context .padded_active_token_count
1285+ )
12541286
12551287 return context .calculate_log_probs (
1256- logits ,
1288+ self . _all_logits_cuda [:, : logits_seq_len , :] ,
12571289 self ._sampled_tokens_cuda [:active_request_count ],
12581290 only_last_token_logits = context .config .materialize_only_last_token_logits ,
12591291 )
12601292
12611293 def _dynamic_step_calculate_log_probs_speculative (
1262- self , logits : Tensor
1294+ self ,
12631295 ) -> Tuple [List [List [float ]], Tensor ]:
12641296 """Calculate log probs from logits for speculative decoding.
12651297
@@ -1271,9 +1303,6 @@ def _dynamic_step_calculate_log_probs_speculative(
12711303 - log_prob(accepted_token[j]) comes from logits at position j
12721304 - log_prob(newly_sampled_token) comes from logits at position accepted_count
12731305
1274- Args:
1275- logits (Tensor): The main model logits [1, seq_len, vocab_size].
1276-
12771306 Returns:
12781307 Tuple of (log_probs_list, log_probs_tensor):
12791308 log_probs_list: List of lists, one per active request, containing
@@ -1291,7 +1320,7 @@ def _dynamic_step_calculate_log_probs_speculative(
12911320 num_prefill_requests = request_in_prefill_status_tensor .sum ().item ()
12921321 num_decode_requests = active_request_count - num_prefill_requests
12931322
1294- logits_squeezed = logits .squeeze (0 ).float ()
1323+ logits_squeezed = self . _all_logits_cuda .squeeze (0 ).float ()
12951324 log_probs_tensor = F .log_softmax (logits_squeezed [: context .active_token_count ], dim = - 1 )
12961325
12971326 log_probs_list_decode = []
@@ -1449,12 +1478,11 @@ def _dynamic_step_calculate_top_n_logprobs_speculative(
14491478 return top_n_results if top_n_results else None
14501479
14511480 def _dynamic_step_calculate_top_n_logprobs (
1452- self , logits : Tensor , log_probs_tensor : Optional [Tensor ] = None
1481+ self , log_probs_tensor : Optional [Tensor ] = None
14531482 ) -> Optional [Dict [int , List [Tuple [Tensor , Tensor ]]]]:
14541483 """Calculate top-n log probs from logits for dynamic batching.
14551484
14561485 Args:
1457- logits (Tensor): The logits to compute top-n log probs from.
14581486 log_probs_tensor (Optional[Tensor]): Pre-computed log probabilities tensor.
14591487 If provided, avoids recomputing log_softmax. Should be the tensor
14601488 returned by calculate_log_probs.
@@ -1743,7 +1771,7 @@ async def async_generate_output_tokens_dynamic_batch(
17431771
17441772 # Forward pass produces only base logits. When speculative decoding is
17451773 # active, MTP logits are computed serially after verification.
1746- logits = self ._dynamic_step_forward_logits (input_ids , position_ids )
1774+ self ._dynamic_step_forward_logits (input_ids , position_ids )
17471775
17481776 # Commit Mamba intermediate states before update_requests, which
17491777 # may swap request indices. The Python lists tracking EOS block IDs
@@ -1769,7 +1797,7 @@ async def async_generate_output_tokens_dynamic_batch(
17691797
17701798 if self .num_speculative_tokens > 0 :
17711799 # Phase 1: Verify speculative tokens using base logits only.
1772- self ._dynamic_step_sample_logits_and_verify_tokens (logits , input_ids )
1800+ self ._dynamic_step_sample_logits_and_verify_tokens (input_ids )
17731801 # Phase 2: Rewind KV cache for rejected tokens.
17741802 self ._rewind_kv_cache ()
17751803
@@ -1781,25 +1809,21 @@ async def async_generate_output_tokens_dynamic_batch(
17811809 # Phase 3: Compute MTP serially with correct (verified) inputs.
17821810 self ._compute_serial_mtp_and_sample ()
17831811 else :
1784- self ._dynamic_step_sample_logits (logits )
1812+ self ._dynamic_step_sample_logits ()
17851813
17861814 log_probs = None
17871815 top_n_logprobs = None
17881816 if return_log_probs or return_top_n_logprobs :
17891817 if self .num_speculative_tokens > 0 :
1790- log_probs , log_probs_tensor = self ._dynamic_step_calculate_log_probs_speculative (
1791- logits
1792- )
1818+ log_probs , log_probs_tensor = self ._dynamic_step_calculate_log_probs_speculative ()
17931819 if return_top_n_logprobs :
17941820 top_n_logprobs = self ._dynamic_step_calculate_top_n_logprobs_speculative (
17951821 log_probs_tensor
17961822 )
17971823 else :
1798- log_probs , log_probs_tensor = self ._dynamic_step_calculate_log_probs (logits )
1824+ log_probs , log_probs_tensor = self ._dynamic_step_calculate_log_probs ()
17991825 if return_top_n_logprobs :
1800- top_n_logprobs = self ._dynamic_step_calculate_top_n_logprobs (
1801- logits , log_probs_tensor
1802- )
1826+ top_n_logprobs = self ._dynamic_step_calculate_top_n_logprobs (log_probs_tensor )
18031827
18041828 if skip_bookkeeping :
18051829 request_bookkeeping = {}
0 commit comments