Skip to content

Commit 9913b28

Browse files
committed
Speculative One Model: FlashInfer sampling
Signed-off-by: Izzy Putterman <[email protected]>
1 parent 74832a1 commit 9913b28

File tree

2 files changed

+18
-1
lines changed

2 files changed

+18
-1
lines changed

tensorrt_llm/_torch/speculative/eagle3.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -367,6 +367,9 @@ def __init__(self, spec_config: "EagleDecodingConfig", mapping: Mapping):
367367
self.mapping = mapping
368368
self.guided_decoder: Optional[CapturableGuidedDecoder] = None
369369
self.force_num_accepted_tokens = get_force_num_accepted_tokens()
370+
self.use_flashinfer = False
371+
self.seed = 0
372+
self.offset = 0
370373

371374
# Skip torch.compile for now since current Torch is not compatible with Triton 3.4
372375
# @torch.compile(options={"max-autotune": True})
@@ -529,9 +532,17 @@ def _sample_tokens_for_batch(
529532
temperatures = spec_metadata.temperatures[:num_tokens]
530533
top_ks = spec_metadata.top_ks[:num_tokens]
531534
top_ps = spec_metadata.top_ps[:num_tokens]
535+
if self.use_flashinfer:
536+
self.seed += 1
532537

533538
sampled_tokens = sampling_batch_spec_dec_one_model(
534-
logits, temperatures, top_ks, top_ps)
539+
logits,
540+
temperatures,
541+
top_ks,
542+
top_ps,
543+
use_flashinfer=self.use_flashinfer,
544+
seed=self.seed,
545+
offset=self.offset)
535546
else:
536547
sampled_tokens = torch.argmax(logits, dim=-1)
537548

tensorrt_llm/_torch/speculative/one_model_sampler.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from typing import Optional
22

33
import torch
4+
from flashinfer.sampling import top_k_top_p_sampling_from_logits
45

56

67
def forward_native(
@@ -78,6 +79,9 @@ def sampling_batch_spec_dec_one_model(
7879
temperatures: torch.Tensor,
7980
top_k: torch.Tensor,
8081
top_p: torch.Tensor,
82+
use_flashinfer: bool = False,
83+
seed: Optional[int] = None,
84+
offset: Optional[int] = None,
8185
) -> tuple[torch.Tensor, torch.Tensor]:
8286
"""
8387
CUDA-graph compatible sampling. Supports mixed sampling params.
@@ -87,5 +91,7 @@ def sampling_batch_spec_dec_one_model(
8791
sampling is opt-in for now.
8892
"""
8993
logits = apply_temperature(logits, temperatures)
94+
if use_flashinfer:
95+
return top_k_top_p_sampling_from_logits(logits, top_k, top_p, seed=seed, offset=offset)
9096
random_sampled = forward_native(logits, top_k, top_p)
9197
return random_sampled

0 commit comments

Comments
 (0)