Skip to content

Commit d66c983

Browse files
committed
[TRTLLM-9689][feat] Introduce max_topk_logprobs parameter and enhance logprobs handling
- Added max_topk_logprobs parameter to AutoDeployConfig and LlmRequest to control the number of top-k logprobs storable for each token. - Updated TorchSampler to accommodate max_topk_logprobs in logprobs processing and storage. - Enhanced logprobs handling in the sampling process to support both sampled and top-k logprobs. - Enabled batched processing of logprobs to enhance logprobs performance - Modified tests to validate the new max_topk_logprobs functionality and ensure correct logprobs output. Signed-off-by: Stefan <[email protected]>
1 parent aabd9f4 commit d66c983

File tree

10 files changed

+347
-183
lines changed

10 files changed

+347
-183
lines changed

tensorrt_llm/_torch/auto_deploy/llm_args.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,9 @@ class AutoDeployConfig(DynamicYamlMixInForSettings, BaseSettings):
200200
max_num_tokens: Optional[int] = Field(default=None, description="The maximum number of tokens.")
201201
max_seq_len: int = Field(default=512, ge=1, description="The maximum sequence length.")
202202
max_batch_size: int = Field(default=8, ge=1, description="The maximum batch size.")
203+
max_topk_logprobs: int = Field(
204+
default=0, description="The maximum number of top-k logprobs to store for each token."
205+
)
203206
attn_page_size: int = Field(
204207
default=64,
205208
ge=1,

tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -770,6 +770,7 @@ def instantiate_sampler(
770770
max_num_sequences=max_num_sequences,
771771
max_beam_width=ad_config.max_beam_width,
772772
disable_overlap_scheduler=ad_config.disable_overlap_scheduler,
773+
max_topk_logprobs=ad_config.max_topk_logprobs,
773774
)
774775
sampler = TorchSampler(sampler_args)
775776

tensorrt_llm/_torch/pyexecutor/_util.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -857,6 +857,7 @@ def create_torch_sampler_args(
857857
disable_overlap_scheduler: bool,
858858
disable_flashinfer_sampling: bool,
859859
enable_async_worker: bool,
860+
max_topk_logprobs: int,
860861
):
861862
max_num_sequences = max_batch_size * mapping.pp_size
862863
max_draft_len = (0 if speculative_config is None else
@@ -872,7 +873,9 @@ def create_torch_sampler_args(
872873
max_beam_width=max_beam_width,
873874
disable_flashinfer_sampling=disable_flashinfer_sampling,
874875
disable_overlap_scheduler=disable_overlap_scheduler,
875-
enable_async_worker=enable_async_worker)
876+
enable_async_worker=enable_async_worker,
877+
max_topk_logprobs=max_topk_logprobs,
878+
)
876879

877880

878881
def instantiate_sampler(
@@ -888,6 +891,7 @@ def instantiate_sampler(
888891
decoding_config: trtllm.DecodingConfig,
889892
kv_cache_config: KvCacheConfig,
890893
disable_flashinfer_sampling: bool,
894+
max_topk_logprobs: int,
891895
):
892896
enable_async_worker = (confidential_compute_enabled()
893897
or llm_args.sampler_force_async_worker)
@@ -901,6 +905,7 @@ def instantiate_sampler(
901905
disable_overlap_scheduler=llm_args.disable_overlap_scheduler,
902906
disable_flashinfer_sampling=disable_flashinfer_sampling,
903907
enable_async_worker=enable_async_worker,
908+
max_topk_logprobs=max_topk_logprobs,
904909
)
905910
decoding_mode = get_decoding_mode(decoding_config=decoding_config,
906911
max_beam_width=max_beam_width)

tensorrt_llm/_torch/pyexecutor/py_executor_creator.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -534,6 +534,7 @@ def drafting_loop_wrapper(model):
534534
decoding_config=decoding_config,
535535
kv_cache_config=kv_cache_config,
536536
disable_flashinfer_sampling=llm_args.disable_flashinfer_sampling,
537+
max_topk_logprobs=llm_args.max_topk_logprobs,
537538
)
538539
logger.info(f"Using Sampler: {type(sampler).__name__}")
539540

tensorrt_llm/_torch/pyexecutor/sampler.py

Lines changed: 315 additions & 175 deletions
Large diffs are not rendered by default.

tensorrt_llm/llmapi/llm_args.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2848,6 +2848,13 @@ class TorchLlmArgs(BaseLlmArgs):
28482848
status="prototype",
28492849
)
28502850

