Skip to content

Commit f2aee0d

Browse files
authored
[TRTLLM-9854][feat] Optimize the host overhead of _sample_async (#9935)
Signed-off-by: ziyixiong-nv <219238287+ziyixiong-nv@users.noreply.github.com>
1 parent 25db9e7 commit f2aee0d

File tree

2 files changed

+84
-1
lines changed

2 files changed

+84
-1
lines changed

tensorrt_llm/_torch/pyexecutor/sampler.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -968,6 +968,23 @@ def get_spec_tree_manager(
968968
def _use_beam_search(self) -> bool:
969969
return self.max_beam_width > 1
970970

971+
def _can_use_fast_greedy_path(self, requests: list[LlmRequest]) -> bool:
972+
"""
973+
Check if we can use the fast argmax path for greedy sampling.
974+
"""
975+
976+
# Check if all requests use greedy sampling and don't require features
977+
# that the fast path skips
978+
for req in requests:
979+
# vocab_size doesn't affect greediness check
980+
if _request_strategy(req, vocab_size=2**31) != GREEDY:
981+
return False
982+
983+
# Fast path skips logprobs handling
984+
if req.py_return_log_probs:
985+
return False
986+
return True
987+
971988
@staticmethod
972989
def _meet_max_token_stop_criteria(
973990
request: LlmRequest, max_seq_len: int, beam_idx: int = DEFAULT_BEAM_IDX
@@ -1882,6 +1899,34 @@ def _apply_d2t(tokens: torch.Tensor, model_outputs) -> None:
18821899
d2t = model_outputs["d2t"][tokens]
18831900
tokens += d2t
18841901

1902+
@staticmethod
1903+
@nvtx_range("fast_greedy_sample_kernel")
1904+
def _fast_greedy_sample_kernel(
1905+
logits_cuda: torch.Tensor,
1906+
new_tokens_cuda: torch.Tensor,
1907+
batch_dest_indices: torch.Tensor,
1908+
max_beam_width: int,
1909+
d2t: torch.Tensor | None,
1910+
) -> None:
1911+
"""Applies fast greedy sampling to the logits.
1912+
1913+
Performs argmax, applies d2t translation if present, and scatters
1914+
tokens into the output buffer. All operations are in-place.
1915+
"""
1916+
# Simple argmax for greedy sampling
1917+
next_tokens = torch.argmax(logits_cuda, dim=-1).to(dtype=new_tokens_cuda.dtype)
1918+
1919+
# Apply draft-to-target token translation if present (for Eagle3)
1920+
if d2t is not None:
1921+
next_tokens += d2t[next_tokens]
1922+
1923+
# Scatter tokens into output buffer
1924+
batch_dest_indices_expanded = batch_dest_indices.unsqueeze(1).expand(-1, max_beam_width)
1925+
next_tokens_expanded = next_tokens.unsqueeze(1).expand(-1, max_beam_width)
1926+
new_tokens_cuda.view(-1, *new_tokens_cuda.shape[2:]).scatter_(
1927+
0, batch_dest_indices_expanded, next_tokens_expanded
1928+
)
1929+
18851930
@staticmethod
18861931
def _apply_embedding_bias(
18871932
logits: torch.Tensor,
@@ -2372,6 +2417,7 @@ def _request_indices_with_stop_words(self, requests: list[LlmRequest]) -> torch.
23722417
if (r.py_stop_words_list is not None and len(r.py_stop_words_list[0]) > 0)
23732418
]
23742419

2420+
@nvtx_range("_write_finish_reasons")
23752421
def _write_finish_reasons(
23762422
self,
23772423
requests: list[LlmRequest],
@@ -2637,6 +2683,36 @@ def _process_requests(
26372683
sampling_requests_metadata.req_num_beams,
26382684
)
26392685

2686+
# Fast path for greedy sampling
2687+
if self._can_use_fast_greedy_path(requests):
2688+
# Compute destination indices on CPU (same pattern as _unbatch_sampling_results)
2689+
batch_destination_indexer = _UnpackedStepIndexer(
2690+
seq_slots=seq_slots,
2691+
num_steps=sampling_requests_metadata.req_num_generated_tokens,
2692+
steps_dim_size=new_tokens_cuda.size(0),
2693+
slots_dim_size=new_tokens_cuda.size(1),
2694+
dim_order=_UnpackedStepIndexer.DimOrder.STEP_MAJOR,
2695+
index_dtype=torch.int64,
2696+
)
2697+
batch_dest_indices_cuda = batch_destination_indexer[:].to(
2698+
new_tokens_cuda.device, non_blocking=True
2699+
)
2700+
2701+
# Get d2t tensor if present
2702+
d2t = model_outputs.get("d2t", None)
2703+
2704+
# Run compiled kernel for argmax, d2t application, and scatter
2705+
self._fast_greedy_sample_kernel(
2706+
logits_cuda,
2707+
new_tokens_cuda,
2708+
batch_dest_indices_cuda,
2709+
self.max_beam_width,
2710+
d2t,
2711+
)
2712+
2713+
new_tokens_host = self._copy_to_host(new_tokens_cuda)
2714+
return new_tokens_host
2715+
26402716
# Indexer for accessing tokens in 'logits_cuda', corresponding to the
26412717
# requests in 'requests'.
26422718
steps_dim_size = new_tokens_cuda.size(0)

tests/unittest/_torch/sampler/test_torch_sampler.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1565,7 +1565,14 @@ def _sample_async(
15651565
num_context_logits_prefix_sum,
15661566
resource_manager,
15671567
)
1568-
assert flashinfer_keys_seen
1568+
1569+
# Fast greedy path bypasses flashinfer sampling, so flashinfer_keys_seen
1570+
# will be empty when all requests are greedy
1571+
all_greedy = all(
1572+
_request_strategy(req, vocab_size=2**31) == GREEDY
1573+
for req in scheduled_requests.all_requests()
1574+
)
1575+
assert flashinfer_keys_seen or all_greedy
15691576
return res
15701577

15711578
patch_ctx.setattr(sampler, "sample_async", _sample_async)

0 commit comments

Comments
 (0)