Skip to content

Commit dabe1c4

Browse files
committed
Implement sampling for MTP
Signed-off-by: Mike Iovine <[email protected]>
1 parent 9ba1426 commit dabe1c4

File tree

5 files changed

+93
-65
lines changed

5 files changed

+93
-65
lines changed

tensorrt_llm/_torch/speculative/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from .auto_heuristic import suggest_spec_config
22
from .eagle3 import Eagle3SpecMetadata
3-
from .interface import SpecMetadata
3+
from .interface import SpecMetadata, SpecWorkerBase
44
from .mtp import MTPEagleWorker, MTPSpecMetadata, MTPWorker
55
from .ngram import NGramDrafter, NGramPoolManager
66
from .save_hidden_state import SaveHiddenStatesDrafter
@@ -19,6 +19,7 @@
1919
"NGramPoolManager",
2020
"SaveHiddenStatesDrafter",
2121
"SpecMetadata",
22+
"SpecWorkerBase",
2223
"get_num_extra_kv_tokens",
2324
"get_num_spec_layers",
2425
"get_spec_decoder",

tensorrt_llm/_torch/speculative/eagle3.py

Lines changed: 7 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,12 @@
77
from tensorrt_llm.mapping import Mapping
88

99
from ..attention_backend import AttentionMetadata
10-
from ..pyexecutor.guided_decoder import CapturableGuidedDecoder
1110
from ..pyexecutor.llm_request import LlmRequest
1211
from ..pyexecutor.resource_manager import BaseResourceManager, SlotManager
1312
from ..pyexecutor.sampler import TorchSampler
1413
from ..pyexecutor.scheduler import ScheduledRequests
15-
from .interface import SpecMetadata, get_force_num_accepted_tokens
14+
from .interface import SpecMetadata, SpecWorkerBase
1615
from .mtp import MTPSampler
17-
from .one_model_sampler import sampling_batch_spec_dec_one_model
1816
from .spec_tree_manager import SpecTreeManager
1917

2018
if TYPE_CHECKING:
@@ -358,15 +356,16 @@ def __init__(self, args: TorchSampler.Args):
358356
super().__init__(args, nextn=args.max_draft_len)
359357

360358

361-
class Eagle3OneModelWorker(nn.Module):
359+
class Eagle3OneModelWorker(SpecWorkerBase):
362360

363361
def __init__(self, spec_config: "EagleDecodingConfig", mapping: Mapping):
364362
super().__init__()
365363
self.spec_config = spec_config
366-
self.max_draft_len = self.spec_config.max_draft_len
367364
self.mapping = mapping
368-
self.guided_decoder: Optional[CapturableGuidedDecoder] = None
369-
self.force_num_accepted_tokens = get_force_num_accepted_tokens()
365+
366+
@property
367+
def max_draft_len(self) -> int:
368+
return self.spec_config.max_draft_len
370369

371370
# Skip torch.compile for now since current Torch is not compatible with Triton 3.4
372371
# @torch.compile(options={"max-autotune": True})
@@ -494,40 +493,6 @@ def forward(self, input_ids, position_ids, hidden_states, logits,
494493
'next_new_tokens': next_new_tokens,
495494
}
496495

