Skip to content
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 99 additions & 0 deletions tests/generative_detectors/test_base.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,31 @@
# Standard
from dataclasses import dataclass
from typing import Optional
from unittest.mock import patch
import asyncio

# Third Party
from vllm.config import MultiModalConfig
from vllm.entrypoints.openai.protocol import (
ChatCompletionLogProb,
ChatCompletionLogProbs,
ChatCompletionLogProbsContent,
ChatCompletionResponse,
ChatCompletionResponseChoice,
ChatMessage,
UsageInfo,
)
from vllm.entrypoints.openai.serving_models import BaseModelPath, OpenAIServingModels
import jinja2
import pytest
import pytest_asyncio

# Local
from vllm_detector_adapter.generative_detectors.base import ChatCompletionDetectionBase
from vllm_detector_adapter.protocol import (
ContentsDetectionRequest,
ContentsDetectionResponse,
)

MODEL_NAME = "openai-community/gpt2"
CHAT_TEMPLATE = "Dummy chat template for testing {}"
Expand Down Expand Up @@ -82,6 +97,55 @@ async def detection_base():
return _async_serving_detection_completion_init()


@pytest.fixture(scope="module")
def completion_response():
log_probs_content_no = ChatCompletionLogProbsContent(
token="no",
logprob=-0.0013,
# 5 logprobs requested for scoring, skipping bytes for conciseness
top_logprobs=[
ChatCompletionLogProb(token="no", logprob=-0.053),
ChatCompletionLogProb(token="0", logprob=-6.61),
ChatCompletionLogProb(token="1", logprob=-16.90),
ChatCompletionLogProb(token="2", logprob=-17.39),
ChatCompletionLogProb(token="3", logprob=-17.61),
],
)
log_probs_content_yes = ChatCompletionLogProbsContent(
token="yes",
logprob=-0.0013,
# 5 logprobs requested for scoring, skipping bytes for conciseness
top_logprobs=[
ChatCompletionLogProb(token="yes", logprob=-0.0013),
ChatCompletionLogProb(token="0", logprob=-6.61),
ChatCompletionLogProb(token="1", logprob=-16.90),
ChatCompletionLogProb(token="2", logprob=-17.39),
ChatCompletionLogProb(token="3", logprob=-17.61),
],
)
choice_0 = ChatCompletionResponseChoice(
index=0,
message=ChatMessage(
role="assistant",
content="no",
),
logprobs=ChatCompletionLogProbs(content=[log_probs_content_no]),
)
choice_1 = ChatCompletionResponseChoice(
index=1,
message=ChatMessage(
role="assistant",
content="yes",
),
logprobs=ChatCompletionLogProbs(content=[log_probs_content_yes]),
)
yield ChatCompletionResponse(
model=MODEL_NAME,
choices=[choice_0, choice_1],
usage=UsageInfo(prompt_tokens=200, total_tokens=206, completion_tokens=6),
)


### Tests #####################################################################


Expand All @@ -97,3 +161,38 @@ def test_async_serving_detection_completion_init(detection_base):
output_template = detection_completion.output_template
assert type(output_template) == jinja2.environment.Template
assert output_template.render(({"text": "moose"})) == "bye moose"


def test_content_analysis_success(detection_base, completion_response):
base_instance = asyncio.run(detection_base)

content_request = ContentsDetectionRequest(
contents=["Where do I find geese?", "You could go to Canada"]
)

