Skip to content

Commit 281a63f

Browse files
committed
Implement sampling for MTP
1 parent 9ba1426 commit 281a63f

File tree

4 files changed

+86
-59
lines changed

4 files changed

+86
-59
lines changed

tensorrt_llm/_torch/speculative/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from .auto_heuristic import suggest_spec_config
22
from .eagle3 import Eagle3SpecMetadata
3-
from .interface import SpecMetadata
3+
from .interface import SpecMetadata, SpecWorkerBase
44
from .mtp import MTPEagleWorker, MTPSpecMetadata, MTPWorker
55
from .ngram import NGramDrafter, NGramPoolManager
66
from .save_hidden_state import SaveHiddenStatesDrafter
@@ -19,6 +19,7 @@
1919
"NGramPoolManager",
2020
"SaveHiddenStatesDrafter",
2121
"SpecMetadata",
22+
"SpecWorkerBase",
2223
"get_num_extra_kv_tokens",
2324
"get_num_spec_layers",
2425
"get_spec_decoder",

tensorrt_llm/_torch/speculative/eagle3.py

Lines changed: 7 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,8 @@
1212
from ..pyexecutor.resource_manager import BaseResourceManager, SlotManager
1313
from ..pyexecutor.sampler import TorchSampler
1414
from ..pyexecutor.scheduler import ScheduledRequests
15-
from .interface import SpecMetadata, get_force_num_accepted_tokens
15+
from .interface import SpecMetadata, SpecWorkerBase
1616
from .mtp import MTPSampler
17-
from .one_model_sampler import sampling_batch_spec_dec_one_model
1817
from .spec_tree_manager import SpecTreeManager
1918

2019
if TYPE_CHECKING:
@@ -358,15 +357,16 @@ def __init__(self, args: TorchSampler.Args):
358357
super().__init__(args, nextn=args.max_draft_len)
359358

360359

361-
class Eagle3OneModelWorker(nn.Module):
360+
class Eagle3OneModelWorker(SpecWorkerBase):
362361

363362
def __init__(self, spec_config: "EagleDecodingConfig", mapping: Mapping):
364363
super().__init__()
365364
self.spec_config = spec_config
366-
self.max_draft_len = self.spec_config.max_draft_len
367365
self.mapping = mapping
368-
self.guided_decoder: Optional[CapturableGuidedDecoder] = None
369-
self.force_num_accepted_tokens = get_force_num_accepted_tokens()
366+
367+
@property
368+
def max_draft_len(self) -> int:
369+
return self.spec_config.max_draft_len
370370

371371
# Skip torch.compile for now since current Torch is not compatible with Triton 3.4
372372
# @torch.compile(options={"max-autotune": True})
@@ -494,40 +494,6 @@ def forward(self, input_ids, position_ids, hidden_states, logits,
494494
'next_new_tokens': next_new_tokens,
495495
}
496496

