Skip to content

Commit d70b5cf

Browse files
committed
:white_check_marks: Push test for base class and fix response massaging issue
Signed-off-by: Gaurav-Kumbhat <Gaurav.Kumbhat@ibm.com>
1 parent 21c5103 commit d70b5cf

File tree

2 files changed

+103
-1
lines changed

2 files changed

+103
-1
lines changed

tests/generative_detectors/test_base.py

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,31 @@
11
# Standard
22
from dataclasses import dataclass
33
from typing import Optional
4+
from unittest.mock import patch
45
import asyncio
56

67
# Third Party
78
from vllm.config import MultiModalConfig
9+
from vllm.entrypoints.openai.protocol import (
10+
ChatCompletionLogProb,
11+
ChatCompletionLogProbs,
12+
ChatCompletionLogProbsContent,
13+
ChatCompletionResponse,
14+
ChatCompletionResponseChoice,
15+
ChatMessage,
16+
UsageInfo,
17+
)
818
from vllm.entrypoints.openai.serving_models import BaseModelPath, OpenAIServingModels
919
import jinja2
20+
import pytest
1021
import pytest_asyncio
1122

1223
# Local
1324
from vllm_detector_adapter.generative_detectors.base import ChatCompletionDetectionBase
25+
from vllm_detector_adapter.protocol import (
26+
ContentsDetectionRequest,
27+
ContentsDetectionResponse,
28+
)
1429

1530
MODEL_NAME = "openai-community/gpt2"
1631
CHAT_TEMPLATE = "Dummy chat template for testing {}"
@@ -82,6 +97,55 @@ async def detection_base():
8297
return _async_serving_detection_completion_init()
8398

8499

100+
@pytest.fixture(scope="module")
101+
def granite_completion_response():
102+
log_probs_content_no = ChatCompletionLogProbsContent(
103+
token="no",
104+
logprob=-0.0013,
105+
# 5 logprobs requested for scoring, skipping bytes for conciseness
106+
top_logprobs=[
107+
ChatCompletionLogProb(token="no", logprob=-0.053),
108+
ChatCompletionLogProb(token="0", logprob=-6.61),
109+
ChatCompletionLogProb(token="1", logprob=-16.90),
110+
ChatCompletionLogProb(token="2", logprob=-17.39),
111+
ChatCompletionLogProb(token="3", logprob=-17.61),
112+
],
113+
)
114+
log_probs_content_yes = ChatCompletionLogProbsContent(
115+
token="yes",
116+
logprob=-0.0013,
117+
# 5 logprobs requested for scoring, skipping bytes for conciseness
118+
top_logprobs=[
119+
ChatCompletionLogProb(token="yes", logprob=-0.0013),
120+
ChatCompletionLogProb(token="0", logprob=-6.61),
121+
ChatCompletionLogProb(token="1", logprob=-16.90),
122+
ChatCompletionLogProb(token="2", logprob=-17.39),
123+
ChatCompletionLogProb(token="3", logprob=-17.61),
124+
],
125+
)
126+
choice_0 = ChatCompletionResponseChoice(
127+
index=0,
128+
message=ChatMessage(
129+
role="assistant",
130+
content="no",
131+
),
132+
logprobs=ChatCompletionLogProbs(content=[log_probs_content_no]),
133+
)
134+
choice_1 = ChatCompletionResponseChoice(
135+
index=1,
136+
message=ChatMessage(
137+
role="assistant",
138+
content="yes",
139+
),
140+
logprobs=ChatCompletionLogProbs(content=[log_probs_content_yes]),
141+
)
142+
yield ChatCompletionResponse(
143+
model=MODEL_NAME,
144+
choices=[choice_0, choice_1],
145+
usage=UsageInfo(prompt_tokens=200, total_tokens=206, completion_tokens=6),
146+
)
147+
148+
85149
### Tests #####################################################################
86150

87151

@@ -97,3 +161,39 @@ def test_async_serving_detection_completion_init(detection_base):
97161
output_template = detection_completion.output_template
98162
assert type(output_template) == jinja2.environment.Template
99163
assert output_template.render(({"text": "moose"})) == "bye moose"
164+
165+
166+
def test_content_analysis_success(detection_base, granite_completion_response):
167+
base_instance = asyncio.run(detection_base)
168+
169+
content_request = ContentsDetectionRequest(
170+
contents=["Where do I find geese?", "You could go to Canada"]
171+
)
172+
173+
scores = [0.9, 0.1, 0.21, 0.54, 0.33]
174+
response = (granite_completion_response, scores, "risk")
175+
with patch(
176+
"vllm_detector_adapter.generative_detectors.base.ChatCompletionDetectionBase.process_chat_completion_with_scores",
177+
return_value=response,
178+
):
179+
result = asyncio.run(base_instance.content_analysis(content_request))
180+
assert isinstance(result, ContentsDetectionResponse)
181+
detections = result.model_dump()
182+
assert len(detections) == 2
183+
print(detections)
184+
# For first content
185+
assert detections[0][0]["detection"] == "no"
186+
assert detections[0][0]["score"] == 0.9
187+
assert detections[0][0]["start"] == 0
188+
assert detections[0][0]["end"] == len(content_request.contents[0])
189+
# 2nd choice as 2nd label
190+
assert detections[0][1]["detection"] == "yes"
191+
assert detections[0][1]["score"] == 0.1
192+
assert detections[0][1]["start"] == 0
193+
assert detections[0][1]["end"] == len(content_request.contents[0])
194+
# For 2nd content, we are only testing 1st detection for simplicity
195+
# Note: detection is same, because of how mock is working.
196+
assert detections[1][0]["detection"] == "no"
197+
assert detections[1][0]["score"] == 0.9
198+
assert detections[1][0]["start"] == 0
199+
assert detections[1][0]["end"] == len(content_request.contents[1])

vllm_detector_adapter/generative_detectors/base.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -333,4 +333,6 @@ async def content_analysis(
333333
if isinstance(result, ErrorResponse):
334334
return result
335335

336-
return ContentsDetectionResponse.from_chat_completion_response(results)
336+
return ContentsDetectionResponse.from_chat_completion_response(
337+
results, request.contents
338+
)

0 commit comments

Comments
 (0)