Skip to content

Commit 69e0391

Browse files
committed
Implement sampling for MTP
Signed-off-by: Mike Iovine <[email protected]>
1 parent 9ba1426 commit 69e0391

File tree

8 files changed

+103
-80
lines changed

8 files changed

+103
-80
lines changed

tensorrt_llm/_torch/pyexecutor/model_engine.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,7 @@
4848
get_spec_metadata,
4949
update_spec_config_from_model_config)
5050
from ..speculative.drafting_loops import BaseDraftingLoopWrapper
51-
from ..speculative.eagle3 import (Eagle3OneModelSpecMetadata,
52-
Eagle3ResourceManager, Eagle3SpecMetadata)
51+
from ..speculative.eagle3 import Eagle3ResourceManager, Eagle3SpecMetadata
5352
from ..speculative.mtp import SampleStateTensorsMTP
5453
from ..speculative.utils import SpecDecodingTensor
5554
from ..utils import (get_model_extra_attrs,
@@ -2115,9 +2114,9 @@ def previous_seq_slots_device():
21152114
num_accepted_draft_tokens)]
21162115
if isinstance(spec_metadata, Eagle3SpecMetadata):
21172116
spec_metadata.request_accepted_path = request_accepted_path
2118-
if isinstance(spec_metadata, Eagle3OneModelSpecMetadata):
2119-
spec_metadata.populate_sampling_params_for_one_model(
2120-
scheduled_requests.all_requests())
2117+
# No-op for non 1-model
2118+
spec_metadata.populate_sampling_params_for_one_model(
2119+
scheduled_requests.all_requests())
21212120
spec_metadata.prepare()
21222121
inputs['spec_metadata'] = spec_metadata
21232122

tensorrt_llm/_torch/pyexecutor/py_executor_creator.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -281,16 +281,12 @@ def create_py_executor(
281281
)
282282
llm_args.disable_overlap_scheduler = True
283283

284-
if spec_config is not None and spec_config.spec_dec_mode.use_one_engine():
285-
if not spec_config.allow_advanced_sampling:
286-
logger.warning(
287-
f"Falling back to greedy decoding for {spec_config.decoding_type}. If you "
288-
"want to use non-greedy sampling, please set allow_advanced_sampling=True."
289-
)
290-
elif spec_config.spec_dec_mode.is_mtp_one_model():
291-
logger.warning(
292-
"Advanced sampling is not supported for MTP yet - this will be added soon."
293-
)
284+
if spec_config is not None and spec_config.spec_dec_mode.use_one_engine(
285+
) and not spec_config.allow_advanced_sampling:
286+
logger.warning(
287+
f"Falling back to greedy decoding for {spec_config.decoding_type}. If you "
288+
"want to use non-greedy sampling, please set allow_advanced_sampling=True."
289+
)
294290

295291
if mm_encoder_only:
296292
llm_args.mm_encoder_only = True

tensorrt_llm/_torch/speculative/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from .auto_heuristic import suggest_spec_config
22
from .eagle3 import Eagle3SpecMetadata
3-
from .interface import SpecMetadata
3+
from .interface import SpecMetadata, SpecWorkerBase
44
from .mtp import MTPEagleWorker, MTPSpecMetadata, MTPWorker
55
from .ngram import NGramDrafter, NGramPoolManager
66
from .save_hidden_state import SaveHiddenStatesDrafter
@@ -19,6 +19,7 @@
1919
"NGramPoolManager",
2020
"SaveHiddenStatesDrafter",
2121
"SpecMetadata",
22+
"SpecWorkerBase",
2223
"get_num_extra_kv_tokens",
2324
"get_num_spec_layers",
2425
"get_spec_decoder",

tensorrt_llm/_torch/speculative/eagle3.py

Lines changed: 7 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,12 @@
77
from tensorrt_llm.mapping import Mapping
88

99
from ..attention_backend import AttentionMetadata
10-
from ..pyexecutor.guided_decoder import CapturableGuidedDecoder
1110
from ..pyexecutor.llm_request import LlmRequest
1211
from ..pyexecutor.resource_manager import BaseResourceManager, SlotManager
1312
from ..pyexecutor.sampler import TorchSampler
1413
from ..pyexecutor.scheduler import ScheduledRequests
15-
from .interface import SpecMetadata, get_force_num_accepted_tokens
14+
from .interface import SpecMetadata, SpecWorkerBase
1615
from .mtp import MTPSampler
17-
from .one_model_sampler import sampling_batch_spec_dec_one_model
1816
from .spec_tree_manager import SpecTreeManager
1917

