Skip to content

Commit a026a9e

Browse files
committed
Implement sampling for MTP
Signed-off-by: Mike Iovine <[email protected]>
1 parent 6732c76 commit a026a9e

File tree

4 files changed

+85
-62
lines changed

4 files changed

+85
-62
lines changed

tensorrt_llm/_torch/speculative/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from .auto_heuristic import suggest_spec_config
22
from .eagle3 import Eagle3SpecMetadata
3-
from .interface import SpecMetadata
3+
from .interface import SpecMetadata, SpecWorkerBase
44
from .mtp import MTPEagleWorker, MTPSpecMetadata, MTPWorker
55
from .ngram import NGramDrafter, NGramPoolManager
66
from .save_hidden_state import SaveHiddenStatesDrafter
@@ -19,6 +19,7 @@
1919
"NGramPoolManager",
2020
"SaveHiddenStatesDrafter",
2121
"SpecMetadata",
22+
"SpecWorkerBase",
2223
"get_num_extra_kv_tokens",
2324
"get_num_spec_layers",
2425
"get_spec_decoder",

tensorrt_llm/_torch/speculative/eagle3.py

Lines changed: 7 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,12 @@
77
from tensorrt_llm.mapping import Mapping
88

99
from ..attention_backend import AttentionMetadata
10-
from ..pyexecutor.guided_decoder import CapturableGuidedDecoder
1110
from ..pyexecutor.llm_request import LlmRequest
1211
from ..pyexecutor.resource_manager import BaseResourceManager, SlotManager
1312
from ..pyexecutor.sampler import TorchSampler
1413
from ..pyexecutor.scheduler import ScheduledRequests
15-
from .interface import SpecMetadata, get_force_num_accepted_tokens
14+
from .interface import SpecMetadata, SpecWorkerBase
1615
from .mtp import MTPSampler
17-
from .one_model_sampler import sampling_batch_spec_dec_one_model
1816
from .spec_tree_manager import SpecTreeManager
1917

2018
if TYPE_CHECKING:
@@ -358,15 +356,16 @@ def __init__(self, args: TorchSampler.Args):
358356
super().__init__(args, nextn=args.max_draft_len)
359357

360358

361-
class Eagle3OneModelWorker(nn.Module):
359+
class Eagle3OneModelWorker(SpecWorkerBase):
362360

363361
def __init__(self, spec_config: "EagleDecodingConfig", mapping: Mapping):
364362
super().__init__()
365363
self.spec_config = spec_config
366-
self.max_draft_len = self.spec_config.max_draft_len
367364
self.mapping = mapping
368-
self.guided_decoder: Optional[CapturableGuidedDecoder] = None
369-
self.force_num_accepted_tokens = get_force_num_accepted_tokens()
365+
366+
@property
367+
def max_draft_len(self) -> int:
368+
return self.spec_config.max_draft_len
370369

371370
# Skip torch.compile for now since current Torch is not compatible with Triton 3.4
372371
# @torch.compile(options={"max-autotune": True})
@@ -503,40 +502,6 @@ def forward(self, input_ids, position_ids, hidden_states, logits,
503502
'next_new_tokens': next_new_tokens,
504503
}
505504

