Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ classifiers = [
]

dependencies = [
"vllm>=0.7.0"
"vllm @ git+https://github.com/vllm-project/[email protected] ; sys_platform == 'darwin'",
"vllm>=0.7.0 ; sys_platform != 'darwin'",
]

[project.optional-dependencies]
Expand Down
181 changes: 174 additions & 7 deletions tests/generative_detectors/test_granite_guardian.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
ChatCompletionLogProb,
ChatCompletionLogProbs,
ChatCompletionLogProbsContent,
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionResponseChoice,
ChatMessage,
Expand All @@ -25,14 +26,18 @@
from vllm_detector_adapter.generative_detectors.granite_guardian import GraniteGuardian
from vllm_detector_adapter.protocol import (
ChatDetectionRequest,
ChatDetectionResponse,
ContextAnalysisRequest,
DetectionChatMessageParam,
DetectionResponse,
)

MODEL_NAME = "ibm-granite/granite-guardian" # Example granite-guardian model
CHAT_TEMPLATE = "Dummy chat template for testing {}"
BASE_MODEL_PATHS = [BaseModelPath(name=MODEL_NAME, model_path=MODEL_NAME)]

CONTENT = "Where do I find geese?"
CONTEXT_DOC = "Geese can be found in lakes, ponds, and rivers"


@dataclass
class MockTokenizer:
Expand Down Expand Up @@ -155,8 +160,8 @@ def granite_guardian_completion_response():
### Tests #####################################################################


def test_preprocess_with_detector_params(granite_guardian_detection):
llama_guard_detection_instance = asyncio.run(granite_guardian_detection)
def test_preprocess_chat_request_with_detector_params(granite_guardian_detection):
granite_guardian_detection_instance = asyncio.run(granite_guardian_detection)
# Make sure with addition of allowed params like risk_name and risk_definition,
# extra params do not get added to guardian_config
detector_params = {
Expand All @@ -172,7 +177,9 @@ def test_preprocess_with_detector_params(granite_guardian_detection):
],
detector_params=detector_params,
)
processed_request = llama_guard_detection_instance.preprocess(initial_request)
processed_request = granite_guardian_detection_instance.preprocess_chat_request(
initial_request
)
assert type(processed_request) == ChatDetectionRequest
# Processed request should not have these extra params
assert "risk_name" not in processed_request.detector_params
Expand All @@ -190,15 +197,175 @@ def test_preprocess_with_detector_params(granite_guardian_detection):
}


def test_request_to_chat_completion_request_prompt_analysis(granite_guardian_detection):
granite_guardian_detection_instance = asyncio.run(granite_guardian_detection)
context_request = ContextAnalysisRequest(
content=CONTENT,
context_type="docs",
context=[
"extra!",
CONTEXT_DOC,
], # additionally test that only last context is used
detector_params={
"n": 2,
"chat_template_kwargs": {
"guardian_config": {"risk_name": "context_relevance"}
},
},
)
chat_request = (
granite_guardian_detection_instance.request_to_chat_completion_request(
context_request, MODEL_NAME
)
)
assert type(chat_request) == ChatCompletionRequest
assert len(chat_request.messages) == 2
assert chat_request.messages[0]["role"] == "user"
assert chat_request.messages[0]["content"] == CONTENT
assert chat_request.messages[1]["role"] == "context"
assert chat_request.messages[1]["content"] == CONTEXT_DOC
assert chat_request.model == MODEL_NAME
# detector_paramas
assert chat_request.n == 2
assert (
chat_request.chat_template_kwargs["guardian_config"]["risk_name"]
== "context_relevance"
)


def test_request_to_chat_completion_request_reponse_analysis(
granite_guardian_detection,
):
granite_guardian_detection_instance = asyncio.run(granite_guardian_detection)
context_request = ContextAnalysisRequest(
content=CONTENT,
context_type="docs",
context=[CONTEXT_DOC],
detector_params={
"n": 3,
"chat_template_kwargs": {"guardian_config": {"risk_name": "groundedness"}},
},
)
chat_request = (
granite_guardian_detection_instance.request_to_chat_completion_request(
context_request, MODEL_NAME
)
)
assert type(chat_request) == ChatCompletionRequest
assert chat_request.messages[0]["role"] == "context"
assert chat_request.messages[0]["content"] == CONTEXT_DOC
assert chat_request.messages[1]["role"] == "assistant"
assert chat_request.messages[1]["content"] == CONTENT
assert chat_request.model == MODEL_NAME
# detector_paramas
assert chat_request.n == 3
assert (
chat_request.chat_template_kwargs["guardian_config"]["risk_name"]
== "groundedness"
)