2851+
max_topk_logprobs: int = Field(
2852+
default=0,
2853+
description=
2854+
"The maximum number of top-k logprobs per request to calculate each step. This does not affect the number of sampled logprobs.",
2855+
status="prototype",
2856+
)
2857+
28512858
@property
28522859
def quant_config(self) -> QuantConfig:
28532860
if self._quant_config is None:

tensorrt_llm/sampling_params.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -344,10 +344,11 @@ def _validate(self):
344344
if self.guided_decoding is not None:
345345
self.guided_decoding._validate()
346346

347-
# correct types as users might pass in logprob=True for Top-1 logprobs and logprobs=False for no logprobs
347+
# correct types as users might pass in logprob=True for Top-0 logprobs and logprobs=False for no logprobs
348348
if self.logprobs is False:
349349
self.logprobs = None
350-
self.logprobs = self.logprobs and int(self.logprobs)
350+
if self.logprobs is True:
351+
self.logprobs = 0
351352
self.prompt_logprobs = self.prompt_logprobs and int(self.prompt_logprobs)
352353

353354
# NB: Static, because downstream code only holds instances of

tests/unittest/_torch/sampler/test_beam_search.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ def check_generation_logits(beam: CompletionOutput,
125125
def check_logprobs(beam: CompletionOutput, sampling_params: SamplingParams,
126126
valid_tokens: int | None) -> None:
127127
"""Check if the logprobs have the correct shape"""
128-
if sampling_params.logprobs:
128+
if sampling_params.logprobs is not None:
129129
generated_tokens = valid_tokens if valid_tokens is not None else sampling_params.max_tokens
130130
assert len(
131131
beam.logprobs
@@ -345,7 +345,7 @@ class GeneralTestParams:
345345
prompt_len = len(input_tokens)
346346
num_generated_tokens = 5
347347
seq_len = prompt_len + num_generated_tokens
348-
num_logprobs = 1
348+
num_logprobs = 0
349349
seq_slot = 4
350350
end_id = 99
351351
batch_size = 2
@@ -541,7 +541,7 @@ def create_default_request(test_params: GeneralTestParams) -> LlmRequest:
541541
end_id=test_params.end_id,
542542
sampling_config=SamplingConfig(
543543
sampling_params._get_sampling_config()),
544-
return_log_probs=test_params.num_logprobs > 0,
544+
return_log_probs=test_params.num_logprobs >= 0,
545545
num_logprobs=test_params.num_logprobs,
546546
is_streaming=False)
547547

@@ -590,7 +590,7 @@ def test_create_beam_history():
590590
num_generated_tokens = test_params.num_generated_tokens
591591
seq_slot = test_params.seq_slot
592592
vocab_size = test_params.vocab_size
593-
num_logprobs = test_params.num_logprobs
593+
num_logprobs = test_params.num_logprobs + 1
594594
cache_indirection = sampler.store.cache_indirection
595595
original_tokens = sampler.store.original_tokens
596596
original_logprobs = torch.zeros(
@@ -663,7 +663,8 @@ def test_create_beam_history():
663663
# set the new log probs and tokens for the beam search sampling
664664
sampler.store.new_log_probs[
665665
seq_slot, :beam_width] = original_logprobs[:beam_width,
666-
num_generated_tokens - 1, 0]
666+
num_generated_tokens - 1,
667+
0:1]
667668
sampler.store.new_tokens[
668669
0,
669670
seq_slot, :beam_width] = original_logprob_indices[:beam_width,

tests/unittest/_torch/sampler/test_logits_logprobs.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@ def simple_llm(request) -> LLM:
115115
"TinyLlama-1.1B-Chat-v1.0"),
116116
max_batch_size=8,
117117
disable_flashinfer_sampling=disable_flashinfer_sampling,
118+
max_topk_logprobs=3,
118119
)
119120
return llm
120121

tests/unittest/api_stability/references/llm.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,10 @@ methods:
223223
annotation: Optional[Dict[str, str]]
224224
default: null
225225
status: prototype
226+
max_topk_logprobs:
227+
annotation: int
228+
default: 0
229+
status: prototype
226230
return_annotation: None
227231
generate:
228232
parameters:

0 commit comments

Comments
 (0)