Skip to content

Commit 383b13e

Browse files
authored
[None][feat] Implement sampling on 1-model EAGLE3 (NVIDIA#9885)
Signed-off-by: Mike Iovine <6158008+mikeiovine@users.noreply.github.com> Signed-off-by: Mike Iovine <miovine@nvidia.com>
1 parent 079ef8a commit 383b13e

File tree

9 files changed

+248
-5
lines changed

9 files changed

+248
-5
lines changed

examples/llm-api/quickstart_advanced.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,9 @@ def add_llm_args(parser):
143143
default=False,
144144
action='store_true')
145145
parser.add_argument('--dynamic_tree_max_topK', type=int, default=None)
146+
parser.add_argument('--allow_advanced_sampling',
147+
default=False,
148+
action='store_true')
146149

147150
# Relaxed acceptance
148151
parser.add_argument('--use_relaxed_acceptance_for_thinking',
@@ -210,7 +213,9 @@ def setup_llm(args, **kwargs):
210213
eagle3_one_model=args.use_one_model,
211214
eagle_choices=args.eagle_choices,
212215
use_dynamic_tree=args.use_dynamic_tree,
213-
dynamic_tree_max_topK=args.dynamic_tree_max_topK)
216+
dynamic_tree_max_topK=args.dynamic_tree_max_topK,
217+
allow_advanced_sampling=args.allow_advanced_sampling)
218+
214219
elif spec_decode_algo == "DRAFT_TARGET":
215220
spec_config = DraftTargetDecodingConfig(
216221
max_draft_len=args.spec_decode_max_draft_len,

tensorrt_llm/_torch/pyexecutor/model_engine.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,8 @@
4848
get_spec_metadata,
4949
update_spec_config_from_model_config)
5050
from ..speculative.drafting_loops import BaseDraftingLoopWrapper
51-
from ..speculative.eagle3 import Eagle3ResourceManager, Eagle3SpecMetadata
51+
from ..speculative.eagle3 import (Eagle3OneModelSpecMetadata,
52+
Eagle3ResourceManager, Eagle3SpecMetadata)
5253
from ..speculative.mtp import SampleStateTensorsMTP
5354
from ..speculative.utils import SpecDecodingTensor
5455
from ..utils import (get_model_extra_attrs,
@@ -2093,6 +2094,9 @@ def previous_seq_slots_device():
20932094
num_accepted_draft_tokens)]
20942095
if isinstance(spec_metadata, Eagle3SpecMetadata):
20952096
spec_metadata.request_accepted_path = request_accepted_path
2097+
if isinstance(spec_metadata, Eagle3OneModelSpecMetadata):
2098+
spec_metadata.populate_sampling_params_for_one_model(
2099+
scheduled_requests.all_requests())
20962100
spec_metadata.prepare()
20972101
inputs['spec_metadata'] = spec_metadata
20982102

tensorrt_llm/_torch/pyexecutor/py_executor_creator.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -281,6 +281,17 @@ def create_py_executor(
281281
)
282282
llm_args.disable_overlap_scheduler = True
283283

284+
if spec_config is not None and spec_config.spec_dec_mode.use_one_engine():
285+
if not spec_config.allow_advanced_sampling:
286+
logger.warning(
287+
f"Falling back to greedy decoding for {spec_config.decoding_type}. If you "
288+
"want to use non-greedy sampling, please set allow_advanced_sampling=True."
289+
)
290+
elif spec_config.spec_dec_mode.is_mtp_one_model():
291+
logger.warning(
292+
"Advanced sampling is not supported for MTP yet - this will be added soon."
293+
)
294+
284295
if mm_encoder_only:
285296
llm_args.mm_encoder_only = True
286297
llm_args.disable_overlap_scheduler = True

tensorrt_llm/_torch/speculative/eagle3.py

Lines changed: 41 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from ..pyexecutor.scheduler import ScheduledRequests
1515
from .interface import SpecMetadata, get_force_num_accepted_tokens
1616
from .mtp import MTPSampler
17+
from .one_model_sampler import sampling_batch_spec_dec_one_model
1718
from .spec_tree_manager import SpecTreeManager
1819

