Skip to content

Commit bc4136f

Browse files
authored
[https://nvbugs/5427043][fix] cherrypick: request length exceeds max_num_tokens (#7718)
Signed-off-by: Superjomn <[email protected]>
1 parent ce6ebf6 commit bc4136f

File tree

3 files changed

+30
-1
lines changed

3 files changed

+30
-1
lines changed

tensorrt_llm/executor/worker.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -932,7 +932,9 @@ def handle_for_ipc_batched(self, responses: List[tllm.Response]) -> None:
932932

933933
for response in responses:
934934

935-
if self.worker._has_background_error():
935+
if isinstance(response, ErrorResponse):
936+
pass # send ErrorResponse directly
937+
elif self.worker._has_background_error():
936938
response = self.worker._create_error_response(response)
937939
elif response.has_error():
938940
# Convert to ErrorResponse, because tllm.Response cannot be

tests/unittest/llmapi/apps/_test_openai_completions.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,18 @@ def test_single_completion(client: openai.OpenAI, model_name):
8080
assert len(completion.choices[0].text) >= 1
8181

8282

83+
def test_single_completion_with_too_long_prompt(client: openai.OpenAI,
84+
model_name):
85+
completion = client.completions.create(
86+
model=model_name,
87+
prompt="Hello, my name is" * 100,
88+
max_tokens=5,
89+
temperature=0.0,
90+
)
91+
92+
print(completion)
93+
94+
8395
@pytest.mark.asyncio(loop_scope="module")
8496
@pytest.mark.parametrize("echo", [True, False])
8597
async def test_completion_streaming(async_client: openai.AsyncOpenAI,

tests/unittest/llmapi/test_llm_pytorch.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import random
12
from contextlib import contextmanager, nullcontext
23

34
import pytest
@@ -821,3 +822,17 @@ def test_llm_with_proxy_error():
821822
match="Mock GenerationExecutorWorker initialization failed"):
822823
llm = LLM(model=llama_model_path,
823824
kv_cache_config=global_kvcache_config)
825+
826+
827+
class TestLlmError:
828+
829+
def test_max_num_token_check(self):
830+
""" LLM should raise error when got prompt length exceed the valid range. """
831+
llm = LLM(llama_model_path,
832+
kv_cache_config=global_kvcache_config,
833+
max_num_tokens=100)
834+
835+
with pytest.raises(ValueError,
836+
match="should not exceed max_num_tokens"):
837+
ids = [random.randint(10, 100) for _ in range(101)]
838+
llm.generate([ids])

0 commit comments

Comments
 (0)