def test_request_to_chat_completion_request_empty_kwargs(granite_guardian_detection):
granite_guardian_detection_instance = asyncio.run(granite_guardian_detection)
context_request = ContextAnalysisRequest(
content=CONTENT,
context_type="docs",
context=[CONTEXT_DOC],
detector_params={"n": 2, "chat_template_kwargs": {}}, # no guardian config
)
chat_request = (
granite_guardian_detection_instance.request_to_chat_completion_request(
context_request, MODEL_NAME
)
)
assert type(chat_request) == ErrorResponse
assert chat_request.code == HTTPStatus.BAD_REQUEST
assert "No risk_name for context analysis" in chat_request.message


def test_request_to_chat_completion_request_empty_guardian_config(
granite_guardian_detection,
):
granite_guardian_detection_instance = asyncio.run(granite_guardian_detection)
context_request = ContextAnalysisRequest(
content=CONTENT,
context_type="docs",
context=[CONTEXT_DOC],
detector_params={"n": 2, "chat_template_kwargs": {"guardian_config": {}}},
)
chat_request = (
granite_guardian_detection_instance.request_to_chat_completion_request(
context_request, MODEL_NAME
)
)
assert type(chat_request) == ErrorResponse
assert chat_request.code == HTTPStatus.BAD_REQUEST
assert "No risk_name for context analysis" in chat_request.message


def test_request_to_chat_completion_request_unsupported_risk_name(
granite_guardian_detection,
):
granite_guardian_detection_instance = asyncio.run(granite_guardian_detection)
context_request = ContextAnalysisRequest(
content=CONTENT,
context_type="docs",
context=[CONTEXT_DOC],
detector_params={
"n": 2,
"chat_template_kwargs": {"guardian_config": {"risk_name": "foo"}},
},
)
chat_request = (
granite_guardian_detection_instance.request_to_chat_completion_request(
context_request, MODEL_NAME
)
)
assert type(chat_request) == ErrorResponse
assert chat_request.code == HTTPStatus.BAD_REQUEST
assert (
"risk_name foo is not compatible with context analysis" in chat_request.message
)


def test_context_analyze(
granite_guardian_detection, granite_guardian_completion_response
):
granite_guardian_detection_instance = asyncio.run(granite_guardian_detection)
context_request = ContextAnalysisRequest(
content=CONTENT,
context_type="docs",
context=[CONTEXT_DOC],
detector_params={
"n": 2,
"chat_template_kwargs": {"guardian_config": {"risk_name": "groundedness"}},
},
)
with patch(
"vllm_detector_adapter.generative_detectors.granite_guardian.GraniteGuardian.create_chat_completion",
return_value=granite_guardian_completion_response,
):
detection_response = asyncio.run(
granite_guardian_detection_instance.context_analyze(context_request)
)
assert type(detection_response) == DetectionResponse
detections = detection_response.model_dump()
assert len(detections) == 2 # 2 choices
detection_0 = detections[0]
assert detection_0["detection"] == "Yes"
assert detection_0["detection_type"] == "risk"
assert pytest.approx(detection_0["score"]) == 1.0


# NOTE: currently these functions are basically just the base implementations,
# where safe/unsafe tokens are defined in the granite guardian class


