Skip to content

Commit 2485128

Browse files
committed
sampled logprob
fix step remove dependency on return_genereation_logits align API across backends, add tests test fix group sampling strategy
1 parent c9771eb commit 2485128

File tree

9 files changed

+366
-57
lines changed

9 files changed

+366
-57
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1022,7 +1022,7 @@ common-files: &common_files |
10221022
tests/unittest/_torch/ray_orchestrator/single_gpu/test_cache_transceiver_comm.py |
10231023
tests/unittest/_torch/sampler/test_beam_search.py |
10241024
tests/unittest/_torch/sampler/test_best_of_n.py |
1025-
tests/unittest/_torch/sampler/test_return_logits.py |
1025+
tests/unittest/_torch/sampler/test_logits_logprobs.py |
10261026
tests/unittest/_torch/sampler/test_torch_multi_arange.py |
10271027
tests/unittest/_torch/sampler/test_trtllm_sampler.py |
10281028
tests/unittest/_torch/speculative/test_draft_target.py |

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1062,7 +1062,7 @@ exclude = [
10621062
"tests/unittest/_torch/ray_orchestrator/single_gpu/test_cache_transceiver_comm.py",
10631063
"tests/unittest/_torch/sampler/test_beam_search.py",
10641064
"tests/unittest/_torch/sampler/test_best_of_n.py",
1065-
"tests/unittest/_torch/sampler/test_return_logits.py",
1065+
"tests/unittest/_torch/sampler/test_logits_logprobs.py",
10661066
"tests/unittest/_torch/sampler/test_torch_multi_arange.py",
10671067
"tests/unittest/_torch/sampler/test_trtllm_sampler.py",
10681068
"tests/unittest/_torch/speculative/test_draft_target.py",

tensorrt_llm/_torch/pyexecutor/sampler.py

Lines changed: 150 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -986,18 +986,66 @@ def handle_logprobs(
986986
topk_log_probs_indices = self.store.new_tokens[0, request.py_seq_slot].view(
987987
beam_width, count, -1
988988
)
989+
token_log_probs = self._convert_logprobs_tensor_to_list(
990+
topk_log_probs_indices, topk_log_probs_vals
991+
)
989992
else:
990993
assert beam_width == 1, "beam width must be 1 for non-beam search"
991-
topk_log_probs_vals = request.py_topk_logprobs_vals[: count * beam_width].view(
992-
beam_width, count, -1
993-
)
994-
topk_log_probs_indices = request.py_topk_logprobs_indices[
995-
: count * beam_width
996-
].view(beam_width, count, -1)
994+
995+
sampled_tokens = request.get_tokens(0)[-count:]
996+
997+
if request.py_num_logprobs == 0:
998+
# Return only the sampled token's logprob
999+
# Compute at least top-1 to determine rank
1000+
if hasattr(request, 'py_sampled_logprobs') and request.py_sampled_logprobs is not None:
1001+
sampled_logprobs = request.py_sampled_logprobs[:count]
1002+
topk_log_probs_vals = request.py_topk_logprobs_vals[:count] # At least k=1
1003+
topk_log_probs_indices = request.py_topk_logprobs_indices[:count]
1004+
1005+
token_log_probs = []
1006+
for step, (sampled_token, sampled_logprob, topk_tokens, topk_logprobs) in enumerate(
1007+
zip(sampled_tokens, sampled_logprobs, topk_log_probs_indices, topk_log_probs_vals)
1008+
):
1009+
topk_tokens_list = topk_tokens.tolist()
1010+
if sampled_token in topk_tokens_list:
1011+
# Sampled token is in top-K, use its rank
1012+
rank = topk_tokens_list.index(sampled_token) + 1
1013+
else:
1014+
# TODO: fix rank
1015+
rank = 2
9971016

998-
token_log_probs = self._convert_logprobs_tensor_to_list(
999-
topk_log_probs_indices, topk_log_probs_vals
1000-
)
1017+
step_dict = {sampled_token: Logprob(logprob=sampled_logprob.item(), rank=rank)}
1018+
token_log_probs.append(step_dict)
1019+
else:
1020+
raise ValueError("py_sampled_logprobs not available when py_num_logprobs == 0")
1021+
else:
1022+
# Return top-K logprobs + logprob of sampled token
1023+
sampled_logprobs = request.py_sampled_logprobs[:count]
1024+
topk_log_probs_vals = request.py_topk_logprobs_vals[:count]
1025+
topk_log_probs_indices = request.py_topk_logprobs_indices[:count]
1026+
1027+
token_log_probs = []
1028+
for step, (sampled_token, sampled_logprob, topk_tokens, topk_logprobs) in enumerate(
1029+
zip(sampled_tokens, sampled_logprobs, topk_log_probs_indices, topk_log_probs_vals)
1030+
):
1031+
step_dict = {}
1032+
topk_tokens_list = topk_tokens.tolist()
1033+
topk_logprobs_list = topk_logprobs.tolist()
1034+
1035+
for rank_idx, (token, logprob) in enumerate(zip(topk_tokens_list, topk_logprobs_list), start=1):
1036+
step_dict[token] = Logprob(logprob=logprob, rank=rank_idx)
1037+
1038+
if sampled_token not in step_dict:
1039+
# TODO: fix rank
1040+
step_dict[sampled_token] = Logprob(
1041+
logprob=sampled_logprob.item(),
1042+
rank=len(topk_tokens_list) + 1
1043+
)
1044+
token_log_probs.append(step_dict)
1045+
1046+
# Wrap in list for non-beam search (beam_width=1)
1047+
token_log_probs = [token_log_probs]
1048+
10011049
request.py_result.append_log_probs(token_log_probs)
10021050

10031051
def finish_if_reason(
@@ -2461,47 +2509,55 @@ def _process_requests(
24612509
assert logits_cuda.dim() == 2, "logits should be 2D"
24622510

24632511
logprobs_req_indices = [
2464-
req_id for req_id, req in enumerate(requests) if req.py_num_logprobs
2512+
req_id for req_id, req in enumerate(requests) if req.py_num_logprobs is not None
24652513
]
2466-
logprobs_logit_indices = logits_cuda_indexer[logprobs_req_indices]
2467-
logprobs_logit_indices_cuda = logprobs_logit_indices.to(
2468-
device=logits_cuda.device, non_blocking=True
2469-
)
2470-
logprobs_cuda = F.log_softmax(
2471-
logits_cuda[logprobs_logit_indices_cuda].to(dtype=torch.float32, non_blocking=True),
2472-
dim=-1,
2473-
)
2474-
topk_vals_cuda, topk_indices_cuda = torch.topk(
2475-
logprobs_cuda, k=max(req.py_num_logprobs for req in requests), dim=-1
2476-
)
2477-
# Use a single D2H copy to reduce overheads
2478-
topk_vals = torch.empty_like(topk_vals_cuda, device="cpu", pin_memory=True)
2479-
topk_indices = torch.empty_like(topk_indices_cuda, device="cpu", pin_memory=True)
2480-
topk_vals.copy_(topk_vals_cuda, non_blocking=True)
2481-
topk_indices.copy_(topk_indices_cuda, non_blocking=True)
2482-
current_offset = 0
2483-
for req_id, steps in zip(
2484-
logprobs_req_indices, req_num_generated_tokens[logprobs_req_indices].tolist()
2485-
):
2486-
req = requests[req_id]
2487-
next_offset = current_offset + steps
2488-
# NB: Assigning views on memory which is being filled asynchronously
2489-
req.py_topk_logprobs_vals = topk_vals[
2490-
current_offset:next_offset, : req.py_num_logprobs
2491-
]
2492-
req.py_topk_logprobs_indices = topk_indices[
2493-
current_offset:next_offset, : req.py_num_logprobs
2494-
]
24952514

2496-
# context requests do not have multiple input beams, but they need multiple output beams
2497-
if req.is_context_init_state:
2498-
req.py_topk_logprobs_vals = req.py_topk_logprobs_vals.expand(
2499-
req.sampling_config.beam_width, -1
2500-
)
2501-
req.py_topk_logprobs_indices = req.py_topk_logprobs_indices.expand(
2502-
req.sampling_config.beam_width, -1
2503-
)
2504-
current_offset = next_offset
2515+
if logprobs_req_indices:
2516+
logprobs_logit_indices = logits_cuda_indexer[logprobs_req_indices]
2517+
logprobs_logit_indices_cuda = logprobs_logit_indices.to(
2518+
device=logits_cuda.device, non_blocking=True
2519+
)
2520+
logprobs_cuda = F.log_softmax(
2521+
logits_cuda[logprobs_logit_indices_cuda].to(dtype=torch.float32, non_blocking=True),
2522+
dim=-1,
2523+
)
2524+
2525+
max_k = max(max(1, req.py_num_logprobs) for req in requests if req.py_num_logprobs is not None)
2526+
topk_vals_cuda, topk_indices_cuda = torch.topk(
2527+
logprobs_cuda,
2528+
k=max_k,
2529+
dim=-1
2530+
)
2531+
# Use a single D2H copy to reduce overheads
2532+
topk_vals = torch.empty_like(topk_vals_cuda, device="cpu", pin_memory=True)
2533+
topk_indices = torch.empty_like(topk_indices_cuda, device="cpu", pin_memory=True)
2534+
topk_vals.copy_(topk_vals_cuda, non_blocking=True)
2535+
topk_indices.copy_(topk_indices_cuda, non_blocking=True)
2536+
current_offset = 0
2537+
for req_id, steps in zip(
2538+
logprobs_req_indices, req_num_generated_tokens[logprobs_req_indices].tolist()
2539+
):
2540+
req = requests[req_id]
2541+
next_offset = current_offset + steps
2542+
# Store at least k=1 for all requests (including logprobs=0) to compute ranks
2543+
k_for_req = max(1, req.py_num_logprobs)
2544+
# NB: Assigning views on memory which is being filled asynchronously
2545+
req.py_topk_logprobs_vals = topk_vals[
2546+
current_offset:next_offset, : k_for_req
2547+
]
2548+
req.py_topk_logprobs_indices = topk_indices[
2549+
current_offset:next_offset, : k_for_req
2550+
]
2551+
2552+
# context requests do not have multiple input beams, but they need multiple output beams
2553+
if req.is_context_init_state:
2554+
req.py_topk_logprobs_vals = req.py_topk_logprobs_vals.expand(
2555+
req.sampling_config.beam_width, -1
2556+
)
2557+
req.py_topk_logprobs_indices = req.py_topk_logprobs_indices.expand(
2558+
req.sampling_config.beam_width, -1
2559+
)
2560+
current_offset = next_offset
25052561

25062562
# Perform sampling in batches
25072563
batched_sampling_result = self._sample_batched_by_strategy(
@@ -2517,6 +2573,52 @@ def _process_requests(
25172573
token_dtype=new_tokens_cuda.dtype,
25182574
)
25192575

2576+
if return_log_probs and logprobs_req_indices:
2577+
sampled_tokens_cuda = batched_sampling_result.batch_next_tokens_cuda_int
2578+
batch_req_indices = batched_sampling_result.batch_req_indices
2579+
logprobs_req_set = set(logprobs_req_indices)
2580+
sampled_logprobs_list = []
2581+
2582+
# Build offsets for the GROUPED order
2583+
grouped_num_steps = req_num_steps[batch_req_indices]
2584+
grouped_offsets = torch.cat([
2585+
torch.zeros((1,), dtype=torch.int32, pin_memory=True),
2586+
grouped_num_steps.cumsum(dim=0)[:-1]
2587+
])
2588+
2589+
# Reverse mapping: original_req_id → position in grouped result
2590+
req_to_grouped_pos = {
2591+
orig_id.item(): grouped_pos
2592+
for grouped_pos, orig_id in enumerate(batch_req_indices)
2593+
}
2594+
2595+
for req_id in range(len(requests)):
2596+
if req_id in logprobs_req_set:
2597+
logprobs_idx = logprobs_req_indices.index(req_id)
2598+
2599+
if logprobs_idx == 0:
2600+
start_offset = 0
2601+
else:
2602+
start_offset = sum(req_num_steps[logprobs_req_indices[:logprobs_idx]].tolist())
2603+
2604+
num_steps_this_req = req_num_steps[req_id].item()
2605+
end_offset = start_offset + num_steps_this_req
2606+
2607+
grouped_pos = req_to_grouped_pos[req_id]
2608+
grouped_start = grouped_offsets[grouped_pos].item()
2609+
grouped_end = grouped_start + grouped_num_steps[grouped_pos].item()
2610+
2611+
sampled_tokens_this_req = sampled_tokens_cuda[grouped_start:grouped_end]
2612+
2613+
step_indices = torch.arange(start_offset, end_offset, device=logprobs_cuda.device)
2614+
sampled_logprobs_cuda = logprobs_cuda[step_indices, sampled_tokens_this_req.long()]
2615+
2616+
sampled_logprobs_cpu = sampled_logprobs_cuda.to(device="cpu", non_blocking=True)
2617+
sampled_logprobs_list.append((req_id, sampled_logprobs_cpu))
2618+
2619+
for req_id, sampled_logprobs in sampled_logprobs_list:
2620+
requests[req_id].py_sampled_logprobs = sampled_logprobs
2621+
25202622
# Fill results into output buffers
25212623
new_tokens_host = self._unbatch_sampling_results(
25222624
batched_sampling_result,

tensorrt_llm/executor/executor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -224,7 +224,7 @@ def _get_logprob_params(
224224
self, request: GenerationRequest) -> Optional[LogprobParams]:
225225
"""Store logprobs-related fields from request for the later logprob calculation."""
226226
logprob_params = None
227-
if request.sampling_params.logprobs or request.sampling_params.prompt_logprobs:
227+
if request.sampling_params.logprobs is not None or request.sampling_params.prompt_logprobs:
228228
logprob_params = LogprobParams(
229229
logprobs=request.sampling_params.logprobs,
230230
prompt_logprobs=request.sampling_params.prompt_logprobs,

tensorrt_llm/executor/result.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1033,6 +1033,21 @@ def _topk_logprobs(logits: torch.Tensor, top_k: int,
10331033
logits = logits[:len(tokens)]
10341034

10351035
logprobs = F.log_softmax(logits.to("cuda", dtype=torch.float32), dim=-1)
1036+
1037+
# only return sampled token
1038+
if top_k == 0:
1039+
results: TokenLogprobs = []
1040+
if tokens is not None:
1041+
for t in range(logprobs.size(0)):
1042+
token_id = tokens[t]
1043+
token_logprob = logprobs[t, token_id].item()
1044+
rank = (logprobs[t] > token_logprob).sum().item() + 1
1045+
token_dict = {
1046+
token_id: Logprob(logprob=token_logprob, rank=rank)
1047+
}
1048+
results.append(token_dict)
1049+
return results
1050+
10361051
topk_vals, topk_indices = torch.topk(logprobs, k=top_k, dim=-1)
10371052

10381053
results: TokenLogprobs = []
@@ -1061,7 +1076,7 @@ def _topk_logprobs(logits: torch.Tensor, top_k: int,
10611076
None) if k_prompt_logprobs and context_logits is not None else None
10621077
generation_logprobs = _topk_logprobs(
10631078
generation_logits, k_logprobs, output_token_ids
1064-
) if k_logprobs and generation_logits is not None else None
1079+
) if k_logprobs is not None and generation_logits is not None else None
10651080

10661081
return LogProbsResult(prompt=prompt_logprobs,
10671082
generation=generation_logprobs)

tensorrt_llm/llmapi/llm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -632,7 +632,7 @@ def _prepare_sampling_params(
632632
if sampling_params.prompt_logprobs and not sampling_params.return_context_logits:
633633
sampling_params.return_context_logits = True
634634
sampling_params._context_logits_auto_enabled = True
635-
if sampling_params.logprobs and not sampling_params.return_generation_logits:
635+
if sampling_params.logprobs is not None and not sampling_params.return_generation_logits:
636636
sampling_params.return_generation_logits = True
637637
sampling_params._generation_logits_auto_enabled = True
638638

@@ -703,7 +703,7 @@ def _check_arguments(self, prompt_len: int, query_len: int,
703703
f"Example: LLM(..., build_config=BuildConfig(gather_context_logits=True))."
704704
)
705705

706-
if sampling_params.logprobs and not self.args.gather_generation_logits:
706+
if sampling_params.logprobs is not None and not self.args.gather_generation_logits:
707707
raise ValueError(
708708
f"`sampling_params.logprobs={sampling_params.logprobs}` requires `gather_generation_logits=True` "
709709
f"to be passed explicitly to the `LLM()` constructor.")

tensorrt_llm/sampling_params.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,8 @@ class SamplingParams:
172172
min_p (float, optional): scale the most likely token to determine the minimum token probability. None means using C++ runtime default 0.0. Defaults to None.
173173
beam_width_array (List[int], optional): The array of beam width using in Variable-Beam-Width-Search. Defaults to None.
174174
175-
logprobs (int, optional): Number of log probabilities to return per output token. Defaults to None.
175+
logprobs (int, optional): Number of log probabilities to return per output token. When set to 0, return only the sampled token's log probability.
176+
When set to K>0, return top-K log probabilities + the sampled token's log probability (last entry) if it's not in the Top-K. Defaults to None.
176177
prompt_logprobs (int, optional): Number of log probabilities to return per prompt token. Defaults to None.
177178
return_context_logits (bool): Controls if Result should contain the context logits. Defaults to False.
178179
return_generation_logits (bool): Controls if Result should contain the generation logits. Defaults to False.
@@ -501,7 +502,7 @@ def _get_output_config(self, is_pytorch_backend: bool = False) -> tllme.OutputCo
501502
config_kwargs = {f: getattr(self, f) for f in fields}
502503

503504
if is_pytorch_backend:
504-
config_kwargs["return_log_probs"] = bool(self.logprobs)
505+
config_kwargs["return_log_probs"] = self.logprobs is not None
505506
if self.prompt_logprobs and not self.return_context_logits:
506507
logger.info(
507508
"Since prompt_logprobs is requested but return_context_logits is False, "

tests/integration/test_lists/test-db/l0_a30.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ l0_a30:
2222
- unittest/_torch/modeling -k "modeling_starcoder2"
2323
- unittest/_torch/auto_deploy/unit/singlegpu
2424
- unittest/_torch/sampler/test_beam_search.py
25-
- unittest/_torch/sampler/test_return_logits.py
25+
- unittest/_torch/sampler/test_logits_logprobs.py
2626
- test_e2e.py::test_openai_completions_with_logit_bias[torch_sampler]
2727
- test_e2e.py::test_openai_chat_with_logit_bias[torch_sampler]
2828
- test_e2e.py::test_openai_completions_with_logit_bias[trtllm_sampler]

0 commit comments

Comments
 (0)