Skip to content

Commit 830896b

Browse files
authored
Merge pull request #74 from swith004/content_detection_response_validation
Add instance validation check for ContentDetectionResponseObject for new_result in process, and empty content response from detector
2 parents 7e90171 + a47970c commit 830896b

File tree

5 files changed

+87
-3
lines changed

5 files changed

+87
-3
lines changed

tests/generative_detectors/test_base.py

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,10 @@
2121
import pytest_asyncio
2222

2323
# Local
24-
from vllm_detector_adapter.generative_detectors.base import ChatCompletionDetectionBase
24+
from vllm_detector_adapter.generative_detectors.base import (
25+
ChatCompletionDetectionBase,
26+
ErrorResponse,
27+
)
2528
from vllm_detector_adapter.protocol import (
2629
ContentsDetectionRequest,
2730
ContentsDetectionResponse,
@@ -204,3 +207,35 @@ def test_content_analysis_success(detection_base, completion_response):
204207
assert detections[1][0]["start"] == 0
205208
assert detections[1][0]["end"] == len(content_request.contents[1])
206209
assert detections[1][0]["metadata"] == {}
210+
211+
212+
def test_content_analysis_errorresponse_verification(detection_base):
213+
"""Test that content_analysis properly propagates an ErrorResponse when a choice has empty content."""
214+
base_instance = asyncio.run(detection_base)
215+
content_request = ContentsDetectionRequest(contents=["Where do I find geese?"])
216+
217+
# Simulate what the model would produce for a request: empty content triggers error
218+
choice = ChatCompletionResponseChoice(
219+
index=0,
220+
message=ChatMessage(role="assistant", content=""),
221+
logprobs=ChatCompletionLogProbs(content=[]),
222+
)
223+
completion_response = ChatCompletionResponse(
224+
model="test-model",
225+
choices=[choice],
226+
usage=UsageInfo(prompt_tokens=1, total_tokens=2, completion_tokens=1),
227+
)
228+
scores = [0.5]
229+
detection_type = "risk"
230+
response = (completion_response, scores, detection_type)
231+
232+
# Patch the process_chat_completion_with_scores to return response
233+
with patch(
234+
"vllm_detector_adapter.generative_detectors.base.ChatCompletionDetectionBase.process_chat_completion_with_scores",
235+
return_value=response,
236+
):
237+
result = asyncio.run(base_instance.content_analysis(content_request))
238+
239+
assert isinstance(result, ErrorResponse)
240+
assert result.type == "BadRequestError"
241+
assert "does not have content" in result.message

tests/test_protocol.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Standard
22
from http import HTTPStatus
3+
from unittest.mock import patch
34

45
# Third Party
56
from vllm.entrypoints.openai.protocol import (
@@ -357,3 +358,36 @@ def test_response_from_completion_response_missing_content():
357358
in detection_response.message
358359
)
359360
assert detection_response.code == HTTPStatus.BAD_REQUEST.value
361+
362+
363+
def test_response_from_empty_string_content_detection():
364+
choice_content_0 = ChatCompletionResponseChoice(
365+
index=0,
366+
message=ChatMessage(
367+
role="assistant",
368+
content="",
369+
),
370+
)
371+
chat_response_0 = ChatCompletionResponse(
372+
model=MODEL_NAME,
373+
choices=[choice_content_0],
374+
usage=UsageInfo(prompt_tokens=136, total_tokens=140, completion_tokens=4),
375+
)
376+
377+
contents = ["sample sentence 1"]
378+
# scores for each content is a list of scores (for multi-label)
379+
scores = [[0.9]]
380+
detection_type = "risk"
381+
382+
detection_response = ContentsDetectionResponse.from_chat_completion_response(
383+
[
384+
(chat_response_0, scores[0], detection_type),
385+
],
386+
contents,
387+
)
388+
assert type(detection_response) == ErrorResponse
389+
assert (
390+
"Choice 0 from chat completion does not have content"
391+
in detection_response.message
392+
)
393+
assert detection_response.code == HTTPStatus.BAD_REQUEST.value

vllm_detector_adapter/generative_detectors/base.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -447,6 +447,13 @@ async def content_analysis(
447447
)
448448
)
449449

450+
# Verify whether the new_result is the correct is an errorresponse, and if so, return the errorresponse
451+
if isinstance(new_result, ErrorResponse):
452+
logger.debug(
453+
f"[content_analysis] ErrorResponse returned: {repr(new_result)}"
454+
)
455+
return new_result
456+
450457
processed_result.append(new_result)
451458

452459
return ContentsDetectionResponse(root=processed_result)

vllm_detector_adapter/generative_detectors/llama_guard.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,13 @@ async def content_analysis(
195195
)
196196
)
197197

198+
# Verify whether the new_result is the correct is an errorresponse, and if so, return the errorresponse
199+
if isinstance(new_result, ErrorResponse):
200+
logger.debug(
201+
f"[content_analysis] ErrorResponse returned: {repr(new_result)}"
202+
)
203+
return new_result
204+
198205
processed_result.append(new_result)
199206

200207
return ContentsDetectionResponse(root=processed_result)

vllm_detector_adapter/protocol.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ def from_chat_completion_response(
8282
# NOTE: for providing spans, we currently consider entire generated text as a span.
8383
# This is because, at the time of writing, the generative guardrail models does not
8484
# provide specific information about input text, which can be used to deduce spans.
85-
if content and isinstance(content, str):
85+
if isinstance(content, str) and content.strip():
8686
response_object = ContentsDetectionResponseObject(
8787
detection_type=detection_type,
8888
detection=content.strip(),
@@ -93,6 +93,7 @@ def from_chat_completion_response(
9393
metadata=metadata_per_choice[i] if metadata_per_choice else {},
9494
).model_dump()
9595
detection_responses.append(response_object)
96+
9697
else:
9798
# This case should be unlikely but we handle it since a detection
9899
# can't be returned without the content
@@ -335,7 +336,7 @@ def from_chat_completion_response(
335336
detection_responses = []
336337
for i, choice in enumerate(response.choices):
337338
content = choice.message.content
338-
if content and isinstance(content, str):
339+
if isinstance(content, str) and content.strip():
339340
response_object = DetectionResponseObject(
340341
detection_type=detection_type,
341342
detection=content.strip(),

0 commit comments

Comments
 (0)