scores = [0.9, 0.1, 0.21, 0.54, 0.33]
response = (completion_response, scores, "risk")
with patch(
"vllm_detector_adapter.generative_detectors.base.ChatCompletionDetectionBase.process_chat_completion_with_scores",
return_value=response,
):
result = asyncio.run(base_instance.content_analysis(content_request))
assert isinstance(result, ContentsDetectionResponse)
detections = result.model_dump()
assert len(detections) == 2
# For first content
assert detections[0][0]["detection"] == "no"
assert detections[0][0]["score"] == 0.9
assert detections[0][0]["start"] == 0
assert detections[0][0]["end"] == len(content_request.contents[0])
# 2nd choice as 2nd label
assert detections[0][1]["detection"] == "yes"
assert detections[0][1]["score"] == 0.1
assert detections[0][1]["start"] == 0
assert detections[0][1]["end"] == len(content_request.contents[0])
# For 2nd content, we are only testing 1st detection for simplicity
# Note: detection is same, because of how mock is working.
assert detections[1][0]["detection"] == "no"
assert detections[1][0]["score"] == 0.9
assert detections[1][0]["start"] == 0
assert detections[1][0]["end"] == len(content_request.contents[1])
88 changes: 88 additions & 0 deletions tests/generative_detectors/test_llama_guard.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
from vllm_detector_adapter.generative_detectors.llama_guard import LlamaGuard
from vllm_detector_adapter.protocol import (
ChatDetectionRequest,
ContentsDetectionRequest,
ContentsDetectionResponse,
ContextAnalysisRequest,
DetectionChatMessageParam,
DetectionResponse,
Expand Down Expand Up @@ -213,3 +215,89 @@ def test_context_analyze(llama_guard_detection):
)
assert type(response) == ErrorResponse
assert response.code == HTTPStatus.NOT_IMPLEMENTED


def test_post_process_content_splits_unsafe_categories(llama_guard_detection):
unsafe_message = "\n\nunsafe\nS2,S3"
responses = ChatCompletionResponse(
model="foo",
usage=UsageInfo(prompt_tokens=1, total_tokens=1),
choices=[
ChatCompletionResponseChoice(
index=1,
message=ChatMessage(
content=unsafe_message,
role=" assistant",
),
)
],
)
unsafe_score = 0.99
llama_guard_detection_instance = asyncio.run(llama_guard_detection)
# NOTE: we are testing private function here
(
responses,
scores,
_,
) = llama_guard_detection_instance._LlamaGuard__post_process_result(
responses, [unsafe_score], "risk"
)
assert isinstance(responses, ChatCompletionResponse)
assert responses.choices[0].message.content == "unsafe"
assert scores[0] == unsafe_score
assert len(responses.choices) == 1


def test_post_process_content_works_for_safe(llama_guard_detection):
safe_message = "safe"
responses = ChatCompletionResponse(
model="foo",
usage=UsageInfo(prompt_tokens=1, total_tokens=1),
choices=[
ChatCompletionResponseChoice(
index=1,
message=ChatMessage(
content=safe_message,
role=" assistant",
),
)
],
)
safe_score = 0.99
llama_guard_detection_instance = asyncio.run(llama_guard_detection)
# NOTE: we are testing private function here
(
responses,
scores,
_,
) = llama_guard_detection_instance._LlamaGuard__post_process_result(
responses, [safe_message], "risk"
)
assert isinstance(responses, ChatCompletionResponse)
assert len(responses.choices) == 1
assert responses.choices[0].message.content == "safe"
assert scores[0] == safe_message


def test_content_detection_with_llama_guard(
llama_guard_detection, llama_guard_completion_response
):
llama_guard_detection_instance = asyncio.run(llama_guard_detection)
content_request = ContentsDetectionRequest(
contents=["Where do I find geese?", "You could go to Canada"]
)
with patch(
"vllm_detector_adapter.generative_detectors.llama_guard.LlamaGuard.create_chat_completion",
return_value=llama_guard_completion_response,
):
detection_response = asyncio.run(
llama_guard_detection_instance.content_analysis(content_request)
)
assert type(detection_response) == ContentsDetectionResponse
detections = detection_response.model_dump()
assert len(detections) == 2 # 2 contents in the request
assert len(detections[0]) == 2 # 2 choices
detection_0 = detections[0][0] # for 1st text in request
assert detection_0["detection"] == "safe"
assert detection_0["detection_type"] == "risk"
assert pytest.approx(detection_0["score"]) == 0.001346767
150 changes: 150 additions & 0 deletions tests/test_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
# Local
from vllm_detector_adapter.protocol import (
ChatDetectionRequest,
ContentsDetectionResponse,
ContentsDetectionResponseObject,
DetectionChatMessageParam,
DetectionResponse,
)
Expand Down Expand Up @@ -129,3 +131,151 @@ def test_response_from_completion_response_missing_content():
in detection_response.message
)
assert detection_response.code == HTTPStatus.BAD_REQUEST.value