506-
def _sample_tokens_for_batch(
507-
self,
508-
logits: torch.Tensor,
509-
spec_metadata: Eagle3OneModelSpecMetadata,
510-
num_contexts: int,
511-
batch_size: int,
512-
) -> torch.Tensor:
513-
"""
514-
Sample tokens from logits using per-request sampling parameters.
515-
Supports both greedy and non-greedy sampling.
516-
517-
Args:
518-
logits: [num_tokens, vocab_size] - Logits to sample from
519-
spec_metadata: Metadata containing sampling parameters
520-
batch_size: Number of requests in the batch
521-
522-
Returns:
523-
sampled_tokens: [num_tokens] - Sampled token ids
524-
"""
525-
if spec_metadata.allow_advanced_sampling:
526-
num_gens = batch_size - num_contexts
527-
num_tokens = num_contexts + num_gens * (self.max_draft_len + 1)
528-
529-
temperatures = spec_metadata.temperatures[:num_tokens]
530-
top_ks = spec_metadata.top_ks[:num_tokens]
531-
top_ps = spec_metadata.top_ps[:num_tokens]
532-
533-
sampled_tokens = sampling_batch_spec_dec_one_model(
534-
logits, temperatures, top_ks, top_ps)
535-
else:
536-
sampled_tokens = torch.argmax(logits, dim=-1)
537-
538-
return sampled_tokens
539-
540505
def sample_and_accept_draft_tokens(
541506
self,
542507
logits: torch.Tensor,
@@ -587,7 +552,7 @@ def draft_decoder(
587552
draft_model: nn.Module,
588553
):
589554
'''
590-
Sampling draft tokens.
555+
Sampling draft tokens with support for non-greedy sampling.
591556
592557
Args:
593558
logits: torch.Tensor
@@ -658,8 +623,3 @@ def prepare_1st_drafter_inputs(
658623
"attn_metadata": attn_metadata,
659624
"spec_metadata": spec_metadata,
660625
}
661-
662-
def set_guided_decoder(self,
663-
guided_decoder: CapturableGuidedDecoder) -> bool:
664-
self.guided_decoder = guided_decoder
665-
return True

tensorrt_llm/_torch/speculative/interface.py

Lines changed: 68 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,22 @@
11
import copy
22
import os
3+
from abc import ABC, abstractmethod
34
from dataclasses import dataclass, field
45
from enum import IntEnum, auto
5-
from typing import List, Optional, Type
6+
from typing import TYPE_CHECKING, List, Optional, Type
67

78
import torch
9+
from torch import nn
810

911
from tensorrt_llm.logger import logger
1012

1113
from ..._utils import get_sm_version
1214
from ..attention_backend.trtllm import AttentionBackend, TrtllmAttention
1315
from ..pyexecutor.resource_manager import BaseResourceManager
1416

17+
if TYPE_CHECKING:
18+
from ..pyexecutor.guided_decoder import CapturableGuidedDecoder
19+
1520
# Environment variable name for forcing the number of accepted tokens in speculative decoding
1621
FORCE_NUM_ACCEPTED_TOKENS_ENV_VAR = "TLLM_SPEC_DECODE_FORCE_NUM_ACCEPTED_TOKENS"
1722

@@ -351,3 +356,65 @@ def populate_sampling_params_for_one_model(
351356
dtype=torch.float32,
352357
pin_memory=True),
353358
non_blocking=True)
359+
360+
361+
class SpecWorkerBase(nn.Module, ABC):
362+
"""
363+
Base class for speculative decoding workers.
364+
Provides common functionality for sampling and token handling.
365+
"""
366+
367+
def __init__(self):
368+
super().__init__()
369+
self.guided_decoder: Optional["CapturableGuidedDecoder"] = None
370+
self.force_num_accepted_tokens = get_force_num_accepted_tokens()
371+
372+
@property
373+
@abstractmethod
374+
def max_draft_len(self) -> int:
375+
"""
376+
Returns the maximum draft length for this worker.
377+
Subclasses should override this property.
378+
"""
379+
380+
def set_guided_decoder(self,
381+
guided_decoder: "CapturableGuidedDecoder") -> bool:
382+
self.guided_decoder = guided_decoder
383+
return True
384+
385+
def _sample_tokens_for_batch(
386+
self,
387+
logits: torch.Tensor,
388+
spec_metadata: SpecMetadata,
389+
num_contexts: int,
390+
batch_size: int,
391+
) -> torch.Tensor:
392+
"""
393+
Sample tokens from logits using per-request sampling parameters.
394+
Supports both greedy and non-greedy sampling.
395+
396+
Args:
397+
logits: [num_tokens, vocab_size] - Logits to sample from
398+
spec_metadata: Metadata containing sampling parameters
399+
num_contexts: Number of context requests in the batch
400+
batch_size: Number of requests in the batch
401+
402+
Returns:
403+
sampled_tokens: [num_tokens] - Sampled token ids
404+
"""
405+
if spec_metadata.allow_advanced_sampling:
406+
from .one_model_sampler import sampling_batch_spec_dec_one_model
407+
408+
num_gens = batch_size - num_contexts
409+
num_tokens = num_contexts + num_gens * (self.max_draft_len + 1)
410+
411+
temperatures = spec_metadata.temperatures[:num_tokens]
412+
top_ks = spec_metadata.top_ks[:num_tokens]
413+
top_ps = spec_metadata.top_ps[:num_tokens]
414+
415+
sampled_tokens = sampling_batch_spec_dec_one_model(
416+
logits, temperatures, top_ks, top_ps)
417+
else:
418+
sampled_tokens = torch.argmax(logits, dim=-1)
419+
420+
return sampled_tokens

tensorrt_llm/_torch/speculative/mtp.py

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,21 +3,19 @@
33

44
import torch
55
import torch.nn.functional as F
6-
from torch import nn
76

87
from tensorrt_llm.mapping import Mapping
98

109
from ..attention_backend import AttentionMetadata
1110
from ..distributed.ops import allgather
1211
from ..model_config import ModelConfig
13-
from ..pyexecutor.guided_decoder import CapturableGuidedDecoder
1412
from ..pyexecutor.llm_request import LlmRequest, LlmRequestState
1513
from ..pyexecutor.resource_manager import BaseResourceManager, SlotManager
1614
from ..pyexecutor.sampler import (DEFAULT_BEAM_IDX, SampleState,
1715
SampleStateTensors, TorchSampler, add_token,
1816
int_tensor)
1917
from ..pyexecutor.scheduler import ScheduledRequests
20-
from .interface import SpecMetadata, get_force_num_accepted_tokens
18+
from .interface import SpecMetadata, SpecWorkerBase
2119

2220
if TYPE_CHECKING:
2321
from tensorrt_llm.llmapi.llm_args import MTPDecodingConfig
@@ -349,15 +347,17 @@ def sample_async(
349347
sampler_event=sampler_event)
350348

351349

352-
class MTPWorker(nn.Module):
350+
class MTPWorker(SpecWorkerBase):
353351

354352
def __init__(self, spec_config: "MTPDecodingConfig", model_config=None):
355353
super().__init__()
356354
self.spec_config = spec_config
357355
self.model_config = model_config
358356
self.is_thop = False
359-
self.guided_decoder: Optional[CapturableGuidedDecoder] = None
360-
self.force_num_accepted_tokens = get_force_num_accepted_tokens()
357+
358+
@property
359+
def max_draft_len(self) -> int:
360+
return self.spec_config.num_nextn_predict_layers
361361

362362
def forward(
363363
self,
@@ -889,8 +889,8 @@ def sample_and_accept_draft_tokens(
889889
logits, spec_metadata.draft_tokens, target_tokens_cache,
890890
mtp_num_modules, batch_size, num_contexts, logits.shape[-1])
891891
else:
892-
# Do greedy sampling for the input logits
893-
target_tokens = torch.argmax(logits, dim=-1)
892+
target_tokens = self._sample_tokens_for_batch(
893+
logits, spec_metadata, num_contexts, batch_size)
894894

895895
# context
896896
accepted_tokens[:num_contexts, 0] = target_tokens[:num_contexts]
@@ -1173,11 +1173,6 @@ def draft_sampler(
11731173

11741174
return draft_tokens
11751175

1176-
def set_guided_decoder(self,
1177-
guided_decoder: CapturableGuidedDecoder) -> bool:
1178-
self.guided_decoder = guided_decoder
1179-
return True
1180-
11811176

11821177
class MTPEagleWorker(MTPWorker):
11831178

0 commit comments

Comments
 (0)