Skip to content

Commit 62010c0

Browse files
authored
[None][feat] Return topk logprobs in torch backend (#7976)
Signed-off-by: Cao Dong <[email protected]>
1 parent cdce68c commit 62010c0

File tree

7 files changed

+53
-66
lines changed

7 files changed

+53
-66
lines changed

tensorrt_llm/_torch/pyexecutor/llm_request.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -311,6 +311,7 @@ def __init__(
311311
is_draft: bool = False,
312312
seq_slot: Optional[int] = None,
313313
target_seq_slot: Optional[int] = None,
314+
num_logprobs: int = 0,
314315
is_first_draft: bool = False,
315316
**kwargs):
316317

@@ -356,6 +357,7 @@ def __init__(
356357
tensorrt_llm.bindings.internal.runtime.
357358
TaskLayerModuleConfig] | None = None
358359

360+
self.py_num_logprobs = num_logprobs
359361
self.py_return_log_probs = return_log_probs
360362
self.py_return_context_logits = return_context_logits
361363
self.py_return_generation_logits = return_generation_logits
@@ -565,6 +567,8 @@ def executor_request_to_llm_request(
565567
mrope_position_deltas=mrope_position_deltas,
566568
lookahead_config=None,
567569
return_log_probs=executor_request.output_config.return_log_probs,
570+
num_logprobs=executor_request.py_num_logprobs if hasattr(
571+
executor_request, "py_num_logprobs") else 0,
568572
return_context_logits=executor_request.output_config.
569573
return_context_logits,
570574
return_perf_metrics=executor_request.output_config.return_perf_metrics,

tensorrt_llm/_torch/pyexecutor/sampler.py

Lines changed: 33 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from typing import Any, List, Literal, Optional, cast
1010

1111
import torch
12+
import torch.nn.functional as F
1213

1314
from tensorrt_llm._torch.pyexecutor.make_decoding_batch_input_output import \
1415
MakeDecodingBatchInputOutput
@@ -891,13 +892,16 @@ def handle_logprobs(self, request: LlmRequest, state: SampleState, *,
891892
beam: int, count: int):
892893
current_slice = slice(0, count), request.py_seq_slot, beam
893894
if request.py_return_log_probs:
894-
assert state.host.log_probs is not None
895-
log_probs = state.host.log_probs[request.py_seq_slot][beam][:count]
896-
current_tokens = state.host.new_tokens[current_slice]
895+
topk_log_probs_vals = request.py_topk_logprobs_vals[:count]
896+
topk_log_probs_indices = request.py_topk_logprobs_indices[:count]
897897

898898
token_log_probs = [{
899-
int(token): Logprob(logprob=logprob, rank=1)
900-
} for token, logprob in zip(current_tokens, log_probs.tolist())]
899+
token: Logprob(logprob=logprob, rank=rank + 1)
900+
for rank, (token, logprob) in enumerate(
901+
zip(topk_token.tolist(), topk_logprob.tolist()))
902+
}
903+
for topk_token, topk_logprob in zip(
904+
topk_log_probs_indices, topk_log_probs_vals)]
901905
assert beam == 0, "The following call relies on beam_width to be 1 - hence the list with a single element"
902906
request.py_result.append_log_probs([token_log_probs])
903907

@@ -1162,13 +1166,8 @@ def log_probs_host(
11621166
self,
11631167
scheduled_requests: ScheduledRequests) -> Optional[torch.Tensor]:
11641168
"""Shape: In lockstep with TRTLLMSampler: https://github.com/NVIDIA/TensorRT-LLM/blob/cea5dd1e3883b18bf50901a7f196f50a9544c28c/cpp/include/tensorrt_llm/runtime/decoderState.h#L103"""
1165-
if any(req.py_return_log_probs
1166-
for req in scheduled_requests.all_requests()):
1167-
return torch.empty(
1168-
(self.max_num_sequences, self.MAX_BEAM_WIDTH, self.max_tokens),
1169-
device="cpu",
1170-
pin_memory=True)
1171-
return None
1169+
return any(req.py_return_log_probs
1170+
for req in scheduled_requests.all_requests())
11721171

