Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions tensorrt_llm/executor/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -624,6 +624,9 @@ def __init__(self,
self._streaming = streaming

def _handle_response(self, response: "GenerationExecutor.Response"):
# Save token lengths before processing to detect which outputs received new tokens
prev_token_lens = {id(o): len(o.token_ids) for o in self._outputs}

GenerationResultBase._handle_response(self, response)

# The postprocess has been performed, return directly
Expand All @@ -638,7 +641,15 @@ def _handle_response(self, response: "GenerationExecutor.Response"):
}
if self.sampling_params.detokenize and self.tokenizer is not None:
for beam_output in self.outputs:
# Always update _last_text_len to prevent stale text_diff
beam_output._last_text_len = len(beam_output.text)
# For n > 1 streaming: only detokenize outputs that received new tokens
# to prevent re-decoding the same tokens multiple times
output_received_new_tokens = len(
beam_output.token_ids) != prev_token_lens.get(
id(beam_output), 0)
if not output_received_new_tokens:
continue
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This could lead to wrong outputs when using beam search.
If one beam exits early, it may in the next iteration swap its tokens with another non-exited beam without adding a new token. As the number of tokens did not change, the swapped tokens will not be decoded, which could result in a wrong output.

You may adjust this line to

if not output_received_new_tokens and not self.sampling_params.use_beam_search:
  continue

if hasattr(
self.tokenizer, 'decode_incrementally'
) and self._streaming and not self.sampling_params.use_beam_search:
Expand Down
34 changes: 34 additions & 0 deletions tests/unittest/llmapi/apps/_test_openai_completions.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import openai
import pytest
from utils.util import similar

from ..test_llm import get_model_path
from .openai_server import RemoteOpenAIServer
Expand Down Expand Up @@ -204,6 +205,39 @@ async def test_batch_completions_streaming(async_client: openai.AsyncOpenAI,
assert texts[0] == texts[1]


@pytest.mark.asyncio(loop_scope="module")
@pytest.mark.parametrize("prompts", [["Hello, my name is"] * 2])
async def test_batch_completions_with_option_n_streaming(
async_client: openai.AsyncOpenAI, model_name, prompts):
# Use non-stream single generation as reference
completion_ref = await async_client.completions.create(
model=model_name,
prompt=prompts[0],
max_tokens=5,
temperature=0.0001,
)
text_ref = completion_ref.choices[0].text

# test n>1 with streaming
batch = await async_client.completions.create(
model=model_name,
prompt=prompts,
n=3, # number of completions to generate for each prompt.
max_tokens=5,
temperature=0.0001,
stream=True,
)
texts = [""] * 6 # 2 prompts × 3 generations per prompt = 6 choices
async for chunk in batch:
assert len(chunk.choices) == 1
choice = chunk.choices[0]
texts[choice.index] += choice.text

# Check all generations are consistent with the reference
for text in texts:
assert similar(text, text_ref, threshold=0.8)


@pytest.mark.asyncio(loop_scope="module")
async def test_completion_stream_options(async_client: openai.AsyncOpenAI,
model_name: str):
Expand Down