Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion tensorrt_llm/_torch/speculative/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from .auto_heuristic import suggest_spec_config
from .eagle3 import Eagle3SpecMetadata
from .interface import SpecMetadata
from .interface import SpecMetadata, SpecWorkerBase
from .mtp import MTPEagleWorker, MTPSpecMetadata, MTPWorker
from .ngram import NGramDrafter, NGramPoolManager
from .save_hidden_state import SaveHiddenStatesDrafter
Expand All @@ -19,6 +19,7 @@
"NGramPoolManager",
"SaveHiddenStatesDrafter",
"SpecMetadata",
"SpecWorkerBase",
"get_num_extra_kv_tokens",
"get_num_spec_layers",
"get_spec_decoder",
Expand Down
54 changes: 7 additions & 47 deletions tensorrt_llm/_torch/speculative/eagle3.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,12 @@
from tensorrt_llm.mapping import Mapping

from ..attention_backend import AttentionMetadata
from ..pyexecutor.guided_decoder import CapturableGuidedDecoder
from ..pyexecutor.llm_request import LlmRequest
from ..pyexecutor.resource_manager import BaseResourceManager, SlotManager
from ..pyexecutor.sampler import TorchSampler
from ..pyexecutor.scheduler import ScheduledRequests
from .interface import SpecMetadata, get_force_num_accepted_tokens
from .interface import SpecMetadata, SpecWorkerBase
from .mtp import MTPSampler
from .one_model_sampler import sampling_batch_spec_dec_one_model
from .spec_tree_manager import SpecTreeManager

if TYPE_CHECKING:
Expand Down Expand Up @@ -358,15 +356,16 @@ def __init__(self, args: TorchSampler.Args):
super().__init__(args, nextn=args.max_draft_len)


class Eagle3OneModelWorker(nn.Module):
class Eagle3OneModelWorker(SpecWorkerBase):

def __init__(self, spec_config: "EagleDecodingConfig", mapping: Mapping):
super().__init__()
self.spec_config = spec_config
self.max_draft_len = self.spec_config.max_draft_len
self.mapping = mapping
self.guided_decoder: Optional[CapturableGuidedDecoder] = None
self.force_num_accepted_tokens = get_force_num_accepted_tokens()

@property
def max_draft_len(self) -> int:
return self.spec_config.max_draft_len

# Skip torch.compile for now since current Torch is not compatible with Triton 3.4
# @torch.compile(options={"max-autotune": True})
Expand Down Expand Up @@ -503,40 +502,6 @@ def forward(self, input_ids, position_ids, hidden_states, logits,
'next_new_tokens': next_new_tokens,
}

def _sample_tokens_for_batch(
self,
logits: torch.Tensor,
spec_metadata: Eagle3OneModelSpecMetadata,
num_contexts: int,
batch_size: int,
) -> torch.Tensor:
"""
Sample tokens from logits using per-request sampling parameters.
Supports both greedy and non-greedy sampling.

Args:
logits: [num_tokens, vocab_size] - Logits to sample from
spec_metadata: Metadata containing sampling parameters
batch_size: Number of requests in the batch

Returns:
sampled_tokens: [num_tokens] - Sampled token ids
"""
if spec_metadata.allow_advanced_sampling:
num_gens = batch_size - num_contexts
num_tokens = num_contexts + num_gens * (self.max_draft_len + 1)

temperatures = spec_metadata.temperatures[:num_tokens]
top_ks = spec_metadata.top_ks[:num_tokens]
top_ps = spec_metadata.top_ps[:num_tokens]

sampled_tokens = sampling_batch_spec_dec_one_model(
logits, temperatures, top_ks, top_ps)
else:
sampled_tokens = torch.argmax(logits, dim=-1)

return sampled_tokens

