Skip to content

Commit db1c637

Browse files
authored
[#11992][fix] Support include_stop_token_in_output in gRPC request manager (#12517)
Signed-off-by: Chang Su <chang.s.su@oracle.com>
1 parent 7273bad commit db1c637

File tree

2 files changed

+4
-0
lines changed

2 files changed

+4
-0
lines changed

tensorrt_llm/grpc/grpc_request_manager.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,7 @@ def create_sampling_params_from_proto(
246246
bad_token_ids: Optional[List[int]] = None,
247247
guided_decoding: Optional[pb2.GuidedDecodingParams] = None,
248248
embedding_bias: Optional[List[float]] = None,
249+
include_stop_token_in_output: bool = False,
249250
) -> SamplingParams:
250251
"""Convert protobuf configuration to TensorRT-LLM SamplingParams.
251252
@@ -332,6 +333,8 @@ def create_sampling_params_from_proto(
332333
kwargs["stop_token_ids"] = stop_token_ids
333334
if ignore_eos:
334335
kwargs["ignore_eos"] = True
336+
if include_stop_token_in_output:
337+
kwargs["include_stop_str_in_output"] = True
335338

336339
# Bad words (TRT-LLM's _setup() tokenizes bad word strings)
337340
if bad:

tensorrt_llm/grpc/grpc_servicer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@ async def Generate(
106106
if request.HasField("guided_decoding")
107107
else None,
108108
embedding_bias=list(request.embedding_bias) if request.embedding_bias else None,
109+
include_stop_token_in_output=request.include_stop_token_in_output,
109110
)
110111

111112
# Build LoRA request if present

0 commit comments

Comments
 (0)