Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions tests/generative_detectors/test_granite_guardian.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import asyncio

# Third Party
from jinja2.exceptions import TemplateError, UndefinedError
from vllm.config import MultiModalConfig
from vllm.entrypoints.openai.protocol import (
ChatCompletionLogProb,
Expand Down Expand Up @@ -415,3 +416,41 @@ def test_chat_detection_errors_on_stream(granite_guardian_detection):
assert type(detection_response) == ErrorResponse
assert detection_response.code == HTTPStatus.BAD_REQUEST.value
assert "streaming is not supported" in detection_response.message


def test_chat_detection_errors_on_jinja_template_error(granite_guardian_detection):
granite_guardian_detection_instance = asyncio.run(granite_guardian_detection)
chat_request = ChatDetectionRequest(
messages=[
DetectionChatMessageParam(role="user", content="How do I pick a lock?")
],
)
with patch(
"vllm_detector_adapter.generative_detectors.granite_guardian.GraniteGuardian.create_chat_completion",
side_effect=TemplateError(),
):
detection_response = asyncio.run(
granite_guardian_detection_instance.chat(chat_request)
)
assert type(detection_response) == ErrorResponse
assert detection_response.code == HTTPStatus.BAD_REQUEST.value
assert "Template error" in detection_response.message


def test_chat_detection_errors_on_undefined_jinja_error(granite_guardian_detection):
granite_guardian_detection_instance = asyncio.run(granite_guardian_detection)
chat_request = ChatDetectionRequest(
messages=[
DetectionChatMessageParam(role="user", content="How do I pick a lock?")
],
)
with patch(
"vllm_detector_adapter.generative_detectors.granite_guardian.GraniteGuardian.create_chat_completion",
side_effect=UndefinedError(), # class of TemplateError
):
detection_response = asyncio.run(
granite_guardian_detection_instance.chat(chat_request)
)
assert type(detection_response) == ErrorResponse
assert detection_response.code == HTTPStatus.BAD_REQUEST.value
assert "Template error" in detection_response.message
22 changes: 19 additions & 3 deletions vllm_detector_adapter/generative_detectors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

# Third Party
from fastapi import Request
from jinja2.exceptions import TemplateError
from vllm.entrypoints.openai.protocol import ChatCompletionResponse, ErrorResponse
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
import jinja2
Expand Down Expand Up @@ -157,9 +158,24 @@ async def process_chat_completion_with_scores(
logger.debug("Request to chat completion: %s", chat_completion_request)

# Call chat completion
chat_response = await self.create_chat_completion(
chat_completion_request, raw_request
)
try:
chat_response = await self.create_chat_completion(
chat_completion_request, raw_request
)
except TemplateError as e:
# Propagate template errors including those from raise_exception in the chat_template.
# UndefinedError, a subclass of TemplateError, can happen due to a variety of reasons -
# e.g. for Granite Guardian it is not limited but including the following
# for a particular risk definition: unexpected number of messages, unexpected
# ordering of messages, unexpected roles used for particular messages.
# Users _may_ be able to correct some of these errors by changing the input
# but the error message may not be directly user-comprehensible
chat_response = ErrorResponse(
message=e.message or "Template error",
type="BadRequestError",
code=HTTPStatus.BAD_REQUEST.value,
)

logger.debug("Raw chat completion response: %s", chat_response)
if isinstance(chat_response, ErrorResponse):
# Propagate chat completion errors directly
Expand Down