1920
if TYPE_CHECKING:
@@ -493,6 +494,40 @@ def forward(self, input_ids, position_ids, hidden_states, logits,
493494
'next_new_tokens': next_new_tokens,
494495
}
495496

497+
def _sample_tokens_for_batch(
498+
self,
499+
logits: torch.Tensor,
500+
spec_metadata: Eagle3OneModelSpecMetadata,
501+
num_contexts: int,
502+
batch_size: int,
503+
) -> torch.Tensor:
504+
"""
505+
Sample tokens from logits using per-request sampling parameters.
506+
Supports both greedy and non-greedy sampling.
507+
508+
Args:
509+
logits: [num_tokens, vocab_size] - Logits to sample from
510+
spec_metadata: Metadata containing sampling parameters
511+
batch_size: Number of requests in the batch
512+
513+
Returns:
514+
sampled_tokens: [num_tokens] - Sampled token ids
515+
"""
516+
if spec_metadata.allow_advanced_sampling:
517+
num_gens = batch_size - num_contexts
518+
num_tokens = num_contexts + num_gens * (self.max_draft_len + 1)
519+
520+
temperatures = spec_metadata.temperatures[:num_tokens]
521+
top_ks = spec_metadata.top_ks[:num_tokens]
522+
top_ps = spec_metadata.top_ps[:num_tokens]
523+
524+
sampled_tokens = sampling_batch_spec_dec_one_model(
525+
logits, temperatures, top_ks, top_ps)
526+
else:
527+
sampled_tokens = torch.argmax(logits, dim=-1)
528+
529+
return sampled_tokens
530+
496531
def sample_and_accept_draft_tokens(
497532
self,
498533
logits: torch.Tensor,
@@ -514,8 +549,9 @@ def sample_and_accept_draft_tokens(
514549
dtype=torch.int,
515550
device=logits.device)
516551

517-
# Do greedy sampling for the input logits
518-
target_tokens = torch.argmax(logits, dim=-1)
552+
# Sample tokens using per-request sampling parameters
553+
target_tokens = self._sample_tokens_for_batch(logits, spec_metadata,
554+
num_contexts, batch_size)
519555
# context
520556
accepted_tokens[:num_contexts, 0] = target_tokens[:num_contexts]
521557

@@ -557,6 +593,9 @@ def draft_decoder(
557593
Draft token ids. Flattened.
558594
'''
559595

596+
# Note: using greedy for draft tokens is a bit easier to implement and
597+
# faster. It doesn't affect the final output and seems to have a negligible
598+
# impact on AR.
560599
draft_tokens = torch.argmax(logits, dim=-1)
561600

562601
# Apply d2t (offsets between draft model dictionary and main model dictionary).

tensorrt_llm/_torch/speculative/interface.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,13 @@ class SpecMetadata:
229229
# whether the spec-dec mode is a dynamic tree.
230230
is_spec_dec_dynamic_tree: bool = False
231231

232+
# For non-greedy sampling on 1-model.
233+
allow_advanced_sampling: bool = False
234+
# Sampling parameters for non-greedy sampling (per-request)
235+
temperatures: Optional[torch.Tensor] = None
236+
top_ks: Optional[torch.Tensor] = None
237+
top_ps: Optional[torch.Tensor] = None
238+
232239
def __post_init__(self):
233240
pass
234241

@@ -264,3 +271,83 @@ def maybe_capture_hidden_states(self, layer_id: int,
264271
Some spec decode algorithms require hidden states from the target
265272
model. Use this method to record them. By default, does nothing.
266273
"""
274+
275+
def populate_sampling_params_for_one_model(
276+
self, requests: list["LlmRequest"]) -> None:
277+
"""
278+
Set up topp/topk/temperatures for 1-model sampler.
279+
"""
280+
from tensorrt_llm._torch.pyexecutor.llm_request import LlmRequestState
281+
from tensorrt_llm.sampling_params import SamplingParams
282+
283+
if not self.allow_advanced_sampling or not self.spec_dec_mode.use_one_engine(
284+
):
285+
return
286+
287+
if self.temperatures is None:
288+
# Ensures determinism across ranks.
289+
torch.manual_seed(0)
290+
291+
temperatures = []
292+
top_ks = []
293+
top_ps = []
294+
295+
# Need to use a very small value for temperature when disabled to avoid division by 0
296+
DISABLE_TEMP_VAL = 1e-5
297+
# Very large values disable topk.
298+
DISABLE_TOPK_VAL = torch.iinfo(torch.int32).max
299+
DISABLE_TOPP_VAL = 1.0
300+
301+
for request in requests:
302+
sampling_config = request.sampling_config
303+
temp = sampling_config.temperature
304+
temp_val = temp[0] if temp is not None and len(temp) > 0 else None
305+
306+
tk = sampling_config.top_k
307+
tk_val = tk[0] if tk is not None and len(tk) > 0 else None
308+
309+
tp = sampling_config.top_p
310+
tp_val = tp[0] if tp is not None and len(tp) > 0 else None
311+
312+
# Context requests have no draft tokens yet.
313+
num_tokens = 1 + self.max_draft_len if request.state == LlmRequestState.GENERATION_IN_PROGRESS else 1
314+
315+
is_greedy = SamplingParams.params_imply_greedy_decoding(
316+
temperature=temp_val,
317+
top_k=tk_val,
318+
top_p=tp_val,
319+
use_beam_search=False)
320+
321+
temp_val = DISABLE_TEMP_VAL if is_greedy or temp_val is None or temp_val == 0 else temp_val
322+
tk_val = DISABLE_TOPK_VAL if is_greedy or tk_val is None or tk_val <= 0 else tk_val
323+
tp_val = DISABLE_TOPP_VAL if is_greedy or tp_val is None else tp_val
324+
325+
temperatures.extend(temp_val for _ in range(num_tokens))
326+
top_ks.extend(tk_val for _ in range(num_tokens))
327+
top_ps.extend(tp_val for _ in range(num_tokens))
328+
329+
if self.temperatures is None:
330+
self.temperatures = torch.ones(
331+
(self.max_draft_len + 1) * self.max_num_requests,
332+
dtype=torch.float32,
333+
device='cuda')
334+
self.top_ks = torch.zeros(
335+
(self.max_draft_len + 1) * self.max_num_requests,
336+
dtype=torch.int32,
337+
device='cuda')
338+
self.top_ps = torch.ones(
339+
(self.max_draft_len + 1) * self.max_num_requests,
340+
dtype=torch.float32,
341+
device='cuda')
342+
343+
self.temperatures[:len(temperatures)].copy_(torch.tensor(
344+
temperatures, dtype=torch.float32, pin_memory=True),
345+
non_blocking=True)
346+
self.top_ks[:len(top_ks)].copy_(torch.tensor(top_ks,
347+
dtype=torch.int32,
348+
pin_memory=True),
349+
non_blocking=True)
350+
self.top_ps[:len(top_ps)].copy_(torch.tensor(top_ps,
351+
dtype=torch.float32,
352+
pin_memory=True),
353+
non_blocking=True)
Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
from typing import Optional
2+
3+
import torch
4+
5+
6+
def forward_native(
7+
logits: torch.Tensor,
8+
k: Optional[torch.Tensor],
9+
p: Optional[torch.Tensor],
10+
) -> torch.Tensor:
11+
"""
12+
PyTorch-native implementation of top-k and top-p sampling.
13+
14+
The logits tensor may be updated in-place.
15+
"""
16+
logits = apply_top_k_top_p(logits, k, p)
17+
probs = logits.softmax(dim=-1, dtype=torch.float32)
18+
return random_sample(probs)
19+
20+
21+
def random_sample(
22+
probs: torch.Tensor,
23+
) -> torch.Tensor:
24+
"""Randomly sample from the probabilities.
25+
26+
We use this function instead of torch.multinomial because torch.multinomial
27+
causes CPU-GPU synchronization.
28+
"""
29+
q = torch.empty_like(probs).exponential_()
30+
return probs.div_(q).argmax(dim=-1).view(-1)
31+
32+
33+
def apply_top_k_top_p(
34+
logits: torch.Tensor,
35+
k: Optional[torch.Tensor],
36+
p: Optional[torch.Tensor],
37+
) -> torch.Tensor:
38+
"""Apply top-k and top-p masks to the logits.
39+
40+
If a top-p is used, this function will sort the logits tensor,
41+
which can be slow for large batches.
42+
43+
The logits tensor may be updated in-place.
44+
"""
45+
logits_sort, logits_idx = logits.sort(dim=-1, descending=False)
46+
if k is not None:
47+
# Apply top-k.
48+
top_k_mask = logits_sort.size(1) - k.to(torch.long) # shape: B
49+
top_k_mask = top_k_mask.clamp(min=0)
50+
# Get all the top_k values.
51+
top_k_mask = logits_sort.gather(1, top_k_mask.unsqueeze(dim=1))
52+
top_k_mask = logits_sort < top_k_mask
53+
logits_sort.masked_fill_(top_k_mask, -float("inf"))
54+
55+
if p is not None:
56+
# Apply top-p.
57+
probs_sort = logits_sort.softmax(dim=-1)
58+
probs_sum = torch.cumsum(probs_sort, dim=-1, out=probs_sort)
59+
top_p_mask = probs_sum <= 1 - p.unsqueeze(dim=1)
60+
# at least one
61+
top_p_mask[:, -1] = False
62+
logits_sort.masked_fill_(top_p_mask, -float("inf"))
63+
# Re-sort the probabilities.
64+
logits = logits_sort.scatter(dim=-1, index=logits_idx, src=logits_sort)
65+
return logits
66+
67+
68+
def apply_temperature(
69+
logits: torch.Tensor,
70+
temp: torch.Tensor,
71+
) -> torch.Tensor:
72+
return logits.div_(temp.unsqueeze(dim=1))
73+
74+
75+
@torch.compile(options={"max-autotune": True})
76+
def sampling_batch_spec_dec_one_model(
77+
logits: torch.Tensor,
78+
temperatures: torch.Tensor,
79+
top_k: torch.Tensor,
80+
top_p: torch.Tensor,
81+
) -> tuple[torch.Tensor, torch.Tensor]:
82+
"""
83+
CUDA-graph compatible sampling. Supports mixed sampling params.
84+
85+
We can't do dynamic kernel selection inside graphs, so this might
86+
be slower than a torch.argmax for greedy requests. This is why advanced
87+
sampling is opt-in for now.
88+
"""
89+
logits = apply_temperature(logits, temperatures)
90+
random_sampled = forward_native(logits, top_k, top_p)
91+
return random_sampled

tensorrt_llm/_torch/speculative/utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ def get_spec_metadata(spec_config,
7676
hidden_size=model_config.hidden_size,
7777
max_num_tokens=max_num_tokens,
7878
layers_to_capture=spec_config.eagle3_layers_to_capture,
79+
allow_advanced_sampling=spec_config.allow_advanced_sampling,
7980
)
8081
if spec_config.spec_dec_mode.is_save_hidden_states():
8182
if spec_config.eagle3_layers_to_capture is None:

tensorrt_llm/llmapi/llm_args.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -619,6 +619,10 @@ class DecodingBaseConfig(StrictBaseModel):
619619
# (N = acceptance_window) drops below this value.
620620
acceptance_length_threshold: Optional[float] = None
621621

622+
# Prototype. If true, allows non-greedy sampling when speculation is used. Only applicable
623+
# to 1-model code paths; non-greedy sampling is always enabled on 2-model paths.
624+
allow_advanced_sampling: bool = False
625+
622626
# Validate acceptance controls at field level so they run on model creation
623627
@field_validator('acceptance_window')
624628
@classmethod

tests/integration/defs/accuracy/test_llm_api_pytorch.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4307,7 +4307,8 @@ def test_eagle3_4gpus(self, moe_backend, one_model, overlap_scheduler,
43074307
draft_len = 3
43084308
spec_config = EagleDecodingConfig(max_draft_len=draft_len,
43094309
speculative_model_dir=eagle_model_dir,
4310-
eagle3_one_model=one_model)
4310+
eagle3_one_model=one_model,
4311+
allow_advanced_sampling=True)
43114312

43124313
max_seq_len = MAX_INPUT_LEN + MAX_OUTPUT_LEN
43134314
llm = LLM(self.MODEL_PATH,

0 commit comments

Comments
 (0)