Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ ordered-set
peft
patchelf
einops
flashinfer-python>=0.3.0,<0.4.0
flashinfer-python==0.6.0rc2
opencv-python-headless
xgrammar==0.1.25
llguidance==0.7.29
Expand All @@ -73,7 +73,7 @@ nvidia-cutlass-dsl==4.3.4; python_version >= "3.10"
plotly
numexpr<2.14.0 # WAR for attempted use of nonexistent numpy.typing
partial_json_parser
apache-tvm-ffi==0.1.4 # used for reduce nvidia-cutlass-dsl host overhead
apache-tvm-ffi==0.1.7 # used for reduce nvidia-cutlass-dsl host overhead
torch-c-dlpack-ext==0.1.3 # used for reduce nvidia-cutlass-dsl host overhead, optional package for improved torch tensor calling perf
mistral-common==1.8.6
torchao>=0.14.1
18 changes: 17 additions & 1 deletion tensorrt_llm/_torch/speculative/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,15 @@

from ..._utils import get_sm_version
from ..attention_backend.trtllm import AttentionBackend, TrtllmAttention
from ..flashinfer_utils import IS_FLASHINFER_AVAILABLE
from ..pyexecutor.resource_manager import BaseResourceManager

if TYPE_CHECKING:
from ..pyexecutor.guided_decoder import CapturableGuidedDecoder

if IS_FLASHINFER_AVAILABLE:
import flashinfer

# Environment variable name for forcing the number of accepted tokens in speculative decoding
FORCE_NUM_ACCEPTED_TOKENS_ENV_VAR = "TLLM_SPEC_DECODE_FORCE_NUM_ACCEPTED_TOKENS"

Expand Down Expand Up @@ -368,6 +372,9 @@ def __init__(self):
super().__init__()
self.guided_decoder: Optional["CapturableGuidedDecoder"] = None
self.force_num_accepted_tokens = get_force_num_accepted_tokens()
self.use_flashinfer = IS_FLASHINFER_AVAILABLE and flashinfer.__version__ >= "0.6.0"
self.seed = 0
self.offset = 0

@property
@abstractmethod
Expand Down Expand Up @@ -412,8 +419,17 @@ def _sample_tokens_for_batch(
top_ks = spec_metadata.top_ks[:num_tokens]
top_ps = spec_metadata.top_ps[:num_tokens]

if self.use_flashinfer:
self.seed += 1

sampled_tokens = sampling_batch_spec_dec_one_model(
logits, temperatures, top_ks, top_ps)
logits,
temperatures,
top_ks,
top_ps,
use_flashinfer=self.use_flashinfer,
seed=self.seed,
offset=self.offset)
else:
sampled_tokens = torch.argmax(logits, dim=-1)

Expand Down
6 changes: 6 additions & 0 deletions tensorrt_llm/_torch/speculative/one_model_sampler.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Optional

import torch
from flashinfer.sampling import top_k_top_p_sampling_from_logits


def forward_native(
Expand Down Expand Up @@ -78,6 +79,9 @@ def sampling_batch_spec_dec_one_model(
temperatures: torch.Tensor,
top_k: torch.Tensor,
top_p: torch.Tensor,
use_flashinfer: bool = False,
seed: Optional[int] = None,
offset: Optional[int] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
CUDA-graph compatible sampling. Supports mixed sampling params.
Expand All @@ -87,5 +91,7 @@ def sampling_batch_spec_dec_one_model(
sampling is opt-in for now.
"""
logits = apply_temperature(logits, temperatures)
if use_flashinfer:
return top_k_top_p_sampling_from_logits(logits, top_k, top_p, seed=seed, offset=offset)
random_sampled = forward_native(logits, top_k, top_p)
return random_sampled