11731172
@override
11741173
@torch.inference_mode()
@@ -1198,8 +1197,7 @@ def sample_async(
11981197
sampler_event.record()
11991198
return SampleState(scheduled_requests=scheduled_requests,
12001199
device=SampleStateTensors(new_tokens=new_tokens),
1201-
host=SampleStateTensors(new_tokens=new_tokens_host,
1202-
log_probs=log_probs_host),
1200+
host=SampleStateTensors(new_tokens=new_tokens_host),
12031201
sampler_event=sampler_event)
12041202

12051203
@staticmethod
@@ -1308,12 +1306,22 @@ def _sample_batched_by_strategy(
13081306
model_outputs: dict[str, torch.Tensor],
13091307
*,
13101308
cuda_device: torch.device,
1311-
log_probs_host: torch.Tensor | None = None,
1309+
log_probs_host: bool = False,
13121310
req_num_steps: torch.Tensor,
13131311
req_offsets: torch.Tensor,
13141312
steps_dim_size: int,
13151313
token_dtype: torch.dtype,
13161314
) -> _BatchedSamplingResult:
1315+
if log_probs_host:
1316+
assert logits_cuda.dim() == 2, "logits should be 2D"
1317+
logprobs = F.log_softmax(logits_cuda.to("cuda",
1318+
dtype=torch.float32),
1319+
dim=-1)
1320+
topk_vals, topk_indices = torch.topk(logprobs,
1321+
k=max(req.py_num_logprobs
1322+
for req in requests),
1323+
dim=-1)
1324+
13171325
requests_by_strategy = _group_requests_by_sampling_strategy(
13181326
requests, pin_memory=True)
13191327
generator_cuda = self.get_generator(cuda_device)
@@ -1357,12 +1365,20 @@ def _sample_batched_by_strategy(
13571365
# softmax_grp_indices: Indices of 'speculation_group_indices' entries requesting probs
13581366
# speculation_softmax_indices: Indices of 'softmax_grp_indices' entries corresponding
13591367
# to requests with draft logits.
1360-
if log_probs_host is not None:
1368+
if log_probs_host:
13611369
softmax_req_indices = group_req_indices
13621370
softmax_grp_indices = torch.arange(len(group_req_indices),
13631371
dtype=torch.int32)
13641372
speculation_softmax_indices = torch.tensor(
13651373
speculation_group_indices, dtype=torch.int32)
1374+
for req_id in group_req_indices:
1375+
req = requests[req_id]
1376+
req.py_topk_logprobs_vals = topk_vals[
1377+
logits_cuda_indexer[req_id], :req.py_num_logprobs].to(
1378+
device="cpu", non_blocking=True)
1379+
req.py_topk_logprobs_indices = topk_indices[
1380+
logits_cuda_indexer[req_id], :req.py_num_logprobs].to(
1381+
device="cpu", non_blocking=True)
13661382
else:
13671383
speculation_group_indices_tensor = torch.tensor(
13681384
speculation_group_indices, dtype=torch.int32)
@@ -1462,7 +1478,7 @@ def _unbatch_sampling_results(
14621478
new_tokens_cuda: torch.Tensor,
14631479
req_num_steps: torch.Tensor,
14641480
seq_slots: torch.Tensor,
1465-
log_probs_host: torch.Tensor | None = None,
1481+
log_probs_host: bool = False,
14661482
) -> torch.Tensor:
14671483
beam = self.BEAM
14681484
assert beam == 0, "beam_width != 1 not supported"
@@ -1479,17 +1495,6 @@ def _dims_canonically_ordered(t: torch.Tensor) -> bool:
14791495
# Assert destination tensor dimensions are canonically ordered ("row"-major); this
14801496
# matters for element ordering in the .view(...).scatter_(...) calls below.
14811497
assert _dims_canonically_ordered(new_tokens_cuda)
1482-
assert log_probs_host is None or _dims_canonically_ordered(
1483-
log_probs_host)
1484-
1485-
# new_tokens_cuda indexed by
1486-
# slice(0, steps), slot, beam
1487-
# log_probs_host indexed by
1488-
# slot, beam, slice(0, steps)
1489-
# batch_... tensors indexed by slice(batch_req_index, batch_req_index + steps)
1490-
#
1491-
if log_probs_host is not None:
1492-
assert new_tokens_cuda.size(0) == log_probs_host.size(-2)
14931498

14941499
# Construct index mapping from slice indices of computed tensors
14951500
# (packed request_idx and step dimensions) to linearized indices
@@ -1511,39 +1516,6 @@ def _dims_canonically_ordered(t: torch.Tensor) -> bool:
15111516
0, batch_dest_indices_1d_cuda,
15121517
batch_next_tokens_cuda_int)
15131518
new_tokens_host = new_tokens_cuda.to("cpu", non_blocking=True)
1514-
# NB: In order to avoid a scatter_ on the host and the necessary D2H copy + synchronization,
1515-
# the 'step' and 'seq_slot' dimensions are unpacked on GPU and later asynchronously
1516-
# copied into the destination buffer. Note that this overwrites all 'step' and token slots for the
1517-
# requests in 'requests' (passed to _process_requests). In fact, the current implementation
1518-
# even overwrites the destination tensors completely (including slices corresponding to request
1519-
# slots not present in 'requests', cf. 'FIXME' below).
1520-
if log_probs_host is not None:
1521-
# FIXME: If log_probs_host were indexed by request indices, rather than request slots, this
1522-
# tensor could be packed densely along the request axis.
1523-
log_probs_cuda = torch.empty_like(
1524-
log_probs_host, device=batch_dest_indices_1d_cuda.device)
1525-
# FIXME: Needs a separate indexer because tensor layout differs from new_tokens_cuda
1526-
batch_dest_probs_cuda_indexer = _UnpackedStepIndexer(
1527-
seq_slots=seq_slots[batch_req_indices],
1528-
num_steps=req_num_steps[batch_req_indices],
1529-
steps_dim_size=new_tokens_cuda.size(0),
1530-
slots_dim_size=new_tokens_cuda.size(1),
1531-
dim_order=_UnpackedStepIndexer.DimOrder.SLOT_MAJOR,
1532-
index_dtype=torch.int64, # enforced by Tensor.scatter_
1533-
)
1534-
batch_dest_probs_indices_cuda = batch_dest_probs_cuda_indexer[:].to(
1535-
batch_softmax_cuda.device, non_blocking=True)
1536-
# NB: torch.arange is needed to enable "advanced indexing",
1537-
# cf. https://numpy.org/devdocs/user/basics.indexing.html#integer-array-indexing
1538-
batch_token_probs = batch_softmax_cuda[
1539-
torch.arange(batch_softmax_cuda.size(0),
1540-
device=batch_softmax_cuda.device,
1541-
dtype=torch.int32), batch_next_tokens_cuda_int]
1542-
log_probs_cuda[:, beam,
1543-
...].view(-1, *log_probs_cuda.shape[3:]).scatter_(
1544-
0, batch_dest_probs_indices_cuda,
1545-
torch.log(batch_token_probs))
1546-
log_probs_host.copy_(log_probs_cuda, non_blocking=True)
15471519
# For requests with LlmRequest.py_draft_logits, return py_target_probs
15481520
for request, batch_softmax_index_cuda in py_draft_logits_indices:
15491521
request.py_target_probs = batch_softmax_cuda[

tensorrt_llm/executor/base_worker.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -480,6 +480,7 @@ def _deduce_max_tokens(request: GenerationRequest,
480480
context_phase_params=context_phase_params,
481481
type=request_type,
482482
cache_salt_id=request.cache_salt_id)
483+
executor_request.py_num_logprobs = request.sampling_params.logprobs
483484
executor_request.py_lora_path = py_lora_path
484485

485486
if self._is_pytorch_backend and request.multimodal_params is not None:

tensorrt_llm/llmapi/llm.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -598,10 +598,6 @@ def _check_arguments(self, prompt_len: int, query_len: int,
598598
is_gen_only: bool) -> None:
599599

600600
if self.args.backend in ["pytorch", "_autodeploy"]:
601-
if sampling_params.logprobs and sampling_params.logprobs > 1:
602-
raise ValueError(
603-
f"PyTorch backend currently only supports `logprobs=1`. Received `logprobs={sampling_params.logprobs}` (Top{sampling_params.logprobs} logprobs). Please set `logprobs=1` in `sampling_params` instead."
604-
)
605601
# Check prompt length and query length against max_num_tokens to filter illegal requests.
606602
# Skip check for gen-only requests
607603
if self.args.backend == "pytorch" and not self.args.enable_chunked_prefill and not is_gen_only:

tensorrt_llm/scaffolding/worker.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,8 @@ def convert_task_params(self, task: GenerationTask):
180180
temperature=task.temperature,
181181
top_p=task.top_p,
182182
top_k=task.top_k,
183-
return_context_logits=task.return_context_logits)
183+
return_context_logits=task.return_context_logits,
184+
logprobs=task.num_logprobs)
184185
return sampling_params
185186