def sample_and_accept_draft_tokens(
self,
logits: torch.Tensor,
Expand Down Expand Up @@ -587,7 +552,7 @@ def draft_decoder(
draft_model: nn.Module,
):
'''
Sampling draft tokens.
Sampling draft tokens with support for non-greedy sampling.

Args:
logits: torch.Tensor
Expand Down Expand Up @@ -658,8 +623,3 @@ def prepare_1st_drafter_inputs(
"attn_metadata": attn_metadata,
"spec_metadata": spec_metadata,
}

def set_guided_decoder(self,
guided_decoder: CapturableGuidedDecoder) -> bool:
self.guided_decoder = guided_decoder
return True
69 changes: 68 additions & 1 deletion tensorrt_llm/_torch/speculative/interface.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,22 @@
import copy
import os
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from enum import IntEnum, auto
from typing import List, Optional, Type
from typing import TYPE_CHECKING, List, Optional, Type

import torch
from torch import nn

from tensorrt_llm.logger import logger

from ..._utils import get_sm_version
from ..attention_backend.trtllm import AttentionBackend, TrtllmAttention
from ..pyexecutor.resource_manager import BaseResourceManager

if TYPE_CHECKING:
from ..pyexecutor.guided_decoder import CapturableGuidedDecoder

# Environment variable name for forcing the number of accepted tokens in speculative decoding
FORCE_NUM_ACCEPTED_TOKENS_ENV_VAR = "TLLM_SPEC_DECODE_FORCE_NUM_ACCEPTED_TOKENS"

Expand Down Expand Up @@ -351,3 +356,65 @@ def populate_sampling_params_for_one_model(
dtype=torch.float32,
pin_memory=True),
non_blocking=True)


class SpecWorkerBase(nn.Module, ABC):
"""
Base class for speculative decoding workers.
Provides common functionality for sampling and token handling.
"""

def __init__(self):
super().__init__()
self.guided_decoder: Optional["CapturableGuidedDecoder"] = None
self.force_num_accepted_tokens = get_force_num_accepted_tokens()

@property
@abstractmethod
def max_draft_len(self) -> int:
"""
Returns the maximum draft length for this worker.
Subclasses should override this property.
"""

def set_guided_decoder(self,
guided_decoder: "CapturableGuidedDecoder") -> bool:
self.guided_decoder = guided_decoder
return True

def _sample_tokens_for_batch(
self,
logits: torch.Tensor,
spec_metadata: SpecMetadata,
num_contexts: int,
batch_size: int,
) -> torch.Tensor:
"""
Sample tokens from logits using per-request sampling parameters.
Supports both greedy and non-greedy sampling.

Args:
logits: [num_tokens, vocab_size] - Logits to sample from
spec_metadata: Metadata containing sampling parameters
num_contexts: Number of context requests in the batch
batch_size: Number of requests in the batch

Returns:
sampled_tokens: [num_tokens] - Sampled token ids
"""
if spec_metadata.allow_advanced_sampling:
from .one_model_sampler import sampling_batch_spec_dec_one_model

num_gens = batch_size - num_contexts
num_tokens = num_contexts + num_gens * (self.max_draft_len + 1)

temperatures = spec_metadata.temperatures[:num_tokens]
top_ks = spec_metadata.top_ks[:num_tokens]
top_ps = spec_metadata.top_ps[:num_tokens]

sampled_tokens = sampling_batch_spec_dec_one_model(
logits, temperatures, top_ks, top_ps)
else:
sampled_tokens = torch.argmax(logits, dim=-1)

return sampled_tokens
21 changes: 8 additions & 13 deletions tensorrt_llm/_torch/speculative/mtp.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,19 @@

import torch
import torch.nn.functional as F
from torch import nn

from tensorrt_llm.mapping import Mapping

from ..attention_backend import AttentionMetadata
from ..distributed.ops import allgather
from ..model_config import ModelConfig
from ..pyexecutor.guided_decoder import CapturableGuidedDecoder
from ..pyexecutor.llm_request import LlmRequest, LlmRequestState
from ..pyexecutor.resource_manager import BaseResourceManager, SlotManager
from ..pyexecutor.sampler import (DEFAULT_BEAM_IDX, SampleState,
SampleStateTensors, TorchSampler, add_token,
int_tensor)
from ..pyexecutor.scheduler import ScheduledRequests
from .interface import SpecMetadata, get_force_num_accepted_tokens
from .interface import SpecMetadata, SpecWorkerBase

if TYPE_CHECKING:
from tensorrt_llm.llmapi.llm_args import MTPDecodingConfig
Expand Down Expand Up @@ -349,15 +347,17 @@ def sample_async(
sampler_event=sampler_event)


class MTPWorker(nn.Module):
class MTPWorker(SpecWorkerBase):

def __init__(self, spec_config: "MTPDecodingConfig", model_config=None):
super().__init__()
self.spec_config = spec_config
self.model_config = model_config
self.is_thop = False
self.guided_decoder: Optional[CapturableGuidedDecoder] = None
self.force_num_accepted_tokens = get_force_num_accepted_tokens()

@property
def max_draft_len(self) -> int:
return self.spec_config.num_nextn_predict_layers

def forward(
self,
Expand Down Expand Up @@ -889,8 +889,8 @@ def sample_and_accept_draft_tokens(
logits, spec_metadata.draft_tokens, target_tokens_cache,
mtp_num_modules, batch_size, num_contexts, logits.shape[-1])
else:
# Do greedy sampling for the input logits
target_tokens = torch.argmax(logits, dim=-1)
target_tokens = self._sample_tokens_for_batch(
logits, spec_metadata, num_contexts, batch_size)

# context
accepted_tokens[:num_contexts, 0] = target_tokens[:num_contexts]
Expand Down Expand Up @@ -1173,11 +1173,6 @@ def draft_sampler(

return draft_tokens

def set_guided_decoder(self,
guided_decoder: CapturableGuidedDecoder) -> bool:
self.guided_decoder = guided_decoder
return True


class MTPEagleWorker(MTPWorker):

Expand Down