|
21 | 21 | import pytest_asyncio |
22 | 22 |
|
23 | 23 | # Local |
24 | | -from vllm_detector_adapter.generative_detectors.base import ChatCompletionDetectionBase |
| 24 | +from vllm_detector_adapter.generative_detectors.base import ( |
| 25 | + ChatCompletionDetectionBase, |
| 26 | + ErrorResponse, |
| 27 | +) |
25 | 28 | from vllm_detector_adapter.protocol import ( |
26 | 29 | ContentsDetectionRequest, |
27 | 30 | ContentsDetectionResponse, |
@@ -204,3 +207,35 @@ def test_content_analysis_success(detection_base, completion_response): |
204 | 207 | assert detections[1][0]["start"] == 0 |
205 | 208 | assert detections[1][0]["end"] == len(content_request.contents[1]) |
206 | 209 | assert detections[1][0]["metadata"] == {} |
| 210 | + |
| 211 | + |
| 212 | +def test_content_analysis_errorresponse_verification(detection_base): |
| 213 | + """Test that content_analysis properly propagates an ErrorResponse when a choice has empty content.""" |
| 214 | + base_instance = asyncio.run(detection_base) |
| 215 | + content_request = ContentsDetectionRequest(contents=["Where do I find geese?"]) |
| 216 | + |
| 217 | + # Simulate what the model would produce for a request: empty content triggers error |
| 218 | + choice = ChatCompletionResponseChoice( |
| 219 | + index=0, |
| 220 | + message=ChatMessage(role="assistant", content=""), |
| 221 | + logprobs=ChatCompletionLogProbs(content=[]), |
| 222 | + ) |
| 223 | + completion_response = ChatCompletionResponse( |
| 224 | + model="test-model", |
| 225 | + choices=[choice], |
| 226 | + usage=UsageInfo(prompt_tokens=1, total_tokens=2, completion_tokens=1), |
| 227 | + ) |
| 228 | + scores = [0.5] |
| 229 | + detection_type = "risk" |
| 230 | + response = (completion_response, scores, detection_type) |
| 231 | + |
| 232 | + # Patch the process_chat_completion_with_scores to return response |
| 233 | + with patch( |
| 234 | + "vllm_detector_adapter.generative_detectors.base.ChatCompletionDetectionBase.process_chat_completion_with_scores", |
| 235 | + return_value=response, |
| 236 | + ): |
| 237 | + result = asyncio.run(base_instance.content_analysis(content_request)) |
| 238 | + |
| 239 | + assert isinstance(result, ErrorResponse) |
| 240 | + assert result.type == "BadRequestError" |
| 241 | + assert "does not have content" in result.message |
0 commit comments