186187
async def generation_handler(self, task: GenerationTask) -> TaskStatus:

tests/unittest/llmapi/test_llm.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1862,6 +1862,18 @@ def llm_return_logprobs_test_harness(prompt_logprobs: Optional[int],
18621862
logprobs_result[0].keys()) in {logprobs, logprobs + 1}
18631863
# Most contain log prob of the sample token, even if it's not within K
18641864
assert token_ids[0] in logprobs_result[0].keys()
1865+
for step_logprobs in logprobs_result:
1866+
assert len(step_logprobs) == logprobs
1867+
logprob_items = [(logprob_obj.logprob, logprob_obj.rank)
1868+
for logprob_obj in step_logprobs.values()]
1869+
sorted_by_rank = sorted(logprob_items, key=lambda x: x[1])
1870+
1871+
for i in range(logprobs - 1):
1872+
current_logprob, current_rank = sorted_by_rank[i]
1873+
next_logprob, next_rank = sorted_by_rank[i + 1]
1874+
assert current_logprob >= next_logprob
1875+
assert current_rank == i + 1
1876+
assert next_rank == current_rank + 1
18651877
print("logprobs[0]: ", logprobs_result[0])
18661878

18671879
if streaming:

tests/unittest/llmapi/test_llm_pytorch.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -915,6 +915,7 @@ def test_llm_return_logprobs(prompt_logprobs: Optional[int],
915915
(2, None, True, False), # prompt_logprobs with context_logits
916916
(2, None, False, False), # prompt_logprobs only
917917
(2, 1, False, False), # both prompt and generation logprobs
918+
(2, 3, False, False), # both prompt and generation logprobs
918919
])
919920
def test_llm_return_logprobs_streaming(prompt_logprobs, logprobs,
920921
return_context_logits,

0 commit comments

Comments
 (0)