497-
def _sample_tokens_for_batch(
498-
self,
499-
logits: torch.Tensor,
500-
spec_metadata: Eagle3OneModelSpecMetadata,
501-
num_contexts: int,
502-
batch_size: int,
503-
) -> torch.Tensor:
504-
"""
505-
Sample tokens from logits using per-request sampling parameters.
506-
Supports both greedy and non-greedy sampling.
507-
508-
Args:
509-
logits: [num_tokens, vocab_size] - Logits to sample from
510-
spec_metadata: Metadata containing sampling parameters
511-
batch_size: Number of requests in the batch
512-
513-
Returns:
514-
sampled_tokens: [num_tokens] - Sampled token ids
515-
"""
516-
if spec_metadata.allow_advanced_sampling:
517-
num_gens = batch_size - num_contexts
518-
num_tokens = num_contexts + num_gens * (self.max_draft_len + 1)
519-
520-
temperatures = spec_metadata.temperatures[:num_tokens]
521-
top_ks = spec_metadata.top_ks[:num_tokens]
522-
top_ps = spec_metadata.top_ps[:num_tokens]
523-
524-
sampled_tokens = sampling_batch_spec_dec_one_model(
525-
logits, temperatures, top_ks, top_ps)
526-
else:
527-
sampled_tokens = torch.argmax(logits, dim=-1)
528-
529-
return sampled_tokens
530-
531496
def sample_and_accept_draft_tokens(
532497
self,
533498
logits: torch.Tensor,
@@ -578,7 +543,7 @@ def draft_decoder(
578543
draft_model: nn.Module,
579544
):
580545
'''
581-
Sampling draft tokens.
546+
Sampling draft tokens with support for non-greedy sampling.
582547
583548
Args:
584549
logits: torch.Tensor
@@ -649,8 +614,3 @@ def prepare_1st_drafter_inputs(
649614
"attn_metadata": attn_metadata,
650615
"spec_metadata": spec_metadata,
651616
}
652-
653-
def set_guided_decoder(self,
654-
guided_decoder: CapturableGuidedDecoder) -> bool:
655-
self.guided_decoder = guided_decoder
656-
return True

tensorrt_llm/_torch/speculative/interface.py

Lines changed: 68 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,22 @@
11
import copy
22
import os
3+
from abc import ABC, abstractmethod
34
from dataclasses import dataclass, field
45
from enum import IntEnum, auto
5-
from typing import List, Optional, Type
6+
from typing import TYPE_CHECKING, List, Optional, Type
67

78
import torch
9+
from torch import nn
810

911
from tensorrt_llm.logger import logger
1012

1113
from ..._utils import get_sm_version
1214
from ..attention_backend.trtllm import AttentionBackend, TrtllmAttention
1315
from ..pyexecutor.resource_manager import BaseResourceManager
1416

17+
if TYPE_CHECKING:
18+
from ..pyexecutor.guided_decoder import CapturableGuidedDecoder
19+
1520
# Environment variable name for forcing the number of accepted tokens in speculative decoding
1621
FORCE_NUM_ACCEPTED_TOKENS_ENV_VAR = "TLLM_SPEC_DECODE_FORCE_NUM_ACCEPTED_TOKENS"
1722

@@ -351,3 +356,65 @@ def populate_sampling_params_for_one_model(
351356
dtype=torch.float32,
352357
pin_memory=True),
353358
non_blocking=True)
359+
360+
361+
class SpecWorkerBase(nn.Module, ABC):
362+
"""
363+
Base class for speculative decoding workers.
364+
Provides common functionality for sampling and token handling.
365+
"""
366+
367+
def __init__(self):
368+
super().__init__()
369+
self.guided_decoder: Optional["CapturableGuidedDecoder"] = None
370+
self.force_num_accepted_tokens = get_force_num_accepted_tokens()
371+
372+
@property
373+
@abstractmethod
374+
def max_draft_len(self) -> int:
375+
"""
376+
Returns the maximum draft length for this worker.
377+
Subclasses should override this property.
378+
"""
379+
380+
def set_guided_decoder(self,
381+
guided_decoder: "CapturableGuidedDecoder") -> bool:
382+
self.guided_decoder = guided_decoder
383+
return True
384+
385+
def _sample_tokens_for_batch(
386+
self,
387+
logits: torch.Tensor,
388+
spec_metadata: SpecMetadata,
389+
num_contexts: int,
390+
batch_size: int,
391+
) -> torch.Tensor:
392+
"""
393+
Sample tokens from logits using per-request sampling parameters.
394+
Supports both greedy and non-greedy sampling.
395+
396+
Args:
397+
logits: [num_tokens, vocab_size] - Logits to sample from
398+
spec_metadata: Metadata containing sampling parameters
399+
num_contexts: Number of context requests in the batch
400+
batch_size: Number of requests in the batch
401+
402+
Returns:
403+
sampled_tokens: [num_tokens] - Sampled token ids
404+
"""
405+
if spec_metadata.allow_advanced_sampling:
406+
from .one_model_sampler import sampling_batch_spec_dec_one_model
407+
408+
num_gens = batch_size - num_contexts
409+
num_tokens = num_contexts + num_gens * (self.max_draft_len + 1)
410+
411+
temperatures = spec_metadata.temperatures[:num_tokens]
412+
top_ks = spec_metadata.top_ks[:num_tokens]
413+
top_ps = spec_metadata.top_ps[:num_tokens]
414+
415+
sampled_tokens = sampling_batch_spec_dec_one_model(
416+
logits, temperatures, top_ks, top_ps)
417+
else:
418+
sampled_tokens = torch.argmax(logits, dim=-1)
419+
420+
return sampled_tokens

tensorrt_llm/_torch/speculative/mtp.py

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,21 +3,19 @@
33

44
import torch
55
import torch.nn.functional as F
6-
from torch import nn
76

87
from tensorrt_llm.mapping import Mapping
98

109
from ..attention_backend import AttentionMetadata
1110
from ..distributed.ops import allgather
1211
from ..model_config import ModelConfig
13-
from ..pyexecutor.guided_decoder import CapturableGuidedDecoder
1412
from ..pyexecutor.llm_request import LlmRequest, LlmRequestState
1513
from ..pyexecutor.resource_manager import BaseResourceManager, SlotManager
1614
from ..pyexecutor.sampler import (DEFAULT_BEAM_IDX, SampleState,
1715
SampleStateTensors, TorchSampler, add_token,
1816
int_tensor)
1917
from ..pyexecutor.scheduler import ScheduledRequests
20-
from .interface import SpecMetadata, get_force_num_accepted_tokens
18+
from .interface import SpecMetadata, SpecWorkerBase
2119

2220
if TYPE_CHECKING:
2321
from tensorrt_llm.llmapi.llm_args import MTPDecodingConfig
@@ -349,15 +347,17 @@ def sample_async(
349347
sampler_event=sampler_event)
350348

351349

352-
class MTPWorker(nn.Module):
350+
class MTPWorker(SpecWorkerBase):
353351

354352
def __init__(self, spec_config: "MTPDecodingConfig", model_config=None):
355353
super().__init__()
356354
self.spec_config = spec_config
357355
self.model_config = model_config
358356
self.is_thop = False
359-
self.guided_decoder: Optional[CapturableGuidedDecoder] = None
360-
self.force_num_accepted_tokens = get_force_num_accepted_tokens()
357+
358+
@property
359+
def max_draft_len(self) -> int:
360+
return self.spec_config.num_nextn_predict_layers
361361

362362
def forward(
363363
self,
@@ -889,8 +889,8 @@ def sample_and_accept_draft_tokens(
889889
logits, spec_metadata.draft_tokens, target_tokens_cache,
890890
mtp_num_modules, batch_size, num_contexts, logits.shape[-1])
891891
else:
892-
# Do greedy sampling for the input logits
893-
target_tokens = torch.argmax(logits, dim=-1)
892+
target_tokens = self._sample_tokens_for_batch(
893+
logits, spec_metadata, num_contexts, batch_size)
894894

895895
# context
896896
accepted_tokens[:num_contexts, 0] = target_tokens[:num_contexts]
@@ -1173,11 +1173,6 @@ def draft_sampler(
11731173

11741174
return draft_tokens
11751175

1176-
def set_guided_decoder(self,
1177-
guided_decoder: CapturableGuidedDecoder) -> bool:
1178-
self.guided_decoder = guided_decoder
1179-
return True
1180-
11811176

11821177
class MTPEagleWorker(MTPWorker):
11831178

tests/integration/defs/accuracy/test_llm_api_pytorch.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1339,9 +1339,14 @@ def test_bfloat16(self, mtp_nextn, attention_dp, cuda_graph,
13391339
cuda_graph_config=CudaGraphConfig() if cuda_graph else None,
13401340
torch_compile_config=torch_compile_config,
13411341
)
1342-
mtp_config = None
1342+
13431343
if mtp_nextn > 0:
1344-
mtp_config = MTPDecodingConfig(num_nextn_predict_layers=mtp_nextn)
1344+
mtp_config = MTPDecodingConfig(num_nextn_predict_layers=mtp_nextn,
1345+
allow_advanced_sampling=True)
1346+
sampling_params = SamplingParams(temperature=0.5)
1347+
else:
1348+
sampling_params = mtp_config = None
1349+
13451350
with LLM(self.MODEL_PATH,
13461351
kv_cache_config=kv_cache_config,
13471352
enable_chunked_prefill=enable_chunked_prefill,
@@ -1350,7 +1355,7 @@ def test_bfloat16(self, mtp_nextn, attention_dp, cuda_graph,
13501355
enable_attention_dp=attention_dp,
13511356
speculative_config=mtp_config) as llm:
13521357
task = GSM8K(self.MODEL_NAME)
1353-
task.evaluate(llm)
1358+
task.evaluate(llm, sampling_params=sampling_params)
13541359

13551360
@pytest.mark.skip_less_device_memory(60000)
13561361
def test_bfloat16_2_model_mtp(self):

0 commit comments

Comments
 (0)