diff --git a/tensorrt_llm/_torch/speculative/__init__.py b/tensorrt_llm/_torch/speculative/__init__.py index 31ed71f76f3..a99a88b514a 100644 --- a/tensorrt_llm/_torch/speculative/__init__.py +++ b/tensorrt_llm/_torch/speculative/__init__.py @@ -1,6 +1,6 @@ from .auto_heuristic import suggest_spec_config from .eagle3 import Eagle3SpecMetadata -from .interface import SpecMetadata +from .interface import SpecMetadata, SpecWorkerBase from .mtp import MTPEagleWorker, MTPSpecMetadata, MTPWorker from .ngram import NGramDrafter, NGramPoolManager from .save_hidden_state import SaveHiddenStatesDrafter @@ -19,6 +19,7 @@ "NGramPoolManager", "SaveHiddenStatesDrafter", "SpecMetadata", + "SpecWorkerBase", "get_num_extra_kv_tokens", "get_num_spec_layers", "get_spec_decoder", diff --git a/tensorrt_llm/_torch/speculative/eagle3.py b/tensorrt_llm/_torch/speculative/eagle3.py index 58b5a82e98a..c3da1c70174 100644 --- a/tensorrt_llm/_torch/speculative/eagle3.py +++ b/tensorrt_llm/_torch/speculative/eagle3.py @@ -7,14 +7,12 @@ from tensorrt_llm.mapping import Mapping from ..attention_backend import AttentionMetadata -from ..pyexecutor.guided_decoder import CapturableGuidedDecoder from ..pyexecutor.llm_request import LlmRequest from ..pyexecutor.resource_manager import BaseResourceManager, SlotManager from ..pyexecutor.sampler import TorchSampler from ..pyexecutor.scheduler import ScheduledRequests -from .interface import SpecMetadata, get_force_num_accepted_tokens +from .interface import SpecMetadata, SpecWorkerBase from .mtp import MTPSampler -from .one_model_sampler import sampling_batch_spec_dec_one_model from .spec_tree_manager import SpecTreeManager if TYPE_CHECKING: @@ -358,15 +356,16 @@ def __init__(self, args: TorchSampler.Args): super().__init__(args, nextn=args.max_draft_len) -class Eagle3OneModelWorker(nn.Module): +class Eagle3OneModelWorker(SpecWorkerBase): def __init__(self, spec_config: "EagleDecodingConfig", mapping: Mapping): super().__init__() self.spec_config = spec_config - self.max_draft_len = self.spec_config.max_draft_len self.mapping = mapping - self.guided_decoder: Optional[CapturableGuidedDecoder] = None - self.force_num_accepted_tokens = get_force_num_accepted_tokens() + + @property + def max_draft_len(self) -> int: + return self.spec_config.max_draft_len # Skip torch.compile for now since current Torch is not compatible with Triton 3.4 # @torch.compile(options={"max-autotune": True}) @@ -503,40 +502,6 @@ def forward(self, input_ids, position_ids, hidden_states, logits, 'next_new_tokens': next_new_tokens, } - def _sample_tokens_for_batch( - self, - logits: torch.Tensor, - spec_metadata: Eagle3OneModelSpecMetadata, - num_contexts: int, - batch_size: int, - ) -> torch.Tensor: - """ - Sample tokens from logits using per-request sampling parameters. - Supports both greedy and non-greedy sampling. - - Args: - logits: [num_tokens, vocab_size] - Logits to sample from - spec_metadata: Metadata containing sampling parameters - batch_size: Number of requests in the batch - - Returns: - sampled_tokens: [num_tokens] - Sampled token ids - """ - if spec_metadata.allow_advanced_sampling: - num_gens = batch_size - num_contexts - num_tokens = num_contexts + num_gens * (self.max_draft_len + 1) - - temperatures = spec_metadata.temperatures[:num_tokens] - top_ks = spec_metadata.top_ks[:num_tokens] - top_ps = spec_metadata.top_ps[:num_tokens] - - sampled_tokens = sampling_batch_spec_dec_one_model( - logits, temperatures, top_ks, top_ps) - else: - sampled_tokens = torch.argmax(logits, dim=-1) - - return sampled_tokens - def sample_and_accept_draft_tokens( self, logits: torch.Tensor, @@ -587,7 +552,7 @@ def draft_decoder( draft_model: nn.Module, ): ''' - Sampling draft tokens. + Sampling draft tokens with support for non-greedy sampling. Args: logits: torch.Tensor @@ -658,8 +623,3 @@ def prepare_1st_drafter_inputs( "attn_metadata": attn_metadata, "spec_metadata": spec_metadata, } - - def set_guided_decoder(self, - guided_decoder: CapturableGuidedDecoder) -> bool: - self.guided_decoder = guided_decoder - return True diff --git a/tensorrt_llm/_torch/speculative/interface.py b/tensorrt_llm/_torch/speculative/interface.py index 9bf262b3cbc..59a5e0129cf 100644 --- a/tensorrt_llm/_torch/speculative/interface.py +++ b/tensorrt_llm/_torch/speculative/interface.py @@ -1,10 +1,12 @@ import copy import os +from abc import ABC, abstractmethod from dataclasses import dataclass, field from enum import IntEnum, auto -from typing import List, Optional, Type +from typing import TYPE_CHECKING, List, Optional, Type import torch +from torch import nn from tensorrt_llm.logger import logger @@ -12,6 +14,9 @@ from ..attention_backend.trtllm import AttentionBackend, TrtllmAttention from ..pyexecutor.resource_manager import BaseResourceManager +if TYPE_CHECKING: + from ..pyexecutor.guided_decoder import CapturableGuidedDecoder + # Environment variable name for forcing the number of accepted tokens in speculative decoding FORCE_NUM_ACCEPTED_TOKENS_ENV_VAR = "TLLM_SPEC_DECODE_FORCE_NUM_ACCEPTED_TOKENS" @@ -351,3 +356,65 @@ def populate_sampling_params_for_one_model( dtype=torch.float32, pin_memory=True), non_blocking=True) + + +class SpecWorkerBase(nn.Module, ABC): + """ + Base class for speculative decoding workers. + Provides common functionality for sampling and token handling. + """ + + def __init__(self): + super().__init__() + self.guided_decoder: Optional["CapturableGuidedDecoder"] = None + self.force_num_accepted_tokens = get_force_num_accepted_tokens() + + @property + @abstractmethod + def max_draft_len(self) -> int: + """ + Returns the maximum draft length for this worker. + Subclasses should override this property. + """ + + def set_guided_decoder(self, + guided_decoder: "CapturableGuidedDecoder") -> bool: + self.guided_decoder = guided_decoder + return True + + def _sample_tokens_for_batch( + self, + logits: torch.Tensor, + spec_metadata: SpecMetadata, + num_contexts: int, + batch_size: int, + ) -> torch.Tensor: + """ + Sample tokens from logits using per-request sampling parameters. + Supports both greedy and non-greedy sampling. + + Args: + logits: [num_tokens, vocab_size] - Logits to sample from + spec_metadata: Metadata containing sampling parameters + num_contexts: Number of context requests in the batch + batch_size: Number of requests in the batch + + Returns: + sampled_tokens: [num_tokens] - Sampled token ids + """ + if spec_metadata.allow_advanced_sampling: + from .one_model_sampler import sampling_batch_spec_dec_one_model + + num_gens = batch_size - num_contexts + num_tokens = num_contexts + num_gens * (self.max_draft_len + 1) + + temperatures = spec_metadata.temperatures[:num_tokens] + top_ks = spec_metadata.top_ks[:num_tokens] + top_ps = spec_metadata.top_ps[:num_tokens] + + sampled_tokens = sampling_batch_spec_dec_one_model( + logits, temperatures, top_ks, top_ps) + else: + sampled_tokens = torch.argmax(logits, dim=-1) + + return sampled_tokens diff --git a/tensorrt_llm/_torch/speculative/mtp.py b/tensorrt_llm/_torch/speculative/mtp.py index 29917d820a4..8d037275a04 100644 --- a/tensorrt_llm/_torch/speculative/mtp.py +++ b/tensorrt_llm/_torch/speculative/mtp.py @@ -3,21 +3,19 @@ import torch import torch.nn.functional as F -from torch import nn from tensorrt_llm.mapping import Mapping from ..attention_backend import AttentionMetadata from ..distributed.ops import allgather from ..model_config import ModelConfig -from ..pyexecutor.guided_decoder import CapturableGuidedDecoder from ..pyexecutor.llm_request import LlmRequest, LlmRequestState from ..pyexecutor.resource_manager import BaseResourceManager, SlotManager from ..pyexecutor.sampler import (DEFAULT_BEAM_IDX, SampleState, SampleStateTensors, TorchSampler, add_token, int_tensor) from ..pyexecutor.scheduler import ScheduledRequests -from .interface import SpecMetadata, get_force_num_accepted_tokens +from .interface import SpecMetadata, SpecWorkerBase if TYPE_CHECKING: from tensorrt_llm.llmapi.llm_args import MTPDecodingConfig @@ -349,15 +347,17 @@ def sample_async( sampler_event=sampler_event) -class MTPWorker(nn.Module): +class MTPWorker(SpecWorkerBase): def __init__(self, spec_config: "MTPDecodingConfig", model_config=None): super().__init__() self.spec_config = spec_config self.model_config = model_config self.is_thop = False - self.guided_decoder: Optional[CapturableGuidedDecoder] = None - self.force_num_accepted_tokens = get_force_num_accepted_tokens() + + @property + def max_draft_len(self) -> int: + return self.spec_config.num_nextn_predict_layers def forward( self, @@ -889,8 +889,8 @@ def sample_and_accept_draft_tokens( logits, spec_metadata.draft_tokens, target_tokens_cache, mtp_num_modules, batch_size, num_contexts, logits.shape[-1]) else: - # Do greedy sampling for the input logits - target_tokens = torch.argmax(logits, dim=-1) + target_tokens = self._sample_tokens_for_batch( + logits, spec_metadata, num_contexts, batch_size) # context accepted_tokens[:num_contexts, 0] = target_tokens[:num_contexts] @@ -1173,11 +1173,6 @@ def draft_sampler( return draft_tokens - def set_guided_decoder(self, - guided_decoder: CapturableGuidedDecoder) -> bool: - self.guided_decoder = guided_decoder - return True - class MTPEagleWorker(MTPWorker):