Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
27487ff
[None][feat] Add processed logprobs functionality to TorchSampler
stnie Dec 3, 2025
952f136
sampled logprob
hchings Nov 22, 2025
fb46e26
update
hchings Dec 2, 2025
b703015
[TRTLLM-9686][feat] Fix issues with processed logprobs functionality.
stnie Dec 5, 2025
55baf72
[TRTLLM-9687][fix] Enable pinned memory for tensor allocations in Tor…
stnie Dec 8, 2025
934d713
[TRTLLM-9688][feat] Refactor processed logprobs to skip logprob calcu…
stnie Dec 9, 2025
cc0adf3
[TRTLLM-9689][feat] Introduce max_topk_logprobs parameter and enhance…
stnie Dec 12, 2025
923574c
cache sampling_strategy
tongyuantongyu Dec 16, 2025
8581194
fix apply_embedding_bias perf
tongyuantongyu Dec 16, 2025
71fec36
fix sample_batched_by_strategy perf
tongyuantongyu Dec 16, 2025
6a1f0c1
[TRTLLM-9686][chore] Added skip to testcase and fixed _request_strate…
stnie Dec 16, 2025
83e8856
[TRTLLM-9686][chore] Refactored speculation_needs_probs to use reques…
stnie Dec 17, 2025
3071e84
Refactor logprob assertions in test_processed_logprobs_e2e to use tor…
stnie Dec 19, 2025
726c227
Nits, Update to ruff for test file
tongyuantongyu Dec 24, 2025
2aff133
Update logprobs_mode definition
tongyuantongyu Dec 30, 2025
9973f42
Avoid unnecessary initialization for runtime buffers
tongyuantongyu Dec 30, 2025
3310db1
Address reviews
tongyuantongyu Jan 2, 2026
c18f1a7
Address reviews 2
tongyuantongyu Jan 5, 2026
640ad01
Fix annotations for api_stability tests
stnie Jan 9, 2026
8c6bea1
Added testcases to waives.txt to continue skipping these testcases af…
stnie Jan 13, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1022,7 +1022,6 @@ common-files: &common_files |
tests/unittest/_torch/ray_orchestrator/single_gpu/test_cache_transceiver_comm.py |
tests/unittest/_torch/sampler/test_beam_search.py |
tests/unittest/_torch/sampler/test_best_of_n.py |
tests/unittest/_torch/sampler/test_return_logits.py |
tests/unittest/_torch/sampler/test_torch_multi_arange.py |
tests/unittest/_torch/sampler/test_trtllm_sampler.py |
tests/unittest/_torch/speculative/test_draft_target.py |
Expand Down
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -1063,7 +1063,6 @@ exclude = [
"tests/unittest/_torch/ray_orchestrator/single_gpu/test_cache_transceiver_comm.py",
"tests/unittest/_torch/sampler/test_beam_search.py",
"tests/unittest/_torch/sampler/test_best_of_n.py",
"tests/unittest/_torch/sampler/test_return_logits.py",
"tests/unittest/_torch/sampler/test_torch_multi_arange.py",
"tests/unittest/_torch/sampler/test_trtllm_sampler.py",
"tests/unittest/_torch/speculative/test_draft_target.py",
Expand Down
10 changes: 9 additions & 1 deletion tensorrt_llm/_torch/pyexecutor/llm_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from tensorrt_llm._torch.shared_tensor import SharedTensorContainer
from tensorrt_llm.bindings import executor as tllm_executor
from tensorrt_llm.executor.result import TokenLogprobs
from tensorrt_llm.sampling_params import LogprobMode

SamplingConfig = tensorrt_llm.bindings.SamplingConfig
'''
Expand Down Expand Up @@ -485,6 +486,7 @@ def __init__(
is_first_draft: bool = False,
use_chunked_generation_logits: bool = True,
logits_chunk_size: int = 8,
logprobs_mode: LogprobMode = LogprobMode.RAW,
**kwargs):

self.py_logits_post_processors = kwargs.pop("py_logits_post_processors",
Expand Down Expand Up @@ -566,6 +568,9 @@ def __init__(
# currently, keep py_stop_words_list as python list, rather than tensor.
self.py_stop_words_list = stop_words_list

self.py_logprobs_mode = LogprobMode(
logprobs_mode) # handle passed a raw string

self.py_result = PyResult(
prompt_len=self.py_prompt_len,
max_new_tokens=self.py_max_new_tokens,
Expand Down Expand Up @@ -825,7 +830,10 @@ def executor_request_to_llm_request(
arrival_time=getattr(executor_request, "py_arrival_time", None),
py_multimodal_data=getattr(executor_request, "py_multimodal_data",
None),
kv_cache_retention_config=executor_request.kv_cache_retention_config)
kv_cache_retention_config=executor_request.kv_cache_retention_config,
logprobs_mode=getattr(executor_request, "py_logprobs_mode",
LogprobMode.RAW),
)
if child_req_ids:
for child_id in child_req_ids:
llm_request.create_child_request(child_id)
Expand Down
Loading