def test_response_from_single_content_detection_response():
choice = ChatCompletionResponseChoice(
index=0,
message=ChatMessage(
role="assistant",
content=" moose",
),
)
chat_response = ChatCompletionResponse(
model=MODEL_NAME,
choices=[choice],
usage=UsageInfo(prompt_tokens=136, total_tokens=140, completion_tokens=4),
)
contents = ["sample sentence"]
scores = [0.9]
detection_type = "risk"

expected_response = ContentsDetectionResponse(
root=[
[
ContentsDetectionResponseObject(
start=0,
end=len(contents[0]),
score=scores[0],
text=contents[0],
detection="moose",
detection_type=detection_type,
)
]
]
)
detection_response = ContentsDetectionResponse.from_chat_completion_response(
[(chat_response, scores, detection_type)], contents
)
assert isinstance(detection_response, ContentsDetectionResponse)
assert detection_response == expected_response


def test_response_from_multi_contents_detection_response():
choice_content_0 = ChatCompletionResponseChoice(
index=0,
message=ChatMessage(
role="assistant",
content=" moose",
),
)
choice_content_1 = ChatCompletionResponseChoice(
index=0,
message=ChatMessage(
role="assistant",
content=" goose",
),
)
chat_response_0 = ChatCompletionResponse(
model=MODEL_NAME,
choices=[choice_content_0],
usage=UsageInfo(prompt_tokens=136, total_tokens=140, completion_tokens=4),
)
chat_response_1 = ChatCompletionResponse(
model=MODEL_NAME,
choices=[choice_content_1],
usage=UsageInfo(prompt_tokens=136, total_tokens=140, completion_tokens=4),
)

contents = ["sample sentence 1", "sample sentence 2"]
# scores for each content is a list of scores (for multi-label)
scores = [[0.9], [0.6]]
detection_type = "risk"

content_response_0 = [
ContentsDetectionResponseObject(
start=0,
end=len(contents[0]),
score=scores[0][0],
text=contents[0],
detection="moose",
detection_type=detection_type,
)
]
content_response_1 = [
ContentsDetectionResponseObject(
start=0,
end=len(contents[1]),
score=scores[1][0],
text=contents[1],
detection="goose",
detection_type=detection_type,
)
]
expected_response = ContentsDetectionResponse(
root=[content_response_0, content_response_1]
)
detection_response = ContentsDetectionResponse.from_chat_completion_response(
[
(chat_response_0, scores[0], detection_type),
(chat_response_1, scores[1], detection_type),
],
contents,
)
assert isinstance(detection_response, ContentsDetectionResponse)
assert detection_response == expected_response


def test_response_from_single_content_detection_missing_content():
choice_content_0 = ChatCompletionResponseChoice(
index=0,
message=ChatMessage(
role="assistant",
),
)
choice_content_1 = ChatCompletionResponseChoice(
index=0,
message=ChatMessage(
role="assistant",
content=" goose",
),
)
chat_response_0 = ChatCompletionResponse(
model=MODEL_NAME,
choices=[choice_content_0],
usage=UsageInfo(prompt_tokens=136, total_tokens=140, completion_tokens=4),
)
chat_response_1 = ChatCompletionResponse(
model=MODEL_NAME,
choices=[choice_content_1],
usage=UsageInfo(prompt_tokens=136, total_tokens=140, completion_tokens=4),
)

contents = ["sample sentence 1", "sample sentence 2"]
# scores for each content is a list of scores (for multi-label)
scores = [[0.9], [0.6]]
detection_type = "risk"

detection_response = ContentsDetectionResponse.from_chat_completion_response(
[
(chat_response_0, scores[0], detection_type),
(chat_response_1, scores[1], detection_type),
],
contents,
)
assert type(detection_response) == ErrorResponse
assert (
"Choice 0 from chat completion does not have content"
in detection_response.message
)
assert detection_response.code == HTTPStatus.BAD_REQUEST.value
Loading