def test_calculate_scores(
granite_guardian_detection, granite_guardian_completion_response
):
llama_guard_detection_instance = asyncio.run(granite_guardian_detection)
scores = llama_guard_detection_instance.calculate_scores(
granite_guardian_detection_instance = asyncio.run(granite_guardian_detection)
scores = granite_guardian_detection_instance.calculate_scores(
granite_guardian_completion_response
)
assert len(scores) == 2 # 2 choices
Expand All @@ -224,7 +391,7 @@ def test_chat_detection(
detection_response = asyncio.run(
granite_guardian_detection_instance.chat(chat_request)
)
assert type(detection_response) == ChatDetectionResponse
assert type(detection_response) == DetectionResponse
detections = detection_response.model_dump()
assert len(detections) == 2 # 2 choices
detection_0 = detections[0]
Expand Down
23 changes: 21 additions & 2 deletions tests/generative_detectors/test_llama_guard.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
ChatCompletionResponse,
ChatCompletionResponseChoice,
ChatMessage,
ErrorResponse,
UsageInfo,
)
from vllm.entrypoints.openai.serving_models import BaseModelPath, OpenAIServingModels
Expand All @@ -24,8 +25,9 @@
from vllm_detector_adapter.generative_detectors.llama_guard import LlamaGuard
from vllm_detector_adapter.protocol import (
ChatDetectionRequest,
ChatDetectionResponse,
ContextAnalysisRequest,
DetectionChatMessageParam,
DetectionResponse,
)

MODEL_NAME = "meta-llama/Llama-Guard-3-8B" # Example llama guard model
Expand Down Expand Up @@ -187,10 +189,27 @@ def test_chat_detection(llama_guard_detection, llama_guard_completion_response):
detection_response = asyncio.run(
llama_guard_detection_instance.chat(chat_request)
)
assert type(detection_response) == ChatDetectionResponse
assert type(detection_response) == DetectionResponse
detections = detection_response.model_dump()
assert len(detections) == 2 # 2 choices
detection_0 = detections[0]
assert detection_0["detection"] == "safe"
assert detection_0["detection_type"] == "risk"
assert pytest.approx(detection_0["score"]) == 0.001346767


def test_context_analyze(llama_guard_detection):
llama_guard_detection_instance = asyncio.run(llama_guard_detection)
content = "Where do I find geese?"
context_doc = "Geese can be found in lakes, ponds, and rivers"
context_request = ContextAnalysisRequest(
content=content,
context_type="docs",
context=[context_doc],
detector_params={"n": 2, "temperature": 0.3},
)
response = asyncio.run(
llama_guard_detection_instance.context_analyze(context_request)
)
assert type(response) == ErrorResponse
assert response.code == HTTPStatus.NOT_IMPLEMENTED
17 changes: 11 additions & 6 deletions tests/test_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,18 @@
# Local
from vllm_detector_adapter.protocol import (
ChatDetectionRequest,
ChatDetectionResponse,
DetectionChatMessageParam,
DetectionResponse,
)

MODEL_NAME = "org/model-name"

### Tests #####################################################################

#### Chat detection request tests

def test_detection_to_completion_request():

def test_chat_detection_to_completion_request():
chat_request = ChatDetectionRequest(
messages=[
DetectionChatMessageParam(
Expand All @@ -46,7 +48,7 @@ def test_detection_to_completion_request():
assert request.n == 3


def test_detection_to_completion_request_unknown_params():
def test_chat_detection_to_completion_request_unknown_params():
chat_request = ChatDetectionRequest(
messages=[
DetectionChatMessageParam(role="user", content="How do I search for moose?")
Expand All @@ -58,6 +60,9 @@ def test_detection_to_completion_request_unknown_params():
assert type(request) == ChatCompletionRequest


#### General response tests


def test_response_from_completion_response():
# Simplified response without logprobs since not needed for this method
choice_0 = ChatCompletionResponseChoice(
Expand All @@ -81,10 +86,10 @@ def test_response_from_completion_response():
)
scores = [0.3, 0.7]
detection_type = "type"
detection_response = ChatDetectionResponse.from_chat_completion_response(
detection_response = DetectionResponse.from_chat_completion_response(
response, scores, detection_type
)
assert type(detection_response) == ChatDetectionResponse
assert type(detection_response) == DetectionResponse
detections = detection_response.model_dump()
assert len(detections) == 2 # 2 choices
detection_0 = detections[0]
Expand Down Expand Up @@ -115,7 +120,7 @@ def test_response_from_completion_response_missing_content():
)
scores = [0.3, 0.7]
detection_type = "type"
detection_response = ChatDetectionResponse.from_chat_completion_response(
detection_response = DetectionResponse.from_chat_completion_response(
response, scores, detection_type
)
assert type(detection_response) == ErrorResponse
Expand Down
Loading