diff --git a/tests/generative_detectors/test_granite_guardian.py b/tests/generative_detectors/test_granite_guardian.py index e753e38..092bc5c 100644 --- a/tests/generative_detectors/test_granite_guardian.py +++ b/tests/generative_detectors/test_granite_guardian.py @@ -6,6 +6,7 @@ import asyncio # Third Party +from jinja2.exceptions import TemplateError, UndefinedError from vllm.config import MultiModalConfig from vllm.entrypoints.openai.protocol import ( ChatCompletionLogProb, @@ -415,3 +416,41 @@ def test_chat_detection_errors_on_stream(granite_guardian_detection): assert type(detection_response) == ErrorResponse assert detection_response.code == HTTPStatus.BAD_REQUEST.value assert "streaming is not supported" in detection_response.message + + +def test_chat_detection_errors_on_jinja_template_error(granite_guardian_detection): + granite_guardian_detection_instance = asyncio.run(granite_guardian_detection) + chat_request = ChatDetectionRequest( + messages=[ + DetectionChatMessageParam(role="user", content="How do I pick a lock?") + ], + ) + with patch( + "vllm_detector_adapter.generative_detectors.granite_guardian.GraniteGuardian.create_chat_completion", + side_effect=TemplateError(), + ): + detection_response = asyncio.run( + granite_guardian_detection_instance.chat(chat_request) + ) + assert type(detection_response) == ErrorResponse + assert detection_response.code == HTTPStatus.BAD_REQUEST.value + assert "Template error" in detection_response.message + + +def test_chat_detection_errors_on_undefined_jinja_error(granite_guardian_detection): + granite_guardian_detection_instance = asyncio.run(granite_guardian_detection) + chat_request = ChatDetectionRequest( + messages=[ + DetectionChatMessageParam(role="user", content="How do I pick a lock?") + ], + ) + with patch( + "vllm_detector_adapter.generative_detectors.granite_guardian.GraniteGuardian.create_chat_completion", + side_effect=UndefinedError(), # class of TemplateError + ): + detection_response = asyncio.run( + granite_guardian_detection_instance.chat(chat_request) + ) + assert type(detection_response) == ErrorResponse + assert detection_response.code == HTTPStatus.BAD_REQUEST.value + assert "Template error" in detection_response.message diff --git a/vllm_detector_adapter/generative_detectors/base.py b/vllm_detector_adapter/generative_detectors/base.py index 9a60717..ed6ef21 100644 --- a/vllm_detector_adapter/generative_detectors/base.py +++ b/vllm_detector_adapter/generative_detectors/base.py @@ -7,6 +7,7 @@ # Third Party from fastapi import Request +from jinja2.exceptions import TemplateError from vllm.entrypoints.openai.protocol import ChatCompletionResponse, ErrorResponse from vllm.entrypoints.openai.serving_chat import OpenAIServingChat import jinja2 @@ -157,9 +158,24 @@ async def process_chat_completion_with_scores( logger.debug("Request to chat completion: %s", chat_completion_request) # Call chat completion - chat_response = await self.create_chat_completion( - chat_completion_request, raw_request - ) + try: + chat_response = await self.create_chat_completion( + chat_completion_request, raw_request + ) + except TemplateError as e: + # Propagate template errors including those from raise_exception in the chat_template. + # UndefinedError, a subclass of TemplateError, can happen due to a variety of reasons - + # e.g. for Granite Guardian it is not limited but including the following + # for a particular risk definition: unexpected number of messages, unexpected + # ordering of messages, unexpected roles used for particular messages. + # Users _may_ be able to correct some of these errors by changing the input + # but the error message may not be directly user-comprehensible + chat_response = ErrorResponse( + message=e.message or "Template error", + type="BadRequestError", + code=HTTPStatus.BAD_REQUEST.value, + ) + logger.debug("Raw chat completion response: %s", chat_response) if isinstance(chat_response, ErrorResponse): # Propagate chat completion errors directly