Skip to content

Commit b04e7e6

Browse files
committed
[TRTLLM-9686][feat] Fix issues with processed logprobs functionality.
- Expand test_logits_logprobs to perform a check for processed logprobs - Fix processed logprobs for greedy sampling and when using temperature Signed-off-by: Stefan Niebler <82932102+stnie@users.noreply.github.com>
1 parent 5db31a5 commit b04e7e6

File tree

5 files changed

+207
-38
lines changed

5 files changed

+207
-38
lines changed

tensorrt_llm/_torch/pyexecutor/sampler.py

Lines changed: 24 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1127,7 +1127,7 @@ def _convert_logprobs_tensor_to_list(
11271127
assert beam_idx == 0, (
11281128
"beam search does not need to explicitly handle sampled log probs"
11291129
)
1130-
if sampled_log_probs_indices[step_idx] not in logprobs:
1130+
if sampled_log_probs_indices[step_idx].item() not in logprobs:
11311131
logprobs[sampled_log_probs_indices[step_idx].item()] = Logprob(
11321132
logprob=sampled_log_probs_vals[step_idx].item(),
11331133
rank=max(
@@ -1380,7 +1380,7 @@ def _process_draft_tokens_rejection_sampling(
13801380
else _request_strategy(request, vocab_size=2**31)
13811381
)
13821382
generator = self.get_generator(request.py_draft_logits.device)
1383-
_, draft_probs = sample(
1383+
_, draft_probs, _ = sample(
13841384
draft_sampling_strategy,
13851385
request.py_draft_logits,
13861386
generator=generator,
@@ -2160,7 +2160,7 @@ def _sample_batched_by_strategy(
21602160
for _ in range(steps)
21612161
]
21622162

2163-
group_next_tokens_cuda, group_softmax_cuda = (
2163+
group_next_tokens_cuda, group_softmax_cuda, group_temperature_cuda = (
21642164
self._grouped_sampler_cls.sample_grouped_strategies(
21652165
strategy_key,
21662166
group_strategies_per_step,
@@ -2182,18 +2182,29 @@ def _sample_batched_by_strategy(
21822182
].copy_(group_next_tokens_cuda, non_blocking=True)
21832183

21842184
if return_log_probs:
2185+
# select the logits for the current group
2186+
current_group_logits_cuda = (
2187+
group_logits_cuda
2188+
if logit_indices_for_sampler is None
2189+
else group_logits_cuda[logit_indices_for_sampler]
2190+
)
21852191
if need_processed_logprobs:
21862192
# if softmax is 0, then the logit was masked out => set to -inf
2187-
group_tgt_logits_cuda = torch.where(
2188-
group_softmax_cuda != 0, group_logits_cuda, float("-inf")
2189-
)
2193+
# apply masking to the logits and store in batch_logits_cuda
21902194
batch_logits_cuda[
21912195
batch_next_tokens_offset_start:batch_next_tokens_offset_end
2192-
].copy_(group_tgt_logits_cuda, non_blocking=True)
2196+
] = torch.where(
2197+
group_softmax_cuda > 0, current_group_logits_cuda, float("-inf")
2198+
)
2199+
# apply temperature to the logits
2200+
if group_temperature_cuda is not None:
2201+
batch_logits_cuda[
2202+
batch_next_tokens_offset_start:batch_next_tokens_offset_end
2203+
] /= group_temperature_cuda
21932204
else:
21942205
batch_logits_cuda[
21952206
batch_next_tokens_offset_start:batch_next_tokens_offset_end
2196-
].copy_(group_logits_cuda, non_blocking=True)
2207+
].copy_(current_group_logits_cuda, non_blocking=True)
21972208

21982209
# Set LlmRequest.py_target_probs
21992210
if speculation_needs_probs:
@@ -2697,7 +2708,7 @@ def _process_logprobs(
26972708
# NB: we do not need group logprobs anymore, we can reuse the storage
26982709
# We only provide 0 based rank, it will be corrected to 1-indexed in handle logprobs
26992710
group_logprobs_cuda.greater_(sampled_vals_cuda)
2700-
sampled_rank_cuda = group_logprobs_cuda.sum(dim=-1)
2711+
sampled_rank_cuda = group_logprobs_cuda.sum(dim=-1).to(torch.int32)
27012712

27022713
# Use a single D2H copy to reduce overheads
27032714
topk_vals = torch.empty_like(topk_vals_cuda, device="cpu", pin_memory=False)
@@ -2768,12 +2779,7 @@ def _process_requests(
27682779
req_offsets=sampling_requests_metadata.req_offsets,
27692780
)
27702781

2771-
self._handle_log_probs(
2772-
requests,
2773-
logits_cuda,
2774-
logits_cuda_indexer=logits_cuda_indexer,
2775-
req_num_generated_tokens=sampling_requests_metadata.req_num_generated_tokens,
2776-
)
2782+
return_log_probs = self._return_log_probs(requests)
27772783

27782784
# Perform sampling in batches
27792785
batched_sampling_result = self._sample_batched_by_strategy(
@@ -2792,7 +2798,9 @@ def _process_requests(
27922798
)
27932799

27942800
if return_log_probs:
2795-
self._process_logprobs(batched_sampling_result, requests, req_num_steps)
2801+
self._process_logprobs(
2802+
batched_sampling_result, requests, sampling_requests_metadata.req_num_steps
2803+
)
27962804

27972805
# Fill results into output buffers
27982806
new_tokens_host = self._unbatch_sampling_results(

tensorrt_llm/_torch/pyexecutor/sampling_utils.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -266,7 +266,8 @@ def greedy_search_sampling_batch(
266266
next_tokens = torch.argmax(logits, dim=-1)
267267
softmax: Optional[torch.Tensor] = None
268268
if return_probs:
269-
softmax = torch.softmax(logits, dim=-1)
269+
softmax = torch.zeros_like(logits)
270+
softmax.scatter_(1, next_tokens.unsqueeze(-1), 1.0)
270271
return next_tokens, softmax
271272

272273

@@ -474,7 +475,7 @@ def sample(
474475
generator: Optional[torch.Generator] = None,
475476
group_metadata: StrategyMetadata | None = None,
476477
return_probs: bool = True,
477-
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
478+
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[float]]:
478479
match strategy:
479480
case ("top_k", top_k, temperature):
480481
tokens, softmax = top_k_sampling_batch(
@@ -506,6 +507,7 @@ def sample(
506507
)
507508
case ("greedy", None):
508509
tokens, softmax = greedy_search_sampling_batch(logits, return_probs=return_probs)
510+
temperature = None
509511
case ("beam_search", beam_width_in, beam_width_out, temperature):
510512
assert group_metadata is not None and isinstance(group_metadata, BeamSearchMetadata), (
511513
"BeamSearchMetadata is required for beam_search_sampling_batch"
@@ -519,7 +521,7 @@ def sample(
519521
generator=generator,
520522
return_probs=return_probs,
521523
)
522-
return tokens, softmax
524+
return tokens, softmax, temperature
523525

524526

525527
GenericStrategyKeyType = TypeVar("GenericStrategyKeyType")

tensorrt_llm/_torch/pyexecutor/sampling_utils_flashinfer.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -141,8 +141,9 @@ def _sample_greedy_with_probs(
141141
*,
142142
group_logit_indices: Optional[torch.Tensor],
143143
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
144-
probs = self._prepare_probs_with_temperature(logits, group_logit_indices, None)
145-
new_tokens, _ = greedy_search_sampling_batch(probs, return_probs=False)
144+
if group_logit_indices is not None:
145+
logits = torch.index_select(logits, 0, group_logit_indices) # ensures copy
146+
new_tokens, probs = greedy_search_sampling_batch(logits, return_probs=True)
146147
return new_tokens, probs
147148

148149
@classmethod
@@ -240,6 +241,9 @@ def computes_probs(cls) -> bool:
240241
return True
241242

242243
class GreedyWithProbs(StrategyImplWithProbs):
244+
def __init__(self):
245+
self._temperature = None
246+
243247
@override
244248
@classmethod
245249
def from_strategies(
@@ -425,6 +429,9 @@ def computes_probs(cls) -> bool:
425429
return False
426430

427431
class GreedySampleOnly(StrategyImplSampleOnly):
432+
def __init__(self):
433+
self._temperature = None
434+
428435
@override
429436
@classmethod
430437
def from_strategies(
@@ -722,7 +729,7 @@ def sample_grouped_strategies(
722729
generator: Optional[torch.Generator] = None,
723730
return_probs: bool,
724731
group_metadata: StrategyMetadata | None = None,
725-
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
732+
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
726733
if hasattr(group_key, "static_beam_width_in"):
727734
beam_width_in = group_key.static_beam_width_in
728735
else:
@@ -735,9 +742,16 @@ def sample_grouped_strategies(
735742
assert return_probs == group_key.computes_probs()
736743

737744
strategy_impl_cls = group_key
738-
return strategy_impl_cls.from_strategies(strategies, cuda_device=logits.device).sample(
745+
sampling_object = strategy_impl_cls.from_strategies(strategies, cuda_device=logits.device)
746+
next_tokens, softmax = sampling_object.sample(
739747
logits,
740748
group_logit_indices=group_logit_indices,
741749
generator=generator,
742750
group_metadata=group_metadata,
743751
)
752+
temperature = (
753+
sampling_object._temperature.unsqueeze(-1)
754+
if sampling_object._temperature is not None
755+
else None
756+
)
757+
return next_tokens, softmax, temperature

tensorrt_llm/sampling_params.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -337,7 +337,9 @@ def _validate(self):
337337
if self.guided_decoding is not None:
338338
self.guided_decoding._validate()
339339

340-
# correct types as users might pass in logprob=True for Top-1 logprobs
340+
# correct types as users might pass in logprob=True for Top-1 logprobs and logprobs=False for no logprobs
341+
if self.logprobs is False:
342+
self.logprobs = None
341343
self.logprobs = self.logprobs and int(self.logprobs)
342344
self.prompt_logprobs = self.prompt_logprobs and int(self.prompt_logprobs)
343345

0 commit comments

Comments
 (0)