Skip to content

Commit 0b21e18

Browse files
committed
[None][feat] Add processed logprobs functionality to TorchSampler
- Iintroduces a new optional parameter, logprobs_mode, to the SamplingParams and LlmRequest classes, allowing users to specify the mode of log probabilities to return. - Create process_logprobs function to remove logprobs processing code from process_requests. - add batching based on logprobs_mode to sample_batched_by_strategy - additionally return processed logits from sampling Signed-off-by: Stefan Niebler <[email protected]>
1 parent 09beaa5 commit 0b21e18

File tree

5 files changed

+161
-68
lines changed

5 files changed

+161
-68
lines changed

tensorrt_llm/_torch/pyexecutor/llm_request.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -460,6 +460,7 @@ def __init__(
460460
is_first_draft: bool = False,
461461
use_chunked_generation_logits: bool = True,
462462
logits_chunk_size: int = 8,
463+
logprobs_mode: str = "raw",
463464
**kwargs):
464465

465466
self.py_logits_post_processors = kwargs.pop("py_logits_post_processors",
@@ -538,6 +539,8 @@ def __init__(
538539
# currently, keep py_stop_words_list as python list, rather than tensor.
539540
self.py_stop_words_list = stop_words_list
540541

542+
self.py_logprobs_mode = logprobs_mode
543+
541544
self.py_result = PyResult(
542545
prompt_len=self.py_prompt_len,
543546
max_new_tokens=self.py_max_new_tokens,
@@ -797,7 +800,9 @@ def executor_request_to_llm_request(
797800
arrival_time=getattr(executor_request, "py_arrival_time", None),
798801
py_multimodal_data=getattr(executor_request, "py_multimodal_data",
799802
None),
800-
kv_cache_retention_config=executor_request.kv_cache_retention_config)
803+
kv_cache_retention_config=executor_request.kv_cache_retention_config,
804+
logprobs_mode=getattr(executor_request, "py_logprobs_mode", "raw"),
805+
)
801806
if child_req_ids:
802807
for child_id in child_req_ids:
803808
llm_request.create_child_request(child_id)

tensorrt_llm/_torch/pyexecutor/sampler.py

Lines changed: 143 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -202,12 +202,13 @@ class SampleStateWithMMResult:
202202
class RequestGroupKey(Generic[GenericStrategyKeyType]):
203203
strategy: GenericStrategyKeyType
204204
speculation_needs_probs: bool
205+
need_processed_logprobs: bool
205206

206207
def __iter__(self):
207-
return iter((self.strategy, self.speculation_needs_probs))
208+
return iter((self.strategy, self.speculation_needs_probs, self.need_processed_logprobs))
208209

209210
def __len__(self):
210-
return 2
211+
return 3
211212

212213

213214
class RequestGroupValue(NamedTuple):
@@ -338,13 +339,19 @@ def _group_requests_by_strategy_key(
338339
# process_draft_tokens.
339340
TorchSampler._speculation_could_use_rejection_sampling(req, strategy)
340341
)
341-
strategy_key = strategy_to_key(strategy, speculation_needs_probs)
342-
group_dict_entry = group_dict[(strategy_key, speculation_needs_probs)]
342+
need_processed_logprobs = req.py_logprobs_mode == "processed"
343+
need_probs = speculation_needs_probs or need_processed_logprobs
344+
strategy_key = strategy_to_key(strategy, need_probs)
345+
group_dict_entry = group_dict[
346+
(strategy_key, speculation_needs_probs, need_processed_logprobs)
347+
]
343348
group_dict_entry[0].append(req_index)
344349
group_dict_entry[1].append(strategy)
345350
return {
346351
RequestGroupKey(
347-
strategy=group_key[0], speculation_needs_probs=group_key[1]
352+
strategy=group_key[0],
353+
speculation_needs_probs=group_key[1],
354+
need_processed_logprobs=group_key[2],
348355
): RequestGroupValue(
349356
indices=torch.tensor(indices, pin_memory=pin_memory, dtype=torch.int32),
350357
strategies=strategies,
@@ -374,6 +381,8 @@ class _BatchedSamplingResult:
374381
batch_req_indices: torch.Tensor
375382
# Next tokens for all requests:
376383
batch_next_tokens_cuda_int: torch.Tensor
384+
# Logits for all requests:
385+
batch_logits_cuda: torch.Tensor | None = None
377386

378387

379388
# Helper class for _PackedStepIndexer and _UnpackedStepIndexer, facilitating the
@@ -942,34 +951,55 @@ def _convert_logprobs_tensor_to_list(
942951
self,
943952
token_tensor: torch.Tensor,
944953
logprobs_tensor: torch.Tensor,
954+
sampled_log_probs_indices: torch.Tensor | None,
955+
sampled_log_probs_vals: torch.Tensor | None,
956+
sampled_log_probs_rank: torch.Tensor | None,
945957
) -> list[list[dict[int, Logprob]]]:
946958
"""Convert the logprobs tensor to a list of lists of dictionaries of Logprob objects
947959
948960
Logprobs storage expects logprobs as a list[list[dict[int, Logprob]]] object
949961
950962
args:
963+
token_tensor: torch.Tensor. Shape: beam_width, num_tokens, num_logprobs
951964
logprobs_tensor: torch.Tensor. Shape: beam_width, num_tokens, num_logprobs
965+
sampled_log_probs_indices: torch.Tensor | None. Shape: num_tokens
966+
sampled_log_probs_vals: torch.Tensor | None. Shape: num_tokens
967+
sampled_log_probs_rank: torch.Tensor | None. Shape: num_tokens
952968
output:
953969
list[list[dict[int, Logprob]]]. Shape: beam_width, num_tokens, dict with num_logprobs keys
954970
"""
955971
assert token_tensor.dim() == 3 and logprobs_tensor.dim() == 3, (
956972
f"Token and logprobs tensors must have 3 dimensions (beam_width, num_tokens, num_logprobs). \
957973
Got shapes (token_tensor) {token_tensor.shape} and (logprobs_tensor) {logprobs_tensor.shape} instead"
958974
)
959-
return [
960-
[
961-
{
975+
976+
token_log_probs: list[list[dict[int, Logprob]]] = []
977+
for beam_idx in range(token_tensor.shape[0]):
978+
beam_token_log_probs: list[dict[int, Logprob]] = []
979+
for step_idx, (topk_token, topk_logprob) in enumerate(
980+
zip(token_tensor[beam_idx], logprobs_tensor[beam_idx])
981+
):
982+
logprobs = {
962983
token: Logprob(logprob=logprob, rank=rank + 1)
963984
for rank, (token, logprob) in enumerate(
964985
zip(topk_token.tolist(), topk_logprob.tolist())
965986
)
966987
}
967-
for topk_token, topk_logprob in zip(
968-
token_tensor[beam_idx], logprobs_tensor[beam_idx]
969-
)
970-
]
971-
for beam_idx in range(token_tensor.shape[0])
972-
]
988+
if sampled_log_probs_indices is not None:
989+
assert beam_idx == 0, (
990+
"beam search does not need to explicitly handle sampled log probs"
991+
)
992+
if sampled_log_probs_indices[step_idx] not in logprobs:
993+
logprobs[sampled_log_probs_indices[step_idx].item()] = Logprob(
994+
logprob=sampled_log_probs_vals[step_idx].item(),
995+
rank=max(
996+
token_tensor.shape[2] + 1, sampled_log_probs_rank[step_idx].item()
997+
),
998+
)
999+
beam_token_log_probs.append(logprobs)
1000+
token_log_probs.append(beam_token_log_probs)
1001+
1002+
return token_log_probs
9731003

9741004
def handle_logprobs(
9751005
self,
@@ -986,6 +1016,10 @@ def handle_logprobs(
9861016
topk_log_probs_indices = self.store.new_tokens[0, request.py_seq_slot].view(
9871017
beam_width, count, -1
9881018
)
1019+
sampled_log_probs_vals = None
1020+
sampled_log_probs_indices = None
1021+
# correct the rank to be 1-indexed
1022+
sampled_log_probs_rank = None
9891023
else:
9901024
assert beam_width == 1, "beam width must be 1 for non-beam search"
9911025
topk_log_probs_vals = request.py_topk_logprobs_vals[: count * beam_width].view(
@@ -994,9 +1028,17 @@ def handle_logprobs(
9941028
topk_log_probs_indices = request.py_topk_logprobs_indices[
9951029
: count * beam_width
9961030
].view(beam_width, count, -1)
1031+
sampled_log_probs_vals = request.py_sampled_logprobs_vals[:count]
1032+
sampled_log_probs_indices = request.py_sampled_logprobs_indices[:count]
1033+
# correct the rank to be 1-indexed
1034+
sampled_log_probs_rank = request.py_sampled_logprobs_rank[:count] + 1
9971035

9981036
token_log_probs = self._convert_logprobs_tensor_to_list(
999-
topk_log_probs_indices, topk_log_probs_vals
1037+
topk_log_probs_indices,
1038+
topk_log_probs_vals,
1039+
sampled_log_probs_indices,
1040+
sampled_log_probs_vals,
1041+
sampled_log_probs_rank,
10001042
)
10011043
request.py_result.append_log_probs(token_log_probs)
10021044

@@ -1865,6 +1907,7 @@ def _sample_batched_by_strategy(
18651907
seq_slots: torch.Tensor,
18661908
seq_lens: Optional[torch.Tensor] = None,
18671909
token_dtype: torch.dtype,
1910+
return_log_probs: bool,
18681911
) -> _BatchedSamplingResult:
18691912
grouped_requests = _group_requests_by_strategy_key(
18701913
requests,
@@ -1894,9 +1937,16 @@ def _sample_batched_by_strategy(
18941937
batch_next_tokens_cuda_int = torch.empty(
18951938
(logits_cuda.size(0), self.max_beam_width), device=cuda_device, dtype=token_dtype
18961939
)
1940+
batch_logits_cuda = (
1941+
torch.empty(
1942+
(logits_cuda.size(0), logits_cuda.size(1)), device=cuda_device, dtype=torch.float32
1943+
)
1944+
if return_log_probs
1945+
else None
1946+
)
18971947
batch_req_idx_offset_start = 0
18981948
batch_next_tokens_offset_start = 0
1899-
for (strategy_key, speculation_needs_probs), (
1949+
for (strategy_key, speculation_needs_probs, need_processed_logprobs), (
19001950
group_req_indices,
19011951
group_strategies,
19021952
group_metadata,
@@ -1943,7 +1993,7 @@ def _sample_batched_by_strategy(
19431993
group_strategies_per_step,
19441994
group_logits_cuda,
19451995
generator=generator_cuda,
1946-
return_probs=speculation_needs_probs,
1996+
return_probs=speculation_needs_probs or need_processed_logprobs,
19471997
group_logit_indices=logit_indices_for_sampler,
19481998
group_metadata=group_metadata,
19491999
)
@@ -1958,6 +2008,20 @@ def _sample_batched_by_strategy(
19582008
batch_next_tokens_offset_start:batch_next_tokens_offset_end
19592009
].copy_(group_next_tokens_cuda, non_blocking=True)
19602010

2011+
if return_log_probs:
2012+
if need_processed_logprobs:
2013+
# if softmax is 0, then the logit was masked out => set to -inf
2014+
group_tgt_logits_cuda = torch.where(
2015+
group_softmax_cuda != 0, group_logits_cuda, float("-inf")
2016+
)
2017+
batch_logits_cuda[
2018+
batch_next_tokens_offset_start:batch_next_tokens_offset_end
2019+
].copy_(group_tgt_logits_cuda, non_blocking=True)
2020+
else:
2021+
batch_logits_cuda[
2022+
batch_next_tokens_offset_start:batch_next_tokens_offset_end
2023+
].copy_(group_logits_cuda, non_blocking=True)
2024+
19612025
# Set LlmRequest.py_target_probs
19622026
if speculation_needs_probs:
19632027
assert group_softmax_cuda is not None
@@ -1986,6 +2050,7 @@ def _sample_batched_by_strategy(
19862050
return _BatchedSamplingResult(
19872051
batch_req_indices=batch_req_indices,
19882052
batch_next_tokens_cuda_int=batch_next_tokens_cuda_int,
2053+
batch_logits_cuda=batch_logits_cuda,
19892054
)
19902055

19912056
def _unbatch_sampling_results(
@@ -2385,6 +2450,63 @@ def request_stop_words(request: LlmRequest, new_tokens: torch.Tensor):
23852450
per_step[step, request_idx, beam_idx] = True
23862451
return per_step
23872452

2453+
@nvtx_range("_process_logprobs")
2454+
def _process_logprobs(
2455+
self,
2456+
batched_sampling_result: _BatchedSamplingResult,
2457+
requests: list[LlmRequest],
2458+
req_num_steps: torch.Tensor,
2459+
):
2460+
group_logprobs_cuda = F.log_softmax(batched_sampling_result.batch_logits_cuda, dim=-1)
2461+
all_req_indices = batched_sampling_result.batch_req_indices
2462+
group_next_tokens_cuda = batched_sampling_result.batch_next_tokens_cuda_int
2463+
group_req_indices = [
2464+
req_gid.item()
2465+
for req_gid in all_req_indices
2466+
if requests[req_gid].py_num_logprobs is not None
2467+
]
2468+
topk_vals_cuda, topk_indices_cuda = torch.topk(
2469+
group_logprobs_cuda,
2470+
k=max(requests[req_id].py_num_logprobs for req_id in group_req_indices),
2471+
dim=-1,
2472+
)
2473+
2474+
sampled_vals_cuda = torch.gather(
2475+
group_logprobs_cuda, dim=-1, index=group_next_tokens_cuda.view(-1, 1)
2476+
)
2477+
sampled_indices_cuda = group_next_tokens_cuda
2478+
2479+
# NB: we do not need group logprobs anymore, we can reuse the storage
2480+
# We only provide 0 based rank, it will be corrected to 1-indexed in handle logprobs
2481+
group_logprobs_cuda.greater_(sampled_vals_cuda)
2482+
sampled_rank_cuda = group_logprobs_cuda.sum(dim=-1)
2483+
2484+
# Use a single D2H copy to reduce overheads
2485+
topk_vals = torch.empty_like(topk_vals_cuda, device="cpu", pin_memory=False)
2486+
topk_indices = torch.empty_like(topk_indices_cuda, device="cpu", pin_memory=False)
2487+
sampled_vals = torch.empty_like(sampled_vals_cuda, device="cpu", pin_memory=False)
2488+
sampled_indices = torch.empty_like(sampled_indices_cuda, device="cpu", pin_memory=False)
2489+
sampled_rank = torch.empty_like(sampled_rank_cuda, device="cpu", pin_memory=False)
2490+
2491+
topk_vals.copy_(topk_vals_cuda, non_blocking=True)
2492+
topk_indices.copy_(topk_indices_cuda, non_blocking=True)
2493+
sampled_vals.copy_(sampled_vals_cuda, non_blocking=True)
2494+
sampled_indices.copy_(sampled_indices_cuda, non_blocking=True)
2495+
sampled_rank.copy_(sampled_rank_cuda, non_blocking=True)
2496+
current_offset = 0
2497+
for req_id, steps in zip(group_req_indices, req_num_steps[group_req_indices].tolist()):
2498+
req = requests[req_id]
2499+
next_offset = current_offset + steps
2500+
# NB: Assigning views on memory which is being filled asynchronously
2501+
req.py_topk_logprobs_vals = topk_vals[current_offset:next_offset, : req.py_num_logprobs]
2502+
req.py_sampled_logprobs_vals = sampled_vals[current_offset:next_offset]
2503+
req.py_topk_logprobs_indices = topk_indices[
2504+
current_offset:next_offset, : req.py_num_logprobs
2505+
]
2506+
req.py_sampled_logprobs_indices = sampled_indices[current_offset:next_offset]
2507+
req.py_sampled_logprobs_rank = sampled_rank[current_offset:next_offset]
2508+
current_offset = next_offset
2509+
23882510
@nvtx_range("_process_requests")
23892511
def _process_requests(
23902512
self,
@@ -2454,55 +2576,6 @@ def _process_requests(
24542576
req_offsets=req_offsets,
24552577
)
24562578

2457-
# Handle top-k logprobs. This is done outside the sampling loop,
2458-
# because the returned logprobs are specified to not reflect temperature scaling,
2459-
# top-k/top-p masking, etc.
2460-
if return_log_probs:
2461-
assert logits_cuda.dim() == 2, "logits should be 2D"
2462-
2463-
logprobs_req_indices = [
2464-
req_id for req_id, req in enumerate(requests) if req.py_num_logprobs
2465-
]
2466-
logprobs_logit_indices = logits_cuda_indexer[logprobs_req_indices]
2467-
logprobs_logit_indices_cuda = logprobs_logit_indices.to(
2468-
device=logits_cuda.device, non_blocking=True
2469-
)
2470-
logprobs_cuda = F.log_softmax(
2471-
logits_cuda[logprobs_logit_indices_cuda].to(dtype=torch.float32, non_blocking=True),
2472-
dim=-1,
2473-
)
2474-
topk_vals_cuda, topk_indices_cuda = torch.topk(
2475-
logprobs_cuda, k=max(req.py_num_logprobs for req in requests), dim=-1
2476-
)
2477-
# Use a single D2H copy to reduce overheads
2478-
topk_vals = torch.empty_like(topk_vals_cuda, device="cpu", pin_memory=True)
2479-
topk_indices = torch.empty_like(topk_indices_cuda, device="cpu", pin_memory=True)
2480-
topk_vals.copy_(topk_vals_cuda, non_blocking=True)
2481-
topk_indices.copy_(topk_indices_cuda, non_blocking=True)
2482-
current_offset = 0
2483-
for req_id, steps in zip(
2484-
logprobs_req_indices, req_num_generated_tokens[logprobs_req_indices].tolist()
2485-
):
2486-
req = requests[req_id]
2487-
next_offset = current_offset + steps
2488-
# NB: Assigning views on memory which is being filled asynchronously
2489-
req.py_topk_logprobs_vals = topk_vals[
2490-
current_offset:next_offset, : req.py_num_logprobs
2491-
]
2492-
req.py_topk_logprobs_indices = topk_indices[
2493-
current_offset:next_offset, : req.py_num_logprobs
2494-
]
2495-
2496-
# context requests do not have multiple input beams, but they need multiple output beams
2497-
if req.is_context_init_state:
2498-
req.py_topk_logprobs_vals = req.py_topk_logprobs_vals.expand(
2499-
req.sampling_config.beam_width, -1
2500-
)
2501-
req.py_topk_logprobs_indices = req.py_topk_logprobs_indices.expand(
2502-
req.sampling_config.beam_width, -1
2503-
)
2504-
current_offset = next_offset
2505-
25062579
# Perform sampling in batches
25072580
batched_sampling_result = self._sample_batched_by_strategy(
25082581
logits_cuda,
@@ -2515,8 +2588,12 @@ def _process_requests(
25152588
seq_lens=seq_lens,
25162589
req_num_generated_tokens=req_num_generated_tokens,
25172590
token_dtype=new_tokens_cuda.dtype,
2591+
return_log_probs=return_log_probs,
25182592
)
25192593

2594+
if return_log_probs:
2595+
self._process_logprobs(batched_sampling_result, requests, req_num_steps)
2596+
25202597
# Fill results into output buffers
25212598
new_tokens_host = self._unbatch_sampling_results(
25222599
batched_sampling_result,

tensorrt_llm/executor/base_worker.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -561,6 +561,7 @@ def _deduce_max_tokens(request: GenerationRequest,
561561
cache_salt_id=request.cache_salt_id)
562562
executor_request.py_num_logprobs = request.sampling_params.logprobs
563563
executor_request.py_lora_path = py_lora_path
564+
executor_request.py_logprobs_mode = request.sampling_params.logprobs_mode
564565

565566
if self._is_pytorch_backend and request.multimodal_params is not None:
566567
if request.multimodal_params.multimodal_data is not None:

0 commit comments

Comments
 (0)