Skip to content

Commit f62aa30

Browse files
committed
Implement sampling for MTP
Signed-off-by: Mike Iovine <[email protected]>
1 parent 9ba1426 commit f62aa30

File tree

6 files changed

+102
-76
lines changed

6 files changed

+102
-76
lines changed

tensorrt_llm/_torch/pyexecutor/model_engine.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,8 @@
4848
get_spec_metadata,
4949
update_spec_config_from_model_config)
5050
from ..speculative.drafting_loops import BaseDraftingLoopWrapper
51-
from ..speculative.eagle3 import (Eagle3OneModelSpecMetadata,
52-
Eagle3ResourceManager, Eagle3SpecMetadata)
51+
from ..speculative.eagle3 import (Eagle3ResourceManager,
52+
Eagle3SpecMetadata)
5353
from ..speculative.mtp import SampleStateTensorsMTP
5454
from ..speculative.utils import SpecDecodingTensor
5555
from ..utils import (get_model_extra_attrs,
@@ -2115,9 +2115,9 @@ def previous_seq_slots_device():
21152115
num_accepted_draft_tokens)]
21162116
if isinstance(spec_metadata, Eagle3SpecMetadata):
21172117
spec_metadata.request_accepted_path = request_accepted_path
2118-
if isinstance(spec_metadata, Eagle3OneModelSpecMetadata):
2119-
spec_metadata.populate_sampling_params_for_one_model(
2120-
scheduled_requests.all_requests())
2118+
# No-op for non 1-model
2119+
spec_metadata.populate_sampling_params_for_one_model(
2120+
scheduled_requests.all_requests())
21212121
spec_metadata.prepare()
21222122
inputs['spec_metadata'] = spec_metadata
21232123

tensorrt_llm/_torch/speculative/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from .auto_heuristic import suggest_spec_config
22
from .eagle3 import Eagle3SpecMetadata
3-
from .interface import SpecMetadata
3+
from .interface import SpecMetadata, SpecWorkerBase
44
from .mtp import MTPEagleWorker, MTPSpecMetadata, MTPWorker
55
from .ngram import NGramDrafter, NGramPoolManager
66
from .save_hidden_state import SaveHiddenStatesDrafter
@@ -19,6 +19,7 @@
1919
"NGramPoolManager",
2020
"SaveHiddenStatesDrafter",
2121
"SpecMetadata",
22+
"SpecWorkerBase",
2223
"get_num_extra_kv_tokens",
2324
"get_num_spec_layers",
2425
"get_spec_decoder",

tensorrt_llm/_torch/speculative/eagle3.py

Lines changed: 7 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,12 @@
77
from tensorrt_llm.mapping import Mapping
88

99
from ..attention_backend import AttentionMetadata
10-
from ..pyexecutor.guided_decoder import CapturableGuidedDecoder
1110
from ..pyexecutor.llm_request import LlmRequest
1211
from ..pyexecutor.resource_manager import BaseResourceManager, SlotManager
1312
from ..pyexecutor.sampler import TorchSampler
1413
from ..pyexecutor.scheduler import ScheduledRequests
15-
from .interface import SpecMetadata, get_force_num_accepted_tokens
14+
from .interface import SpecMetadata, SpecWorkerBase
1615
from .mtp import MTPSampler
17-
from .one_model_sampler import sampling_batch_spec_dec_one_model
1816
from .spec_tree_manager import SpecTreeManager
1917

2018
if TYPE_CHECKING:
@@ -358,15 +356,16 @@ def __init__(self, args: TorchSampler.Args):
358356
super().__init__(args, nextn=args.max_draft_len)
359357

360358

361-
class Eagle3OneModelWorker(nn.Module):
359+
class Eagle3OneModelWorker(SpecWorkerBase):
362360

