|
8 | 8 | import pytest_asyncio
|
9 | 9 |
|
10 | 10 | from vllm.entrypoints.context import ConversationContext
|
11 |
| -from vllm.entrypoints.openai.protocol import ResponsesRequest |
| 11 | +from vllm.entrypoints.openai.protocol import ErrorResponse, ResponsesRequest |
12 | 12 | from vllm.entrypoints.openai.serving_responses import OpenAIServingResponses
|
13 | 13 | from vllm.entrypoints.tool_server import ToolServer
|
| 14 | +from vllm.inputs.data import TokensPrompt as EngineTokensPrompt |
14 | 15 |
|
15 | 16 |
|
16 | 17 | class MockConversationContext(ConversationContext):
|
@@ -127,3 +128,63 @@ async def test_initialize_tool_sessions(self, serving_responses_instance,
|
127 | 128 |
|
128 | 129 | # Verify that init_tool_sessions was called
|
129 | 130 | assert mock_context.init_tool_sessions_called
|
| 131 | + |
| 132 | + |
| 133 | +class TestValidateGeneratorInput: |
| 134 | + """Test class for _validate_generator_input method""" |
| 135 | + |
| 136 | + @pytest_asyncio.fixture |
| 137 | + async def serving_responses_instance(self): |
| 138 | + """Create a real OpenAIServingResponses instance for testing""" |
| 139 | + # Create minimal mocks for required dependencies |
| 140 | + engine_client = MagicMock() |
| 141 | + engine_client.get_model_config = AsyncMock() |
| 142 | + |
| 143 | + model_config = MagicMock() |
| 144 | + model_config.hf_config.model_type = "test" |
| 145 | + model_config.get_diff_sampling_param.return_value = {} |
| 146 | + |
| 147 | + models = MagicMock() |
| 148 | + |
| 149 | + # Create the actual instance |
| 150 | + instance = OpenAIServingResponses( |
| 151 | + engine_client=engine_client, |
| 152 | + model_config=model_config, |
| 153 | + models=models, |
| 154 | + request_logger=None, |
| 155 | + chat_template=None, |
| 156 | + chat_template_content_format="auto", |
| 157 | + ) |
| 158 | + |
| 159 | + # Set max_model_len for testing |
| 160 | + instance.max_model_len = 100 |
| 161 | + |
| 162 | + return instance |
| 163 | + |
| 164 | + def test_validate_generator_input(self, serving_responses_instance): |
| 165 | + """Test _validate_generator_input with valid prompt length""" |
| 166 | + # Create an engine prompt with valid length (less than max_model_len) |
| 167 | + valid_prompt_token_ids = list(range(5)) # 5 tokens < 100 max_model_len |
| 168 | + engine_prompt = EngineTokensPrompt( |
| 169 | + prompt_token_ids=valid_prompt_token_ids) |
| 170 | + |
| 171 | + # Call the method |
| 172 | + result = serving_responses_instance._validate_generator_input( |
| 173 | + engine_prompt) |
| 174 | + |
| 175 | + # Should return None for valid input |
| 176 | + assert result is None |
| 177 | + |
| 178 | + # create an invalid engine prompt |
| 179 | + invalid_prompt_token_ids = list( |
| 180 | + range(200)) # 100 tokens >= 100 max_model_len |
| 181 | + engine_prompt = EngineTokensPrompt( |
| 182 | + prompt_token_ids=invalid_prompt_token_ids) |
| 183 | + |
| 184 | + # Call the method |
| 185 | + result = serving_responses_instance._validate_generator_input( |
| 186 | + engine_prompt) |
| 187 | + |
| 188 | + # Should return an ErrorResponse |
| 189 | + assert result is not None |
| 190 | + assert isinstance(result, ErrorResponse) |
0 commit comments