Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1022,7 +1022,6 @@ common-files: &common_files |
tests/unittest/_torch/ray_orchestrator/single_gpu/test_cache_transceiver_comm.py |
tests/unittest/_torch/sampler/test_beam_search.py |
tests/unittest/_torch/sampler/test_best_of_n.py |
tests/unittest/_torch/sampler/test_return_logits.py |
tests/unittest/_torch/sampler/test_torch_multi_arange.py |
tests/unittest/_torch/sampler/test_trtllm_sampler.py |
tests/unittest/_torch/speculative/test_draft_target.py |
Expand Down
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -1063,7 +1063,6 @@ exclude = [
"tests/unittest/_torch/ray_orchestrator/single_gpu/test_cache_transceiver_comm.py",
"tests/unittest/_torch/sampler/test_beam_search.py",
"tests/unittest/_torch/sampler/test_best_of_n.py",
"tests/unittest/_torch/sampler/test_return_logits.py",
"tests/unittest/_torch/sampler/test_torch_multi_arange.py",
"tests/unittest/_torch/sampler/test_trtllm_sampler.py",
"tests/unittest/_torch/speculative/test_draft_target.py",
Expand Down
10 changes: 9 additions & 1 deletion tensorrt_llm/_torch/pyexecutor/llm_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from tensorrt_llm._torch.shared_tensor import SharedTensorContainer
from tensorrt_llm.bindings import executor as tllm_executor
from tensorrt_llm.executor.result import TokenLogprobs
from tensorrt_llm.sampling_params import LogprobMode

SamplingConfig = tensorrt_llm.bindings.SamplingConfig
'''
Expand Down Expand Up @@ -485,6 +486,7 @@ def __init__(
is_first_draft: bool = False,
use_chunked_generation_logits: bool = True,
logits_chunk_size: int = 8,
logprobs_mode: LogprobMode = LogprobMode.RAW,
**kwargs):

self.py_logits_post_processors = kwargs.pop("py_logits_post_processors",
Expand Down Expand Up @@ -566,6 +568,9 @@ def __init__(
# currently, keep py_stop_words_list as python list, rather than tensor.
self.py_stop_words_list = stop_words_list

self.py_logprobs_mode = LogprobMode(
logprobs_mode) # handle passed a raw string

self.py_result = PyResult(
prompt_len=self.py_prompt_len,
max_new_tokens=self.py_max_new_tokens,
Expand Down Expand Up @@ -825,7 +830,10 @@ def executor_request_to_llm_request(
arrival_time=getattr(executor_request, "py_arrival_time", None),
py_multimodal_data=getattr(executor_request, "py_multimodal_data",
None),
kv_cache_retention_config=executor_request.kv_cache_retention_config)
kv_cache_retention_config=executor_request.kv_cache_retention_config,
logprobs_mode=getattr(executor_request, "py_logprobs_mode",
LogprobMode.RAW),
)
if child_req_ids:
for child_id in child_req_ids:
llm_request.create_child_request(child_id)
Expand Down
Loading