@@ -968,6 +968,23 @@ def get_spec_tree_manager(
968968 def _use_beam_search (self ) -> bool :
969969 return self .max_beam_width > 1
970970
971+ def _can_use_fast_greedy_path (self , requests : list [LlmRequest ]) -> bool :
972+ """
973+ Check if we can use the fast argmax path for greedy sampling.
974+ """
975+
976+ # Check if all requests use greedy sampling and don't require features
977+ # that the fast path skips
978+ for req in requests :
979+ # vocab_size doesn't affect greediness check
980+ if _request_strategy (req , vocab_size = 2 ** 31 ) != GREEDY :
981+ return False
982+
983+ # Fast path skips logprobs handling
984+ if req .py_return_log_probs :
985+ return False
986+ return True
987+
971988 @staticmethod
972989 def _meet_max_token_stop_criteria (
973990 request : LlmRequest , max_seq_len : int , beam_idx : int = DEFAULT_BEAM_IDX
@@ -1882,6 +1899,34 @@ def _apply_d2t(tokens: torch.Tensor, model_outputs) -> None:
18821899 d2t = model_outputs ["d2t" ][tokens ]
18831900 tokens += d2t
18841901
1902+ @staticmethod
1903+ @nvtx_range ("fast_greedy_sample_kernel" )
1904+ def _fast_greedy_sample_kernel (
1905+ logits_cuda : torch .Tensor ,
1906+ new_tokens_cuda : torch .Tensor ,
1907+ batch_dest_indices : torch .Tensor ,
1908+ max_beam_width : int ,
1909+ d2t : torch .Tensor | None ,
1910+ ) -> None :
1911+ """Applies fast greedy sampling to the logits.
1912+
1913+ Performs argmax, applies d2t translation if present, and scatters
1914+ tokens into the output buffer. All operations are in-place.
1915+ """
1916+ # Simple argmax for greedy sampling
1917+ next_tokens = torch .argmax (logits_cuda , dim = - 1 ).to (dtype = new_tokens_cuda .dtype )
1918+
1919+ # Apply draft-to-target token translation if present (for Eagle3)
1920+ if d2t is not None :
1921+ next_tokens += d2t [next_tokens ]
1922+
1923+ # Scatter tokens into output buffer
1924+ batch_dest_indices_expanded = batch_dest_indices .unsqueeze (1 ).expand (- 1 , max_beam_width )
1925+ next_tokens_expanded = next_tokens .unsqueeze (1 ).expand (- 1 , max_beam_width )
1926+ new_tokens_cuda .view (- 1 , * new_tokens_cuda .shape [2 :]).scatter_ (
1927+ 0 , batch_dest_indices_expanded , next_tokens_expanded
1928+ )
1929+
18851930 @staticmethod
18861931 def _apply_embedding_bias (
18871932 logits : torch .Tensor ,
@@ -2372,6 +2417,7 @@ def _request_indices_with_stop_words(self, requests: list[LlmRequest]) -> torch.
23722417 if (r .py_stop_words_list is not None and len (r .py_stop_words_list [0 ]) > 0 )
23732418 ]
23742419
2420+ @nvtx_range ("_write_finish_reasons" )
23752421 def _write_finish_reasons (
23762422 self ,
23772423 requests : list [LlmRequest ],
@@ -2637,6 +2683,36 @@ def _process_requests(
26372683 sampling_requests_metadata .req_num_beams ,
26382684 )
26392685
2686+ # Fast path for greedy sampling
2687+ if self ._can_use_fast_greedy_path (requests ):
2688+ # Compute destination indices on CPU (same pattern as _unbatch_sampling_results)
2689+ batch_destination_indexer = _UnpackedStepIndexer (
2690+ seq_slots = seq_slots ,
2691+ num_steps = sampling_requests_metadata .req_num_generated_tokens ,
2692+ steps_dim_size = new_tokens_cuda .size (0 ),
2693+ slots_dim_size = new_tokens_cuda .size (1 ),
2694+ dim_order = _UnpackedStepIndexer .DimOrder .STEP_MAJOR ,
2695+ index_dtype = torch .int64 ,
2696+ )
2697+ batch_dest_indices_cuda = batch_destination_indexer [:].to (
2698+ new_tokens_cuda .device , non_blocking = True
2699+ )
2700+
2701+ # Get d2t tensor if present
2702+ d2t = model_outputs .get ("d2t" , None )
2703+
2704+ # Run compiled kernel for argmax, d2t application, and scatter
2705+ self ._fast_greedy_sample_kernel (
2706+ logits_cuda ,
2707+ new_tokens_cuda ,
2708+ batch_dest_indices_cuda ,
2709+ self .max_beam_width ,
2710+ d2t ,
2711+ )
2712+
2713+ new_tokens_host = self ._copy_to_host (new_tokens_cuda )
2714+ return new_tokens_host
2715+
26402716 # Indexer for accessing tokens in 'logits_cuda', corresponding to the
26412717 # requests in 'requests'.
26422718 steps_dim_size = new_tokens_cuda .size (0 )
0 commit comments