2018
if TYPE_CHECKING:
@@ -358,15 +356,16 @@ def __init__(self, args: TorchSampler.Args):
358356
super().__init__(args, nextn=args.max_draft_len)
359357

360358

361-
class Eagle3OneModelWorker(nn.Module):
359+
class Eagle3OneModelWorker(SpecWorkerBase):
362360

363361
def __init__(self, spec_config: "EagleDecodingConfig", mapping: Mapping):
364362
super().__init__()
365363
self.spec_config = spec_config
366-
self.max_draft_len = self.spec_config.max_draft_len
367364
self.mapping = mapping
368-
self.guided_decoder: Optional[CapturableGuidedDecoder] = None
369-
self.force_num_accepted_tokens = get_force_num_accepted_tokens()
365+
366+
@property
367+
def max_draft_len(self) -> int:
368+
return self.spec_config.max_draft_len
370369

371370
# Skip torch.compile for now since current Torch is not compatible with Triton 3.4
372371
# @torch.compile(options={"max-autotune": True})
@@ -494,40 +493,6 @@ def forward(self, input_ids, position_ids, hidden_states, logits,
494493
'next_new_tokens': next_new_tokens,
495494
}
496495

497-
def _sample_tokens_for_batch(
498-
self,
499-
logits: torch.Tensor,
500-
spec_metadata: Eagle3OneModelSpecMetadata,
501-
num_contexts: int,
502-
batch_size: int,
503-
) -> torch.Tensor:
504-
"""
505-
Sample tokens from logits using per-request sampling parameters.
506-
Supports both greedy and non-greedy sampling.
507-
508-
Args:
509-
logits: [num_tokens, vocab_size] - Logits to sample from
510-
spec_metadata: Metadata containing sampling parameters
511-
batch_size: Number of requests in the batch
512-
513-
Returns:
514-
sampled_tokens: [num_tokens] - Sampled token ids
515-
"""
516-
if spec_metadata.allow_advanced_sampling:
517-
num_gens = batch_size - num_contexts
518-
num_tokens = num_contexts + num_gens * (self.max_draft_len + 1)
519-
520-
temperatures = spec_metadata.temperatures[:num_tokens]
521-
top_ks = spec_metadata.top_ks[:num_tokens]
522-
top_ps = spec_metadata.top_ps[:num_tokens]
523-
524-
sampled_tokens = sampling_batch_spec_dec_one_model(
525-
logits, temperatures, top_ks, top_ps)
526-
else:
527-
sampled_tokens = torch.argmax(logits, dim=-1)
528-
529-
return sampled_tokens
530-
531496
def sample_and_accept_draft_tokens(
532497
self,
533498
logits: torch.Tensor,
@@ -578,7 +543,7 @@ def draft_decoder(
578543
draft_model: nn.Module,
579544
):
580545
'''
581-
Sampling draft tokens.
546+
Sampling draft tokens with support for non-greedy sampling.
582547
583548
Args:
584549
logits: torch.Tensor
@@ -649,8 +614,3 @@ def prepare_1st_drafter_inputs(
649614
"attn_metadata": attn_metadata,
650615
"spec_metadata": spec_metadata,
651616
}
652-
653-
def set_guided_decoder(self,
654-
guided_decoder: CapturableGuidedDecoder) -> bool:
655-
self.guided_decoder = guided_decoder
656-
return True

tensorrt_llm/_torch/speculative/interface.py

Lines changed: 68 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,22 @@
11
import copy
22
import os
3+
from abc import ABC, abstractmethod
34
from dataclasses import dataclass, field
45
from enum import IntEnum, auto
5-
from typing import List, Optional, Type
6+
from typing import TYPE_CHECKING, List, Optional, Type
67

78
import torch
9+
from torch import nn
810

911
from tensorrt_llm.logger import logger
1012

1113
from ..._utils import get_sm_version
1214
from ..attention_backend.trtllm import AttentionBackend, TrtllmAttention
1315
from ..pyexecutor.resource_manager import BaseResourceManager
1416

17+
if TYPE_CHECKING:
18+
from ..pyexecutor.guided_decoder import CapturableGuidedDecoder
19+
1520
# Environment variable name for forcing the number of accepted tokens in speculative decoding
1621
FORCE_NUM_ACCEPTED_TOKENS_ENV_VAR = "TLLM_SPEC_DECODE_FORCE_NUM_ACCEPTED_TOKENS"
1722

@@ -351,3 +356,65 @@ def populate_sampling_params_for_one_model(
351356
dtype=torch.float32,
352357
pin_memory=True),
353358
non_blocking=True)
359+
360+
361+
class SpecWorkerBase(nn.Module, ABC):
362+
"""
363+
Base class for speculative decoding workers.
364+
Provides common functionality for sampling and token handling.
365+
"""
366+
367+
def __init__(self):
368+
super().__init__()
369+
self.guided_decoder: Optional["CapturableGuidedDecoder"] = None
370+
self.force_num_accepted_tokens = get_force_num_accepted_tokens()
371+
372+
@property
373+
@abstractmethod
374+
def max_draft_len(self) -> int:
375+
"""
376+
Returns the maximum draft length for this worker.
377+
Subclasses should override this property.
378+
"""
379+
380+
def set_guided_decoder(self,
381+
guided_decoder: "CapturableGuidedDecoder") -> bool:
382+
self.guided_decoder = guided_decoder
383+
return True
384+
385+
def _sample_tokens_for_batch(
386+
self,
387+
logits: torch.Tensor,
388+
spec_metadata: SpecMetadata,
389+
num_contexts: int,
390+
batch_size: int,
391+
) -> torch.Tensor:
392+
"""
393+
Sample tokens from logits using per-request sampling parameters.
394+
Supports both greedy and non-greedy sampling.
395+
396+
Args:
397+
logits: [num_tokens, vocab_size] - Logits to sample from
398+
spec_metadata: Metadata containing sampling parameters
399+
num_contexts: Number of context requests in the batch
400+
batch_size: Number of requests in the batch
401+
402+
Returns:
403+
sampled_tokens: [num_tokens] - Sampled token ids
404+
"""
405+
if spec_metadata.allow_advanced_sampling:
406+
from .one_model_sampler import sampling_batch_spec_dec_one_model
407+
408+
num_gens = batch_size - num_contexts
409+
num_tokens = num_contexts + num_gens * (self.max_draft_len + 1)
410+
411+
temperatures = spec_metadata.temperatures[:num_tokens]
412+
top_ks = spec_metadata.top_ks[:num_tokens]
413+
top_ps = spec_metadata.top_ps[:num_tokens]
414+
415+
sampled_tokens = sampling_batch_spec_dec_one_model(
416+
logits, temperatures, top_ks, top_ps)
417+
else:
418+
sampled_tokens = torch.argmax(logits, dim=-1)
419+
420+
return sampled_tokens

tensorrt_llm/_torch/speculative/mtp.py

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,21 +3,19 @@
33

44
import torch
55
import torch.nn.functional as F
6-
from torch import nn
76

87
from tensorrt_llm.mapping import Mapping
98

109
from ..attention_backend import AttentionMetadata
1110
from ..distributed.ops import allgather
1211
from ..model_config import ModelConfig
13-
from ..pyexecutor.guided_decoder import CapturableGuidedDecoder
1412
from ..pyexecutor.llm_request import LlmRequest, LlmRequestState
1513
from ..pyexecutor.resource_manager import BaseResourceManager, SlotManager
1614
from ..pyexecutor.sampler import (DEFAULT_BEAM_IDX, SampleState,
1715
SampleStateTensors, TorchSampler, add_token,
1816
int_tensor)
1917
from ..pyexecutor.scheduler import ScheduledRequests
20-
from .interface import SpecMetadata, get_force_num_accepted_tokens
18+
from .interface import SpecMetadata, SpecWorkerBase
2119

2220
if TYPE_CHECKING:
2321
from tensorrt_llm.llmapi.llm_args import MTPDecodingConfig
@@ -349,15 +347,17 @@ def sample_async(
349347
sampler_event=sampler_event)
350348

351349

352-
class MTPWorker(nn.Module):
350+
class MTPWorker(SpecWorkerBase):
353351

354352
def __init__(self, spec_config: "MTPDecodingConfig", model_config=None):
355353
super().__init__()
356354
self.spec_config = spec_config
357355
self.model_config = model_config
358356
self.is_thop = False
359-
self.guided_decoder: Optional[CapturableGuidedDecoder] = None
360-
self.force_num_accepted_tokens = get_force_num_accepted_tokens()
357+
358+
@property
359+
def max_draft_len(self) -> int:
360+
return self.spec_config.num_nextn_predict_layers
361361

362362
def forward(
363363
self,
@@ -889,8 +889,8 @@ def sample_and_accept_draft_tokens(
889889
logits, spec_metadata.draft_tokens, target_tokens_cache,
890890
mtp_num_modules, batch_size, num_contexts, logits.shape[-1])
891891
else:
892-
# Do greedy sampling for the input logits
893-
target_tokens = torch.argmax(logits, dim=-1)
892+
target_tokens = self._sample_tokens_for_batch(
893+
logits, spec_metadata, num_contexts, batch_size)
894894

895895
# context
896896
accepted_tokens[:num_contexts, 0] = target_tokens[:num_contexts]
@@ -1173,11 +1173,6 @@ def draft_sampler(
11731173

11741174
return draft_tokens
11751175

1176-
def set_guided_decoder(self,
1177-
guided_decoder: CapturableGuidedDecoder) -> bool:
1178-
self.guided_decoder = guided_decoder
1179-
return True
1180-
11811176

11821177
class MTPEagleWorker(MTPWorker):
11831178

tensorrt_llm/_torch/speculative/utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ def get_spec_metadata(spec_config,
3131
mtp_num_modules=spec_config.num_nextn_predict_layers,
3232
max_num_requests=max_num_requests,
3333
mtp_hidden_states_manager=spec_resource_manager,
34+
allow_advanced_sampling=spec_config.allow_advanced_sampling,
3435
)
3536
if spec_config.spec_dec_mode.is_mtp_eagle():
3637
return Eagle3SpecMetadata(
@@ -46,6 +47,7 @@ def get_spec_metadata(spec_config,
4647
eagle3_resource_manager=spec_resource_manager,
4748
layers_to_capture=None,
4849
is_mtp_eagle=True,
50+
allow_advanced_sampling=spec_config.allow_advanced_sampling,
4951
)
5052
if spec_config.spec_dec_mode.is_eagle3():
5153
return Eagle3SpecMetadata(

tests/integration/defs/accuracy/test_llm_api_pytorch.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1317,7 +1317,6 @@ class TestDeepSeekV3Lite(LlmapiAccuracyTestHarness):
13171317
MODEL_PATH = f"{llm_models_root()}/DeepSeek-V3-Lite/bf16"
13181318

13191319
@pytest.mark.skip_less_device_memory(60000)
1320-
# Chunked Prefill for MLA can only be enabled on SM100
13211320
@parametrize_with_ids("enable_chunked_prefill", [False, True])
13221321
@parametrize_with_ids("torch_compile", [False, True])
13231322
@parametrize_with_ids("attention_dp,cuda_graph,overlap_scheduler",
@@ -1339,9 +1338,13 @@ def test_bfloat16(self, mtp_nextn, attention_dp, cuda_graph,
13391338
cuda_graph_config=CudaGraphConfig() if cuda_graph else None,
13401339
torch_compile_config=torch_compile_config,
13411340
)
1342-
mtp_config = None
1341+
13431342
if mtp_nextn > 0:
1344-
mtp_config = MTPDecodingConfig(num_nextn_predict_layers=mtp_nextn)
1343+
mtp_config = MTPDecodingConfig(num_nextn_predict_layers=mtp_nextn,
1344+
allow_advanced_sampling=True)
1345+
else:
1346+
mtp_config = None
1347+
13451348
with LLM(self.MODEL_PATH,
13461349
kv_cache_config=kv_cache_config,
13471350
enable_chunked_prefill=enable_chunked_prefill,

0 commit comments

Comments
 (0)