Skip to content

Commit 12e1f73

Browse files
kris1025dominicshanshan
authored andcommitted
[TRTLLM-7384][feat] enable rejection sampling for CDL (NVIDIA#7731)
Signed-off-by: linquanh <linquanh@nvidia.com>
1 parent 09a9d1b commit 12e1f73

File tree

8 files changed

+211
-143
lines changed

8 files changed

+211
-143
lines changed

tensorrt_llm/_torch/pyexecutor/llm_request.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -456,6 +456,8 @@ def __init__(
456456
self.use_draft_model = is_draft
457457
# Whether the request is for the first forward of the draft model.
458458
self.py_is_first_draft = is_first_draft
459+
self.d2t = None
460+
self.py_draft_use_greedy_sampling = False
459461

460462
# Chunked logits parameters
461463
self.py_use_chunked_generation_logits = use_chunked_generation_logits

tensorrt_llm/_torch/pyexecutor/py_executor_creator.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@
3131
from ..distributed import MPIDist, TorchDist
3232
from ..speculative import (get_num_extra_kv_tokens, get_spec_drafter,
3333
get_spec_resource_manager)
34-
from ..utils import _get_allow_chain_drafter
3534
from ._util import (KvCacheCreator, _adjust_torch_mem_fraction,
3635
create_py_executor_instance, instantiate_sampler, is_mla,
3736
validate_feature_combination)
@@ -344,13 +343,11 @@ def create_py_executor(
344343
_ExecutorCreationStage.MODEL_ENGINE_DRAFT):
345344
draft_spec_config = copy.copy(spec_config)
346345

347-
if _get_allow_chain_drafter():
348-
use_chain_drafter = (
349-
guided_decoding_config is None
350-
and draft_spec_config._allow_greedy_draft_tokens
351-
and pytorch_backend_config.attn_backend == "TRTLLM")
352-
else:
353-
use_chain_drafter = False
346+
use_chain_drafter = (
347+
guided_decoding_config is None
348+
and draft_spec_config._allow_chain_drafter
349+
and draft_spec_config._allow_greedy_draft_tokens
350+
and pytorch_backend_config.attn_backend == "TRTLLM")
354351

355352
logger.debug(f"USE CHAIN DRAFTER: {use_chain_drafter}")
356353
if use_chain_drafter:

tensorrt_llm/_torch/pyexecutor/sampler.py

Lines changed: 31 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -310,10 +310,15 @@ def greedy_search_sampling_batch(
310310
softmax_indices: Optional[torch.IntTensor] = None
311311
) -> tuple[torch.Tensor, torch.Tensor]:
312312
next_tokens = torch.argmax(logits, dim=-1)
313+
index_to_scatter = next_tokens
313314
if softmax_indices is not None:
314-
logits = logits[softmax_indices.to(logits.device, non_blocking=True)]
315-
softmax = torch.softmax(logits, dim=-1)
316-
return next_tokens, softmax
315+
logits = logits[softmax_indices]
316+
index_to_scatter = next_tokens[softmax_indices]
317+
probs = torch.zeros_like(logits)
318+
probs.scatter_(dim=-1,
319+
index=index_to_scatter.unsqueeze(-1),
320+
src=torch.ones_like(logits))
321+
return next_tokens, probs
317322

318323

319324
def get_rejected_indices(draft_probs: torch.Tensor, target_probs: torch.Tensor,
@@ -1127,20 +1132,38 @@ def _tree_sampling_batch(self, requests: list[LlmRequest],
11271132

11281133
return new_draft_tokens_host
11291134

1135+
@torch.inference_mode()
11301136
def _process_draft_tokens_rejection_sampling(
11311137
self, request: LlmRequest, new_tokens: list[list[list[int]]],
11321138
new_tokens_tensor: torch.Tensor) -> int:
11331139
# FIXME: Passing a dummy vocab_size could result in unnecessary
11341140
# filtering of vocab_size logits, out of vocab_size in
11351141
# total. The 'sample' below should generally be avoided
11361142
# by retaining the draft_probs during drafting (TRTLLM-7772).
1137-
sampling_strategy = _request_strategy(request, vocab_size=2**31)
1143+
draft_sampling_strategy = (
1144+
"greedy", None
1145+
) if request.py_draft_use_greedy_sampling else _request_strategy(
1146+
request, vocab_size=2**31)
11381147
generator = self.get_generator(request.py_draft_logits.device)
1139-
_, draft_probs = sample(sampling_strategy,
1148+
_, draft_probs = sample(draft_sampling_strategy,
11401149
request.py_draft_logits,
11411150
generator=generator)
1142-
draft_probs = draft_probs.squeeze(0)
11431151
target_probs = request.py_target_probs
1152+
d2t = getattr(request, "d2t", None)
1153+
if d2t is not None:
1154+
vocab_d = draft_probs.shape[-1]
1155+
vocab_t = target_probs.shape[-1]
1156+
assert d2t.numel(
1157+
) == vocab_d, f"d2t size mismatch: {d2t.numel()} != {vocab_d}"
1158+
assert d2t.device == draft_probs.device, f"d2t device mismatch: {d2t.device} != {draft_probs.device}"
1159+
aligned_draft_probs = torch.zeros(
1160+
(*draft_probs.shape[:-1], vocab_t),
1161+
device=draft_probs.device,
1162+
dtype=draft_probs.dtype)
1163+
source_indices = torch.arange(vocab_d, device=draft_probs.device)
1164+
target_indices = (source_indices + d2t) % vocab_t
1165+
aligned_draft_probs[..., target_indices] = draft_probs
1166+
draft_probs = aligned_draft_probs
11441167
rejected_indices = get_rejected_indices(draft_probs, target_probs,
11451168
generator,
11461169
request.py_draft_tokens)
@@ -1181,7 +1204,8 @@ def process_draft_tokens(
11811204
new_tokens: list[list[list[int]]],
11821205
new_tokens_tensor: torch.Tensor,
11831206
resource_manager: Optional[ResourceManager] = None) -> int:
1184-
if request.py_draft_logits is None:
1207+
if _request_strategy(request, vocab_size=2**
1208+
31) == GREEDY or request.py_draft_logits is None:
11851209
spec_tree_manager = self.get_spec_tree_manager(resource_manager)
11861210
if spec_tree_manager is not None:
11871211
num_accepted = self._process_draft_tokens_tree(

tensorrt_llm/_torch/speculative/drafting_loops.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -116,8 +116,7 @@ def __init__(self, max_draft_len: int, draft_model: torch.nn.Module):
116116

117117
def forward(self, input_ids: torch.Tensor, position_ids: torch.Tensor,
118118
attn_metadata: AttentionMetadata, spec_metadata: SpecMetadata,
119-
**kwargs) -> torch.Tensor:
120-
119+
**kwargs) -> dict[str, torch.Tensor]:
121120
logits = self.draft_model.forward(input_ids=input_ids,
122121
position_ids=position_ids,
123122
attn_metadata=attn_metadata,
@@ -126,6 +125,7 @@ def forward(self, input_ids: torch.Tensor, position_ids: torch.Tensor,
126125
logits = logits[spec_metadata.gather_ids]
127126

128127
new_draft_tokens = [self.sample(logits)]
128+
draft_logits = [logits]
129129
with save_metadata_state(attn_metadata, spec_metadata):
130130
batch_size = attn_metadata.num_seqs
131131

@@ -139,13 +139,17 @@ def forward(self, input_ids: torch.Tensor, position_ids: torch.Tensor,
139139
attn_metadata=attn_metadata,
140140
spec_metadata=spec_metadata)
141141
new_draft_tokens.append(self.sample(logits))
142+
draft_logits.append(logits)
142143
new_position_ids += 1
143144
attn_metadata.kv_lens_cuda[:batch_size] += 1
144145
if i == 0 and isinstance(spec_metadata, Eagle3SpecMetadata):
145146
spec_metadata.hidden_states_read_indices[:batch_size].copy_(
146147
spec_metadata.hidden_states_write_indices[:batch_size])
147148

148-
return torch.stack(new_draft_tokens)
149+
return {
150+
"new_draft_tokens": torch.stack(new_draft_tokens),
151+
"draft_logits": torch.stack(draft_logits)
152+
}
149153

150154
def sample(self, logits: torch.Tensor) -> torch.Tensor:
151155
# TODO: inject the sampler here so we can support non-greedy

tensorrt_llm/_torch/speculative/model_drafter.py

Lines changed: 34 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,12 @@ def _prepare_draft_batch(
231231
ScheduledRequests: The prepared draft batch
232232
"""
233233
try:
234+
for req in scheduled_requests.all_requests():
235+
draft_model = self.draft_model_engine.model.draft_model if self.use_static_draft_loop else self.draft_model_engine.model
236+
if hasattr(draft_model.model, "d2t"):
237+
req.d2t = draft_model.model.d2t.data
238+
req.py_draft_use_greedy_sampling = self.use_static_draft_loop
239+
234240
draft_batch = ScheduledRequests()
235241

236242
for request in scheduled_requests.context_requests:
@@ -530,7 +536,8 @@ def _setup_draft_batch_and_resources(
530536
return draft_batch, req_id_to_old_request
531537

532538
def process_static_draft_outputs(
533-
self, outputs: torch.Tensor | SampleState,
539+
self,
540+
outputs: dict[str, torch.Tensor] | tuple[torch.Tensor, SampleState],
534541
draft_batch: ScheduledRequests,
535542
req_id_to_old_request: Dict[int, LlmRequest]) -> None:
536543
"""
@@ -541,23 +548,26 @@ def process_static_draft_outputs(
541548
draft_batch: The draft batch that was processed
542549
req_id_to_old_request: Mapping from draft request ID to original request
543550
"""
544-
if isinstance(outputs, torch.Tensor):
545-
# For non-overlap scheduler path.
546-
outputs_host = outputs.cpu()
551+
552+
if isinstance(outputs, dict):
553+
draft_tokens_host = outputs["new_draft_tokens"].cpu()
554+
draft_logits = outputs["draft_logits"]
547555
else:
548-
outputs_host = outputs.host.new_tokens
549-
outputs.sampler_event.synchronize()
550-
551-
for token_idx in range(self.max_draft_tokens):
552-
for req_idx, req in enumerate(draft_batch.all_requests()):
553-
target_model_req = req_id_to_old_request[req.py_request_id]
554-
if target_model_req.state != LlmRequestState.GENERATION_IN_PROGRESS:
555-
# Chunked prefill request in progress; no need to append draft tokens
556-
continue
556+
draft_logits = outputs[0]
557+
draft_tokens_host = outputs[1].host.new_tokens
558+
outputs[1].sampler_event.synchronize()
557559

558-
target_req = req_id_to_old_request[req.py_request_id]
559-
target_req.py_draft_tokens.append(
560-
outputs_host[token_idx][req_idx])
560+
for req_idx, req in enumerate(draft_batch.all_requests()):
561+
target_model_req = req_id_to_old_request[req.py_request_id]
562+
if target_model_req.state != LlmRequestState.GENERATION_IN_PROGRESS:
563+
# Chunked prefill request in progress; no need to append draft tokens
564+
continue
565+
py_draft_logits = []
566+
for token_idx in range(self.max_draft_tokens):
567+
target_model_req.py_draft_tokens.append(
568+
draft_tokens_host[token_idx][req_idx])
569+
py_draft_logits.append(draft_logits[token_idx][req_idx])
570+
target_model_req.py_draft_logits = torch.stack(py_draft_logits)
561571

562572
# Clean up draft resources
563573
for req in draft_batch.all_requests():
@@ -710,23 +720,26 @@ def generate_draft_tokens_with_overlap(
710720
# Only update target inputs, cleanup will be done in executor loop
711721
self._update_target_inputs_with_draft_tokens(
712722
target_inputs,
713-
outputs,
723+
outputs["new_draft_tokens"],
714724
draft_position=0,
715725
draft_length=self.max_draft_tokens,
716726
draft_batch=draft_batch,
717727
req_id_to_old_request=req_id_to_old_request)
718728

719-
new_tokens_host = outputs.to(device='cpu', non_blocking=True)
729+
new_tokens_host = outputs["new_draft_tokens"].to(device='cpu',
730+
non_blocking=True)
720731
sampler_event = torch.cuda.Event()
721732
sampler_event.record()
722733

723-
outputs = SampleState(
734+
sample_state = SampleState(
724735
scheduled_requests=draft_batch,
725-
device=SampleStateTensors(new_tokens=outputs),
736+
device=SampleStateTensors(
737+
new_tokens=outputs["new_draft_tokens"]),
726738
host=SampleStateTensors(new_tokens=new_tokens_host),
727739
sampler_event=sampler_event)
728740

729-
return target_inputs, outputs, draft_batch
741+
return target_inputs, (outputs["draft_logits"],
742+
sample_state), draft_batch
730743

731744
# Handle guided decoder and sampling for non-static loop
732745
if self.guided_decoder is not None:

tensorrt_llm/_torch/utils.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -308,12 +308,6 @@ def create_lm_head_tp_mapping(mapping: Mapping, token_count: int) -> Mapping:
308308
)
309309

310310

311-
# Development function to control chain drafter feature.
312-
# It's here so that unit tests can mock it and turn it off.
313-
def _get_allow_chain_drafter() -> bool:
314-
return True
315-
316-
317311
def get_device_uuid(device_idx: int) -> str:
318312
"""Get the UUID of a CUDA device using torch cuda api"""
319313

tensorrt_llm/llmapi/llm_args.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -366,6 +366,8 @@ class DecodingBaseConfig(StrictBaseModel):
366366

367367
load_format: Optional[str] = None
368368

369+
# If set, drafting is allowed to use chain drafter.
370+
_allow_chain_drafter: bool = PrivateAttr(True)
369371
# If set, drafting uses greedy sampling, irrespective of sampling parameters.
370372
_allow_greedy_draft_tokens: bool = PrivateAttr(True)
371373

0 commit comments

Comments
 (0)