diff --git a/requirements.txt b/requirements.txt index a21b8ca2819..f977b413288 100644 --- a/requirements.txt +++ b/requirements.txt @@ -52,7 +52,7 @@ ordered-set peft patchelf einops -flashinfer-python>=0.3.0,<0.4.0 +flashinfer-python==0.6.0rc2 opencv-python-headless xgrammar==0.1.25 llguidance==0.7.29 @@ -73,7 +73,7 @@ nvidia-cutlass-dsl==4.3.4; python_version >= "3.10" plotly numexpr<2.14.0 # WAR for attempted use of nonexistent numpy.typing partial_json_parser -apache-tvm-ffi==0.1.4 # used for reduce nvidia-cutlass-dsl host overhead +apache-tvm-ffi==0.1.7 # used for reduce nvidia-cutlass-dsl host overhead torch-c-dlpack-ext==0.1.3 # used for reduce nvidia-cutlass-dsl host overhead, optional package for improved torch tensor calling perf mistral-common==1.8.6 torchao>=0.14.1 diff --git a/tensorrt_llm/_torch/speculative/interface.py b/tensorrt_llm/_torch/speculative/interface.py index 59a5e0129cf..eaa503ff7e8 100644 --- a/tensorrt_llm/_torch/speculative/interface.py +++ b/tensorrt_llm/_torch/speculative/interface.py @@ -12,11 +12,15 @@ from ..._utils import get_sm_version from ..attention_backend.trtllm import AttentionBackend, TrtllmAttention +from ..flashinfer_utils import IS_FLASHINFER_AVAILABLE from ..pyexecutor.resource_manager import BaseResourceManager if TYPE_CHECKING: from ..pyexecutor.guided_decoder import CapturableGuidedDecoder +if IS_FLASHINFER_AVAILABLE: + import flashinfer + # Environment variable name for forcing the number of accepted tokens in speculative decoding FORCE_NUM_ACCEPTED_TOKENS_ENV_VAR = "TLLM_SPEC_DECODE_FORCE_NUM_ACCEPTED_TOKENS" @@ -368,6 +372,9 @@ def __init__(self): super().__init__() self.guided_decoder: Optional["CapturableGuidedDecoder"] = None self.force_num_accepted_tokens = get_force_num_accepted_tokens() + self.use_flashinfer = IS_FLASHINFER_AVAILABLE and flashinfer.__version__ >= "0.6.0" + self.seed = 0 + self.offset = 0 @property @abstractmethod @@ -412,8 +419,17 @@ def _sample_tokens_for_batch( top_ks = spec_metadata.top_ks[:num_tokens] top_ps = spec_metadata.top_ps[:num_tokens] + if self.use_flashinfer: + self.seed += 1 + sampled_tokens = sampling_batch_spec_dec_one_model( - logits, temperatures, top_ks, top_ps) + logits, + temperatures, + top_ks, + top_ps, + use_flashinfer=self.use_flashinfer, + seed=self.seed, + offset=self.offset) else: sampled_tokens = torch.argmax(logits, dim=-1) diff --git a/tensorrt_llm/_torch/speculative/one_model_sampler.py b/tensorrt_llm/_torch/speculative/one_model_sampler.py index ca48c03f28e..7d49aa85dd1 100644 --- a/tensorrt_llm/_torch/speculative/one_model_sampler.py +++ b/tensorrt_llm/_torch/speculative/one_model_sampler.py @@ -1,6 +1,7 @@ from typing import Optional import torch +from flashinfer.sampling import top_k_top_p_sampling_from_logits def forward_native( @@ -78,6 +79,9 @@ def sampling_batch_spec_dec_one_model( temperatures: torch.Tensor, top_k: torch.Tensor, top_p: torch.Tensor, + use_flashinfer: bool = False, + seed: Optional[int] = None, + offset: Optional[int] = None, ) -> tuple[torch.Tensor, torch.Tensor]: """ CUDA-graph compatible sampling. Supports mixed sampling params. @@ -87,5 +91,7 @@ def sampling_batch_spec_dec_one_model( sampling is opt-in for now. """ logits = apply_temperature(logits, temperatures) + if use_flashinfer: + return top_k_top_p_sampling_from_logits(logits, top_k, top_p, seed=seed, offset=offset) random_sampled = forward_native(logits, top_k, top_p) return random_sampled