363361
def __init__(self, spec_config: "EagleDecodingConfig", mapping: Mapping):
364362
super().__init__()
365363
self.spec_config = spec_config
366-
self.max_draft_len = self.spec_config.max_draft_len
367364
self.mapping = mapping
368-
self.guided_decoder: Optional[CapturableGuidedDecoder] = None
369-
self.force_num_accepted_tokens = get_force_num_accepted_tokens()
365+
366+
@property
367+
def max_draft_len(self) -> int:
368+
return self.spec_config.max_draft_len
370369

371370
# Skip torch.compile for now since current Torch is not compatible with Triton 3.4
372371
# @torch.compile(options={"max-autotune": True})
@@ -494,40 +493,6 @@ def forward(self, input_ids, position_ids, hidden_states, logits,
494493
'next_new_tokens': next_new_tokens,
495494
}
496495

497-
def _sample_tokens_for_batch(
498-
self,
499-
logits: torch.Tensor,
500-
spec_metadata: Eagle3OneModelSpecMetadata,
501-
num_contexts: int,
502-
batch_size: int,
503-
) -> torch.Tensor:
504-
"""
505-
Sample tokens from logits using per-request sampling parameters.
506-
Supports both greedy and non-greedy sampling.
507-
508-
Args:
509-
logits: [num_tokens, vocab_size] - Logits to sample from
510-
spec_metadata: Metadata containing sampling parameters
511-
batch_size: Number of requests in the batch
512-
513-
Returns:
514-
sampled_tokens: [num_tokens] - Sampled token ids
515-
"""
516-
if spec_metadata.allow_advanced_sampling:
517-
num_gens = batch_size - num_contexts
518-
num_tokens = num_contexts + num_gens * (self.max_draft_len + 1)
519-
520-
temperatures = spec_metadata.temperatures[:num_tokens]
521-
top_ks = spec_metadata.top_ks[:num_tokens]
522-
top_ps = spec_metadata.top_ps[:num_tokens]
523-
524-
sampled_tokens = sampling_batch_spec_dec_one_model(
525-
logits, temperatures, top_ks, top_ps)
526-
else:
527-
sampled_tokens = torch.argmax(logits, dim=-1)
528-
529-
return sampled_tokens
530-
531496
def sample_and_accept_draft_tokens(
532497
self,
533498
logits: torch.Tensor,
@@ -578,7 +543,7 @@ def draft_decoder(
578543
draft_model: nn.Module,
579544
):
580545
'''
581-
Sampling draft tokens.
546+
Sampling draft tokens with support for non-greedy sampling.
582547
583548
Args:
584549
logits: torch.Tensor
@@ -649,8 +614,3 @@ def prepare_1st_drafter_inputs(
649614
"attn_metadata": attn_metadata,
650615
"spec_metadata": spec_metadata,
651616
}
652-
653-
def set_guided_decoder(self,
654-
guided_decoder: CapturableGuidedDecoder) -> bool:
655-
self.guided_decoder = guided_decoder
656-
return True

tensorrt_llm/_torch/speculative/interface.py

Lines changed: 68 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,22 @@
11
import copy
22
import os
3+
from abc import ABC, abstractmethod
34
from dataclasses import dataclass, field
45
from enum import IntEnum, auto
5-
from typing import List, Optional, Type
6+
from typing import TYPE_CHECKING, List, Optional, Type
67

78
import torch
9+
from torch import nn
810

911
from tensorrt_llm.logger import logger
1012

1113
from ..._utils import get_sm_version
1214
from ..attention_backend.trtllm import AttentionBackend, TrtllmAttention
1315
from ..pyexecutor.resource_manager import BaseResourceManager
1416

17+
if TYPE_CHECKING:
18+
from ..pyexecutor.guided_decoder import CapturableGuidedDecoder
19+
1520
# Environment variable name for forcing the number of accepted tokens in speculative decoding
1621
FORCE_NUM_ACCEPTED_TOKENS_ENV_VAR = "TLLM_SPEC_DECODE_FORCE_NUM_ACCEPTED_TOKENS"
1722

@@ -351,3 +356,65 @@ def populate_sampling_params_for_one_model(
351356
dtype=torch.float32,
352357
pin_memory=True),
353358
non_blocking=True)
359+
360+
361+
class SpecWorkerBase(nn.Module, ABC):
362+
"""
363+
Base class for speculative decoding workers.
364+
Provides common functionality for sampling and token handling.
365+
"""
366+
367+
def __init__(self):
368+
super().__init__()
369+
self.guided_decoder: Optional["CapturableGuidedDecoder"] = None
370+
self.force_num_accepted_tokens = get_force_num_accepted_tokens()
371+
372+
@property
373+
@abstractmethod
374+
def max_draft_len(self) -> int:
375+
"""
376+
Returns the maximum draft length for this worker.
377+
Subclasses should override this property.
378+
"""
379+
380+
def set_guided_decoder(self,
381+
guided_decoder: "CapturableGuidedDecoder") -> bool:
382+
self.guided_decoder = guided_decoder
383+
return True
384+
385+
def _sample_tokens_for_batch(
386+
self,
387+
logits: torch.Tensor,
388+
spec_metadata: SpecMetadata,
389+
num_contexts: int,
390+
batch_size: int,
391+
) -> torch.Tensor:
392+
"""
393+
Sample tokens from logits using per-request sampling parameters.
394+
Supports both greedy and non-greedy sampling.
395+
396+
Args:
397+
logits: [num_tokens, vocab_size] - Logits to sample from
398+
spec_metadata: Metadata containing sampling parameters
399+
num_contexts: Number of context requests in the batch
400+
batch_size: Number of requests in the batch
401+
402+
Returns:
403+
sampled_tokens: [num_tokens] - Sampled token ids
404+
"""
405+
if spec_metadata.allow_advanced_sampling:
406+
from .one_model_sampler import sampling_batch_spec_dec_one_model
407+
408+
num_gens = batch_size - num_contexts
409+
num_tokens = num_contexts + num_gens * (self.max_draft_len + 1)
410+
411+
temperatures = spec_metadata.temperatures[:num_tokens]
412+
top_ks = spec_metadata.top_ks[:num_tokens]
413+
top_ps = spec_metadata.top_ps[:num_tokens]
414+
415+
sampled_tokens = sampling_batch_spec_dec_one_model(
416+
logits, temperatures, top_ks, top_ps)
417+
else:
418+
sampled_tokens = torch.argmax(logits, dim=-1)
419+
420+
return sampled_tokens

tensorrt_llm/_torch/speculative/mtp.py

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,21 +3,19 @@
33

44
import torch
55
import torch.nn.functional as F
6-
from torch import nn
76

87
from tensorrt_llm.mapping import Mapping
98

109
from ..attention_backend import AttentionMetadata
1110
from ..distributed.ops import allgather
1211
from ..model_config import ModelConfig
13-
from ..pyexecutor.guided_decoder import CapturableGuidedDecoder
1412
from ..pyexecutor.llm_request import LlmRequest, LlmRequestState
1513
from ..pyexecutor.resource_manager import BaseResourceManager, SlotManager
1614
from ..pyexecutor.sampler import (DEFAULT_BEAM_IDX, SampleState,
1715
SampleStateTensors, TorchSampler, add_token,
1816
int_tensor)
1917
from ..pyexecutor.scheduler import ScheduledRequests
20-
from .interface import SpecMetadata, get_force_num_accepted_tokens
18+
from .interface import SpecMetadata, SpecWorkerBase
2119

2220
if TYPE_CHECKING:
2321
from tensorrt_llm.llmapi.llm_args import MTPDecodingConfig
@@ -349,15 +347,17 @@ def sample_async(
349347
sampler_event=sampler_event)
350348

351349

352-
class MTPWorker(nn.Module):
350+
class MTPWorker(SpecWorkerBase):
353351

354352
def __init__(self, spec_config: "MTPDecodingConfig", model_config=None):
355353
super().__init__()
356354
self.spec_config = spec_config
357355
self.model_config = model_config
358356
self.is_thop = False
359-
self.guided_decoder: Optional[CapturableGuidedDecoder] = None
360-
self.force_num_accepted_tokens = get_force_num_accepted_tokens()
357+
358+
@property
359+
def max_draft_len(self) -> int:
360+
return self.spec_config.num_nextn_predict_layers
361361

362362
def forward(
363363
self,
@@ -889,8 +889,8 @@ def sample_and_accept_draft_tokens(
889889
logits, spec_metadata.draft_tokens, target_tokens_cache,
890890
mtp_num_modules, batch_size, num_contexts, logits.shape[-1])
891891
else:
892-
# Do greedy sampling for the input logits
893-
target_tokens = torch.argmax(logits, dim=-1)
892+
target_tokens = self._sample_tokens_for_batch(
893+
logits, spec_metadata, num_contexts, batch_size)
894894

895895
# context
896896
accepted_tokens[:num_contexts, 0] = target_tokens[:num_contexts]
@@ -1173,11 +1173,6 @@ def draft_sampler(
11731173

11741174
return draft_tokens
11751175

1176-
def set_guided_decoder(self,
1177-
guided_decoder: CapturableGuidedDecoder) -> bool:
1178-
self.guided_decoder = guided_decoder
1179-
return True
1180-
11811176

11821177
class MTPEagleWorker(MTPWorker):
11831178

tests/integration/defs/accuracy/test_llm_api_pytorch.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1318,14 +1318,12 @@ class TestDeepSeekV3Lite(LlmapiAccuracyTestHarness):
13181318

13191319
@pytest.mark.skip_less_device_memory(60000)
13201320
# Chunked Prefill for MLA can only be enabled on SM100
1321-
@parametrize_with_ids("enable_chunked_prefill", [False, True])
1322-
@parametrize_with_ids("torch_compile", [False, True])
1321+
@parametrize_with_ids("enable_chunked_prefill", [False])
1322+
@parametrize_with_ids("torch_compile", [False])
13231323
@parametrize_with_ids("attention_dp,cuda_graph,overlap_scheduler",
1324-
[(False, False, False), (True, False, False),
1325-
(False, True, False), (False, False, True),
1326-
(False, True, True), (True, True, True)])
1324+
[(False, False, False)])
13271325
# Only Hopper and Blackwell MLA kernel supports MTP
1328-
@parametrize_with_ids("mtp_nextn", [0, 2])
1326+
@parametrize_with_ids("mtp_nextn", [2])
13291327
def test_bfloat16(self, mtp_nextn, attention_dp, cuda_graph,
13301328
overlap_scheduler, torch_compile, enable_chunked_prefill):
13311329
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.75)
@@ -1339,9 +1337,14 @@ def test_bfloat16(self, mtp_nextn, attention_dp, cuda_graph,
13391337
cuda_graph_config=CudaGraphConfig() if cuda_graph else None,
13401338
torch_compile_config=torch_compile_config,
13411339
)
1342-
mtp_config = None
1340+
13431341
if mtp_nextn > 0:
1344-
mtp_config = MTPDecodingConfig(num_nextn_predict_layers=mtp_nextn)
1342+
mtp_config = MTPDecodingConfig(num_nextn_predict_layers=mtp_nextn,
1343+
allow_advanced_sampling=True)
1344+
sampling_params = SamplingParams(temperature=0.5)
1345+
else:
1346+
sampling_params = mtp_config = None
1347+
13451348
with LLM(self.MODEL_PATH,
13461349
kv_cache_config=kv_cache_config,
13471350
enable_chunked_prefill=enable_chunked_prefill,
@@ -1350,7 +1353,7 @@ def test_bfloat16(self, mtp_nextn, attention_dp, cuda_graph,
13501353
enable_attention_dp=attention_dp,
13511354
speculative_config=mtp_config) as llm:
13521355
task = GSM8K(self.MODEL_NAME)
1353-
task.evaluate(llm)
1356+
task.evaluate(llm, sampling_params=sampling_params)
13541357

13551358
@pytest.mark.skip_less_device_memory(60000)
13561359
def test_bfloat16_2_model_mtp(self):

0 commit comments

Comments
 (0)