Skip to content

Commit 831b124

Browse files
qandrewAndrew Xia
andauthored
[responsesAPI] add better error messaging for long prompts (vllm-project#25724)
Signed-off-by: Andrew Xia <[email protected]> Signed-off-by: Andrew Xia <[email protected]> Co-authored-by: Andrew Xia <[email protected]>
1 parent c1ffcb5 commit 831b124

File tree

2 files changed

+84
-1
lines changed

2 files changed

+84
-1
lines changed

tests/entrypoints/openai/test_serving_responses.py

Lines changed: 62 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,10 @@
88
import pytest_asyncio
99

1010
from vllm.entrypoints.context import ConversationContext
11-
from vllm.entrypoints.openai.protocol import ResponsesRequest
11+
from vllm.entrypoints.openai.protocol import ErrorResponse, ResponsesRequest
1212
from vllm.entrypoints.openai.serving_responses import OpenAIServingResponses
1313
from vllm.entrypoints.tool_server import ToolServer
14+
from vllm.inputs.data import TokensPrompt as EngineTokensPrompt
1415

1516

1617
class MockConversationContext(ConversationContext):
@@ -127,3 +128,63 @@ async def test_initialize_tool_sessions(self, serving_responses_instance,
127128

128129
# Verify that init_tool_sessions was called
129130
assert mock_context.init_tool_sessions_called
131+
132+
133+
class TestValidateGeneratorInput:
134+
"""Test class for _validate_generator_input method"""
135+
136+
@pytest_asyncio.fixture
137+
async def serving_responses_instance(self):
138+
"""Create a real OpenAIServingResponses instance for testing"""
139+
# Create minimal mocks for required dependencies
140+
engine_client = MagicMock()
141+
engine_client.get_model_config = AsyncMock()
142+
143+
model_config = MagicMock()
144+
model_config.hf_config.model_type = "test"
145+
model_config.get_diff_sampling_param.return_value = {}
146+
147+
models = MagicMock()
148+
149+
# Create the actual instance
150+
instance = OpenAIServingResponses(
151+
engine_client=engine_client,
152+
model_config=model_config,
153+
models=models,
154+
request_logger=None,
155+
chat_template=None,
156+
chat_template_content_format="auto",
157+
)
158+
159+
# Set max_model_len for testing
160+
instance.max_model_len = 100
161+
162+
return instance
163+
164+
def test_validate_generator_input(self, serving_responses_instance):
165+
"""Test _validate_generator_input with valid prompt length"""
166+
# Create an engine prompt with valid length (less than max_model_len)
167+
valid_prompt_token_ids = list(range(5)) # 5 tokens < 100 max_model_len
168+
engine_prompt = EngineTokensPrompt(
169+
prompt_token_ids=valid_prompt_token_ids)
170+
171+
# Call the method
172+
result = serving_responses_instance._validate_generator_input(
173+
engine_prompt)
174+
175+
# Should return None for valid input
176+
assert result is None
177+
178+
# create an invalid engine prompt
179+
invalid_prompt_token_ids = list(
180+
range(200)) # 100 tokens >= 100 max_model_len
181+
engine_prompt = EngineTokensPrompt(
182+
prompt_token_ids=invalid_prompt_token_ids)
183+
184+
# Call the method
185+
result = serving_responses_instance._validate_generator_input(
186+
engine_prompt)
187+
188+
# Should return an ErrorResponse
189+
assert result is not None
190+
assert isinstance(result, ErrorResponse)

vllm/entrypoints/openai/serving_responses.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,23 @@ def __init__(
192192

193193
self.tool_server = tool_server
194194

195+
def _validate_generator_input(
196+
self,
197+
engine_prompt: EngineTokensPrompt) -> Optional[ErrorResponse]:
198+
"""Add validations to the input to the generator here."""
199+
if self.max_model_len <= len(engine_prompt["prompt_token_ids"]):
200+
error_message = (
201+
"The engine prompt length"
202+
f" {len(engine_prompt['prompt_token_ids'])} "
203+
f"exceeds the max_model_len {self.max_model_len}. "
204+
"Please reduce prompt.")
205+
return self.create_error_response(
206+
err_type="invalid_request_error",
207+
message=error_message,
208+
status_code=HTTPStatus.BAD_REQUEST,
209+
)
210+
return None
211+
195212
async def create_responses(
196213
self,
197214
request: ResponsesRequest,
@@ -287,8 +304,13 @@ async def create_responses(
287304
available_tools = []
288305
try:
289306
for i, engine_prompt in enumerate(engine_prompts):
307+
maybe_error = self._validate_generator_input(engine_prompt)
308+
if maybe_error is not None:
309+
return maybe_error
310+
290311
default_max_tokens = self.max_model_len - len(
291312
engine_prompt["prompt_token_ids"])
313+
292314
sampling_params = request.to_sampling_params(
293315
default_max_tokens, self.default_sampling_params)
294316

0 commit comments

Comments
 (0)