Skip to content

Commit e3d70bd

Browse files
committed
Speculative One Model: FlashInfer sampling
Signed-off-by: Izzy Putterman <iputterman@nvidia.com>
1 parent 46f035b commit e3d70bd

File tree

3 files changed

+25
-3
lines changed

3 files changed

+25
-3
lines changed

requirements.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ ordered-set
5252
peft
5353
patchelf
5454
einops
55-
flashinfer-python>=0.3.0,<0.4.0
55+
flashinfer-python==0.6.0rc2
5656
opencv-python-headless
5757
xgrammar==0.1.25
5858
llguidance==0.7.29
@@ -73,7 +73,7 @@ nvidia-cutlass-dsl==4.3.4; python_version >= "3.10"
7373
plotly
7474
numexpr<2.14.0 # WAR for attempted use of nonexistent numpy.typing
7575
partial_json_parser
76-
apache-tvm-ffi==0.1.4 # used for reduce nvidia-cutlass-dsl host overhead
76+
apache-tvm-ffi==0.1.7 # used for reduce nvidia-cutlass-dsl host overhead
7777
torch-c-dlpack-ext==0.1.3 # used for reduce nvidia-cutlass-dsl host overhead, optional package for improved torch tensor calling perf
7878
mistral-common==1.8.6
7979
torchao>=0.14.1

tensorrt_llm/_torch/speculative/interface.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,15 @@
1212

1313
from ..._utils import get_sm_version
1414
from ..attention_backend.trtllm import AttentionBackend, TrtllmAttention
15+
from ..flashinfer_utils import IS_FLASHINFER_AVAILABLE
1516
from ..pyexecutor.resource_manager import BaseResourceManager
1617

1718
if TYPE_CHECKING:
1819
from ..pyexecutor.guided_decoder import CapturableGuidedDecoder
1920

21+
if IS_FLASHINFER_AVAILABLE:
22+
import flashinfer
23+
2024
# Environment variable name for forcing the number of accepted tokens in speculative decoding
2125
FORCE_NUM_ACCEPTED_TOKENS_ENV_VAR = "TLLM_SPEC_DECODE_FORCE_NUM_ACCEPTED_TOKENS"
2226

@@ -368,6 +372,9 @@ def __init__(self):
368372
super().__init__()
369373
self.guided_decoder: Optional["CapturableGuidedDecoder"] = None
370374
self.force_num_accepted_tokens = get_force_num_accepted_tokens()
375+
self.use_flashinfer = IS_FLASHINFER_AVAILABLE and flashinfer.__version__ >= "0.6.0"
376+
self.seed = 0
377+
self.offset = 0
371378

372379
@property
373380
@abstractmethod
@@ -412,8 +419,17 @@ def _sample_tokens_for_batch(
412419
top_ks = spec_metadata.top_ks[:num_tokens]
413420
top_ps = spec_metadata.top_ps[:num_tokens]
414421

422+
if self.use_flashinfer:
423+
self.seed += 1
424+
415425
sampled_tokens = sampling_batch_spec_dec_one_model(
416-
logits, temperatures, top_ks, top_ps)
426+
logits,
427+
temperatures,
428+
top_ks,
429+
top_ps,
430+
use_flashinfer=self.use_flashinfer,
431+
seed=self.seed,
432+
offset=self.offset)
417433
else:
418434
sampled_tokens = torch.argmax(logits, dim=-1)
419435

tensorrt_llm/_torch/speculative/one_model_sampler.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from typing import Optional
22

33
import torch
4+
from flashinfer.sampling import top_k_top_p_sampling_from_logits
45

56

67
def forward_native(
@@ -78,6 +79,9 @@ def sampling_batch_spec_dec_one_model(
7879
temperatures: torch.Tensor,
7980
top_k: torch.Tensor,
8081
top_p: torch.Tensor,
82+
use_flashinfer: bool = False,
83+
seed: Optional[int] = None,
84+
offset: Optional[int] = None,
8185
) -> tuple[torch.Tensor, torch.Tensor]:
8286
"""
8387
CUDA-graph compatible sampling. Supports mixed sampling params.
@@ -87,5 +91,7 @@ def sampling_batch_spec_dec_one_model(
8791
sampling is opt-in for now.
8892
"""
8993
logits = apply_temperature(logits, temperatures)
94+
if use_flashinfer:
95+
return top_k_top_p_sampling_from_logits(logits, top_k, top_p, seed=seed, offset=offset)
9096
random_sampled = forward_native(logits, top_k, top_p)
9197
return random_sampled

0 commit comments

Comments
 (0)