497-
def _sample_tokens_for_batch(
498-
self,
499-
logits: torch.Tensor,
500-
spec_metadata: Eagle3OneModelSpecMetadata,
501-
num_contexts: int,
502-
batch_size: int,
503-
) -> torch.Tensor:
504-
"""
505-
Sample tokens from logits using per-request sampling parameters.
506-
Supports both greedy and non-greedy sampling.
507-
508-
Args:
509-
logits: [num_tokens, vocab_size] - Logits to sample from
510-
spec_metadata: Metadata containing sampling parameters
511-
batch_size: Number of requests in the batch
512-
513-
Returns:
514-
sampled_tokens: [num_tokens] - Sampled token ids
515-
"""
516-
if spec_metadata.allow_advanced_sampling:
517-
num_gens = batch_size - num_contexts
518-
num_tokens = num_contexts + num_gens * (self.max_draft_len + 1)
519-
520-
temperatures = spec_metadata.temperatures[:num_tokens]
521-
top_ks = spec_metadata.top_ks[:num_tokens]
522-
top_ps = spec_metadata.top_ps[:num_tokens]
523-
524-
sampled_tokens = sampling_batch_spec_dec_one_model(
525-
logits, temperatures, top_ks, top_ps)
526-
else:
527-
sampled_tokens = torch.argmax(logits, dim=-1)
528-
529-
return sampled_tokens
530-
531497
def sample_and_accept_draft_tokens(
532498
self,
533499
logits: torch.Tensor,
@@ -578,7 +544,7 @@ def draft_decoder(
578544
draft_model: nn.Module,
579545
):
580546
'''
581-
Sampling draft tokens.
547+
Sampling draft tokens with support for non-greedy sampling.
582548
583549
Args:
584550
logits: torch.Tensor
@@ -649,8 +615,3 @@ def prepare_1st_drafter_inputs(
649615
"attn_metadata": attn_metadata,
650616
"spec_metadata": spec_metadata,
651617
}
652-
653-
def set_guided_decoder(self,
654-
guided_decoder: CapturableGuidedDecoder) -> bool:
655-
self.guided_decoder = guided_decoder
656-
return True

tensorrt_llm/_torch/speculative/interface.py

Lines changed: 69 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,21 @@
22
import os
33
from dataclasses import dataclass, field
44
from enum import IntEnum, auto
5-
from typing import List, Optional, Type
5+
from typing import TYPE_CHECKING, List, Optional, Type
6+
from abc import ABC, abstractmethod
67

78
import torch
9+
from torch import nn
810

911
from tensorrt_llm.logger import logger
1012

1113
from ..._utils import get_sm_version
1214
from ..attention_backend.trtllm import AttentionBackend, TrtllmAttention
1315
from ..pyexecutor.resource_manager import BaseResourceManager
1416

17+
if TYPE_CHECKING:
18+
from ..pyexecutor.guided_decoder import CapturableGuidedDecoder
19+
1520
# Environment variable name for forcing the number of accepted tokens in speculative decoding
1621
FORCE_NUM_ACCEPTED_TOKENS_ENV_VAR = "TLLM_SPEC_DECODE_FORCE_NUM_ACCEPTED_TOKENS"
1722

@@ -351,3 +356,66 @@ def populate_sampling_params_for_one_model(
351356
dtype=torch.float32,
352357
pin_memory=True),
353358
non_blocking=True)
359+
360+
361+
class SpecWorkerBase(nn.Module, ABC):
362+
"""
363+
Base class for speculative decoding workers.
364+
Provides common functionality for sampling and token handling.
365+
"""
366+
367+
def __init__(self):
368+
super().__init__()
369+
self.guided_decoder: Optional["CapturableGuidedDecoder"] = None
370+
self.force_num_accepted_tokens = get_force_num_accepted_tokens()
371+
372+
@property
373+
@abstractmethod
374+
def max_draft_len(self) -> int:
375+
"""
376+
Returns the maximum draft length for this worker.
377+
Subclasses should override this property.
378+
"""
379+
pass
380+
381+
def set_guided_decoder(
382+
self, guided_decoder: "CapturableGuidedDecoder") -> bool:
383+
self.guided_decoder = guided_decoder
384+
return True
385+
386+
def _sample_tokens_for_batch(
387+
self,
388+
logits: torch.Tensor,
389+
spec_metadata: SpecMetadata,
390+
num_contexts: int,
391+
batch_size: int,
392+
) -> torch.Tensor:
393+
"""
394+
Sample tokens from logits using per-request sampling parameters.
395+
Supports both greedy and non-greedy sampling.
396+
397+
Args:
398+
logits: [num_tokens, vocab_size] - Logits to sample from
399+
spec_metadata: Metadata containing sampling parameters
400+
num_contexts: Number of context requests in the batch
401+
batch_size: Number of requests in the batch
402+
403+
Returns:
404+
sampled_tokens: [num_tokens] - Sampled token ids
405+
"""
406+
if spec_metadata.allow_advanced_sampling:
407+
from .one_model_sampler import sampling_batch_spec_dec_one_model
408+
409+
num_gens = batch_size - num_contexts
410+
num_tokens = num_contexts + num_gens * (self.max_draft_len + 1)
411+
412+
temperatures = spec_metadata.temperatures[:num_tokens]
413+
top_ks = spec_metadata.top_ks[:num_tokens]
414+
top_ps = spec_metadata.top_ps[:num_tokens]
415+
416+
sampled_tokens = sampling_batch_spec_dec_one_model(
417+
logits, temperatures, top_ks, top_ps)
418+
else:
419+
sampled_tokens = torch.argmax(logits, dim=-1)
420+
421+
return sampled_tokens

tensorrt_llm/_torch/speculative/mtp.py

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
SampleStateTensors, TorchSampler, add_token,
1818
int_tensor)
1919
from ..pyexecutor.scheduler import ScheduledRequests
20-
from .interface import SpecMetadata, get_force_num_accepted_tokens
20+
from .interface import SpecMetadata, SpecWorkerBase
2121

2222
if TYPE_CHECKING:
2323
from tensorrt_llm.llmapi.llm_args import MTPDecodingConfig
@@ -349,15 +349,17 @@ def sample_async(
349349
sampler_event=sampler_event)
350350

351351

352-
class MTPWorker(nn.Module):
352+
class MTPWorker(SpecWorkerBase):
353353

354354
def __init__(self, spec_config: "MTPDecodingConfig", model_config=None):
355355
super().__init__()
356356
self.spec_config = spec_config
357357
self.model_config = model_config
358358
self.is_thop = False
359-
self.guided_decoder: Optional[CapturableGuidedDecoder] = None
360-
self.force_num_accepted_tokens = get_force_num_accepted_tokens()
359+
360+
@property
361+
def max_draft_len(self) -> int:
362+
return self.spec_config.num_nextn_predict_layers
361363

362364
def forward(
363365
self,
@@ -889,8 +891,8 @@ def sample_and_accept_draft_tokens(
889891
logits, spec_metadata.draft_tokens, target_tokens_cache,
890892
mtp_num_modules, batch_size, num_contexts, logits.shape[-1])
891893
else:
892-
# Do greedy sampling for the input logits
893-
target_tokens = torch.argmax(logits, dim=-1)
894+
target_tokens = self._sample_tokens_for_batch(
895+
logits, spec_metadata, num_contexts, batch_size)
894896

895897
# context
896898
accepted_tokens[:num_contexts, 0] = target_tokens[:num_contexts]
@@ -1173,11 +1175,6 @@ def draft_sampler(
11731175

11741176
return draft_tokens
11751177

1176-
def set_guided_decoder(self,
1177-
guided_decoder: CapturableGuidedDecoder) -> bool:
1178-
self.guided_decoder = guided_decoder
1179-
return True
1180-
11811178

11821179
class MTPEagleWorker(MTPWorker):
11831180

0 commit comments

Comments
 (0)