Skip to content

Commit 18531f1

Browse files
committed
[TRTLLM-9688][feat] Refactor processed logprobs to skip logprob calculation when not needed
- Added LogProbsMode class to define modes for log probabilities: RAW and PROCESSED. - Updated SamplingParams and LlmRequest to utilize LogProbsMode for logprobs_mode parameter. - Enhanced validation to check logprobs_mode against LogProbsMode values. - Modified TorchSampler and related classes to support new logprobs_mode functionality. - Modified TorchSampler to only calculate logprobs when a request needs it - Updated tests to cover new logprobs_mode behavior and ensure correct processing of log probabilities. Signed-off-by: Stefan Niebler <[email protected]>
1 parent 160f461 commit 18531f1

File tree

7 files changed

+265
-81
lines changed

7 files changed

+265
-81
lines changed

tensorrt_llm/_torch/pyexecutor/llm_request.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from tensorrt_llm._torch.shared_tensor import SharedTensorContainer
99
from tensorrt_llm.bindings import executor as tllm_executor
1010
from tensorrt_llm.executor.result import TokenLogprobs
11+
from tensorrt_llm.sampling_params import LogprobMode
1112

1213
SamplingConfig = tensorrt_llm.bindings.SamplingConfig
1314
'''
@@ -460,7 +461,7 @@ def __init__(
460461
is_first_draft: bool = False,
461462
use_chunked_generation_logits: bool = True,
462463
logits_chunk_size: int = 8,
463-
logprobs_mode: str = "raw",
464+
logprobs_mode: LogprobMode | None = None,
464465
**kwargs):
465466

466467
self.py_logits_post_processors = kwargs.pop("py_logits_post_processors",
@@ -539,7 +540,7 @@ def __init__(
539540
# currently, keep py_stop_words_list as python list, rather than tensor.
540541
self.py_stop_words_list = stop_words_list
541542

542-
self.py_logprobs_mode = logprobs_mode
543+
self.py_logprobs_mode = LogprobMode.RAW if logprobs_mode is None else logprobs_mode
543544

544545
self.py_result = PyResult(
545546
prompt_len=self.py_prompt_len,
@@ -568,6 +569,15 @@ def set_exclude_last_generation_logits(
568569
self.py_result.set_exclude_last_generation_logits(
569570
exclude_last_generation_logits)
570571

572+
def validate_logprobs_mode(self):
573+
if self.py_logprobs_mode not in [
574+
LogprobMode.RAW, LogprobMode.PROCESSED
575+
]:
576+
raise ValueError(
577+
f"Invalid logprobs_mode: {self.py_logprobs_mode} "
578+
f"Expected one of {LogprobMode.RAW.value}, {LogprobMode.PROCESSED.value}"
579+
)
580+
571581
@property
572582
def cached_tokens(self) -> int:
573583
return self._cached_tokens
@@ -801,7 +811,7 @@ def executor_request_to_llm_request(
801811
py_multimodal_data=getattr(executor_request, "py_multimodal_data",
802812
None),
803813
kv_cache_retention_config=executor_request.kv_cache_retention_config,
804-
logprobs_mode=getattr(executor_request, "py_logprobs_mode", "raw"),
814+
logprobs_mode=getattr(executor_request, "py_logprobs_mode", None),
805815
)
806816
if child_req_ids:
807817
for child_id in child_req_ids:

tensorrt_llm/_torch/pyexecutor/py_executor.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1718,6 +1718,8 @@ def _validate_request(self, request: LlmRequest):
17181718
f"Request beam width {sampling_config.beam_width} "
17191719
f"is not equal to max_beam_width {self.max_beam_width}. This is not supported!"
17201720
)
1721+
# Validate logprobs mode
1722+
request.validate_logprobs_mode()
17211723

17221724
# Check token ID ranges
17231725
if isinstance(self.model_engine.model, DecoderModelForCausalLM):

0 commit comments

Comments
 (0)