diff --git a/pyproject.toml b/pyproject.toml index 8daf8b2..87d43fe 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,7 +13,8 @@ classifiers = [ ] dependencies = [ - "vllm>=0.7.0" + "vllm @ git+https://github.com/vllm-project/vllm.git@v0.7.0 ; sys_platform == 'darwin'", + "vllm>=0.7.0 ; sys_platform != 'darwin'", ] [project.optional-dependencies] diff --git a/tests/generative_detectors/test_granite_guardian.py b/tests/generative_detectors/test_granite_guardian.py index 0e6343b..e063c65 100644 --- a/tests/generative_detectors/test_granite_guardian.py +++ b/tests/generative_detectors/test_granite_guardian.py @@ -11,6 +11,7 @@ ChatCompletionLogProb, ChatCompletionLogProbs, ChatCompletionLogProbsContent, + ChatCompletionRequest, ChatCompletionResponse, ChatCompletionResponseChoice, ChatMessage, @@ -25,14 +26,18 @@ from vllm_detector_adapter.generative_detectors.granite_guardian import GraniteGuardian from vllm_detector_adapter.protocol import ( ChatDetectionRequest, - ChatDetectionResponse, + ContextAnalysisRequest, DetectionChatMessageParam, + DetectionResponse, ) MODEL_NAME = "ibm-granite/granite-guardian" # Example granite-guardian model CHAT_TEMPLATE = "Dummy chat template for testing {}" BASE_MODEL_PATHS = [BaseModelPath(name=MODEL_NAME, model_path=MODEL_NAME)] +CONTENT = "Where do I find geese?" +CONTEXT_DOC = "Geese can be found in lakes, ponds, and rivers" + @dataclass class MockTokenizer: @@ -155,8 +160,8 @@ def granite_guardian_completion_response(): ### Tests ##################################################################### -def test_preprocess_with_detector_params(granite_guardian_detection): - llama_guard_detection_instance = asyncio.run(granite_guardian_detection) +def test_preprocess_chat_request_with_detector_params(granite_guardian_detection): + granite_guardian_detection_instance = asyncio.run(granite_guardian_detection) # Make sure with addition of allowed params like risk_name and risk_definition, # extra params do not get added to guardian_config detector_params = { @@ -172,7 +177,9 @@ def test_preprocess_with_detector_params(granite_guardian_detection): ], detector_params=detector_params, ) - processed_request = llama_guard_detection_instance.preprocess(initial_request) + processed_request = granite_guardian_detection_instance.preprocess_chat_request( + initial_request + ) assert type(processed_request) == ChatDetectionRequest # Processed request should not have these extra params assert "risk_name" not in processed_request.detector_params @@ -190,6 +197,166 @@ def test_preprocess_with_detector_params(granite_guardian_detection): } +def test_request_to_chat_completion_request_prompt_analysis(granite_guardian_detection): + granite_guardian_detection_instance = asyncio.run(granite_guardian_detection) + context_request = ContextAnalysisRequest( + content=CONTENT, + context_type="docs", + context=[ + "extra!", + CONTEXT_DOC, + ], # additionally test that only last context is used + detector_params={ + "n": 2, + "chat_template_kwargs": { + "guardian_config": {"risk_name": "context_relevance"} + }, + }, + ) + chat_request = ( + granite_guardian_detection_instance.request_to_chat_completion_request( + context_request, MODEL_NAME + ) + ) + assert type(chat_request) == ChatCompletionRequest + assert len(chat_request.messages) == 2 + assert chat_request.messages[0]["role"] == "user" + assert chat_request.messages[0]["content"] == CONTENT + assert chat_request.messages[1]["role"] == "context" + assert chat_request.messages[1]["content"] == CONTEXT_DOC + assert chat_request.model == MODEL_NAME + # detector_paramas + assert chat_request.n == 2 + assert ( + chat_request.chat_template_kwargs["guardian_config"]["risk_name"] + == "context_relevance" + ) + + +def test_request_to_chat_completion_request_reponse_analysis( + granite_guardian_detection, +): + granite_guardian_detection_instance = asyncio.run(granite_guardian_detection) + context_request = ContextAnalysisRequest( + content=CONTENT, + context_type="docs", + context=[CONTEXT_DOC], + detector_params={ + "n": 3, + "chat_template_kwargs": {"guardian_config": {"risk_name": "groundedness"}}, + }, + ) + chat_request = ( + granite_guardian_detection_instance.request_to_chat_completion_request( + context_request, MODEL_NAME + ) + ) + assert type(chat_request) == ChatCompletionRequest + assert chat_request.messages[0]["role"] == "context" + assert chat_request.messages[0]["content"] == CONTEXT_DOC + assert chat_request.messages[1]["role"] == "assistant" + assert chat_request.messages[1]["content"] == CONTENT + assert chat_request.model == MODEL_NAME + # detector_paramas + assert chat_request.n == 3 + assert ( + chat_request.chat_template_kwargs["guardian_config"]["risk_name"] + == "groundedness" + ) + + +def test_request_to_chat_completion_request_empty_kwargs(granite_guardian_detection): + granite_guardian_detection_instance = asyncio.run(granite_guardian_detection) + context_request = ContextAnalysisRequest( + content=CONTENT, + context_type="docs", + context=[CONTEXT_DOC], + detector_params={"n": 2, "chat_template_kwargs": {}}, # no guardian config + ) + chat_request = ( + granite_guardian_detection_instance.request_to_chat_completion_request( + context_request, MODEL_NAME + ) + ) + assert type(chat_request) == ErrorResponse + assert chat_request.code == HTTPStatus.BAD_REQUEST + assert "No risk_name for context analysis" in chat_request.message + + +def test_request_to_chat_completion_request_empty_guardian_config( + granite_guardian_detection, +): + granite_guardian_detection_instance = asyncio.run(granite_guardian_detection) + context_request = ContextAnalysisRequest( + content=CONTENT, + context_type="docs", + context=[CONTEXT_DOC], + detector_params={"n": 2, "chat_template_kwargs": {"guardian_config": {}}}, + ) + chat_request = ( + granite_guardian_detection_instance.request_to_chat_completion_request( + context_request, MODEL_NAME + ) + ) + assert type(chat_request) == ErrorResponse + assert chat_request.code == HTTPStatus.BAD_REQUEST + assert "No risk_name for context analysis" in chat_request.message + + +def test_request_to_chat_completion_request_unsupported_risk_name( + granite_guardian_detection, +): + granite_guardian_detection_instance = asyncio.run(granite_guardian_detection) + context_request = ContextAnalysisRequest( + content=CONTENT, + context_type="docs", + context=[CONTEXT_DOC], + detector_params={ + "n": 2, + "chat_template_kwargs": {"guardian_config": {"risk_name": "foo"}}, + }, + ) + chat_request = ( + granite_guardian_detection_instance.request_to_chat_completion_request( + context_request, MODEL_NAME + ) + ) + assert type(chat_request) == ErrorResponse + assert chat_request.code == HTTPStatus.BAD_REQUEST + assert ( + "risk_name foo is not compatible with context analysis" in chat_request.message + ) + + +def test_context_analyze( + granite_guardian_detection, granite_guardian_completion_response +): + granite_guardian_detection_instance = asyncio.run(granite_guardian_detection) + context_request = ContextAnalysisRequest( + content=CONTENT, + context_type="docs", + context=[CONTEXT_DOC], + detector_params={ + "n": 2, + "chat_template_kwargs": {"guardian_config": {"risk_name": "groundedness"}}, + }, + ) + with patch( + "vllm_detector_adapter.generative_detectors.granite_guardian.GraniteGuardian.create_chat_completion", + return_value=granite_guardian_completion_response, + ): + detection_response = asyncio.run( + granite_guardian_detection_instance.context_analyze(context_request) + ) + assert type(detection_response) == DetectionResponse + detections = detection_response.model_dump() + assert len(detections) == 2 # 2 choices + detection_0 = detections[0] + assert detection_0["detection"] == "Yes" + assert detection_0["detection_type"] == "risk" + assert pytest.approx(detection_0["score"]) == 1.0 + + # NOTE: currently these functions are basically just the base implementations, # where safe/unsafe tokens are defined in the granite guardian class @@ -197,8 +364,8 @@ def test_preprocess_with_detector_params(granite_guardian_detection): def test_calculate_scores( granite_guardian_detection, granite_guardian_completion_response ): - llama_guard_detection_instance = asyncio.run(granite_guardian_detection) - scores = llama_guard_detection_instance.calculate_scores( + granite_guardian_detection_instance = asyncio.run(granite_guardian_detection) + scores = granite_guardian_detection_instance.calculate_scores( granite_guardian_completion_response ) assert len(scores) == 2 # 2 choices @@ -224,7 +391,7 @@ def test_chat_detection( detection_response = asyncio.run( granite_guardian_detection_instance.chat(chat_request) ) - assert type(detection_response) == ChatDetectionResponse + assert type(detection_response) == DetectionResponse detections = detection_response.model_dump() assert len(detections) == 2 # 2 choices detection_0 = detections[0] diff --git a/tests/generative_detectors/test_llama_guard.py b/tests/generative_detectors/test_llama_guard.py index 1ba98db..0d68366 100644 --- a/tests/generative_detectors/test_llama_guard.py +++ b/tests/generative_detectors/test_llama_guard.py @@ -14,6 +14,7 @@ ChatCompletionResponse, ChatCompletionResponseChoice, ChatMessage, + ErrorResponse, UsageInfo, ) from vllm.entrypoints.openai.serving_models import BaseModelPath, OpenAIServingModels @@ -24,8 +25,9 @@ from vllm_detector_adapter.generative_detectors.llama_guard import LlamaGuard from vllm_detector_adapter.protocol import ( ChatDetectionRequest, - ChatDetectionResponse, + ContextAnalysisRequest, DetectionChatMessageParam, + DetectionResponse, ) MODEL_NAME = "meta-llama/Llama-Guard-3-8B" # Example llama guard model @@ -187,10 +189,27 @@ def test_chat_detection(llama_guard_detection, llama_guard_completion_response): detection_response = asyncio.run( llama_guard_detection_instance.chat(chat_request) ) - assert type(detection_response) == ChatDetectionResponse + assert type(detection_response) == DetectionResponse detections = detection_response.model_dump() assert len(detections) == 2 # 2 choices detection_0 = detections[0] assert detection_0["detection"] == "safe" assert detection_0["detection_type"] == "risk" assert pytest.approx(detection_0["score"]) == 0.001346767 + + +def test_context_analyze(llama_guard_detection): + llama_guard_detection_instance = asyncio.run(llama_guard_detection) + content = "Where do I find geese?" + context_doc = "Geese can be found in lakes, ponds, and rivers" + context_request = ContextAnalysisRequest( + content=content, + context_type="docs", + context=[context_doc], + detector_params={"n": 2, "temperature": 0.3}, + ) + response = asyncio.run( + llama_guard_detection_instance.context_analyze(context_request) + ) + assert type(response) == ErrorResponse + assert response.code == HTTPStatus.NOT_IMPLEMENTED diff --git a/tests/test_protocol.py b/tests/test_protocol.py index 8480691..9178ae6 100644 --- a/tests/test_protocol.py +++ b/tests/test_protocol.py @@ -14,16 +14,18 @@ # Local from vllm_detector_adapter.protocol import ( ChatDetectionRequest, - ChatDetectionResponse, DetectionChatMessageParam, + DetectionResponse, ) MODEL_NAME = "org/model-name" ### Tests ##################################################################### +#### Chat detection request tests -def test_detection_to_completion_request(): + +def test_chat_detection_to_completion_request(): chat_request = ChatDetectionRequest( messages=[ DetectionChatMessageParam( @@ -46,7 +48,7 @@ def test_detection_to_completion_request(): assert request.n == 3 -def test_detection_to_completion_request_unknown_params(): +def test_chat_detection_to_completion_request_unknown_params(): chat_request = ChatDetectionRequest( messages=[ DetectionChatMessageParam(role="user", content="How do I search for moose?") @@ -58,6 +60,9 @@ def test_detection_to_completion_request_unknown_params(): assert type(request) == ChatCompletionRequest +#### General response tests + + def test_response_from_completion_response(): # Simplified response without logprobs since not needed for this method choice_0 = ChatCompletionResponseChoice( @@ -81,10 +86,10 @@ def test_response_from_completion_response(): ) scores = [0.3, 0.7] detection_type = "type" - detection_response = ChatDetectionResponse.from_chat_completion_response( + detection_response = DetectionResponse.from_chat_completion_response( response, scores, detection_type ) - assert type(detection_response) == ChatDetectionResponse + assert type(detection_response) == DetectionResponse detections = detection_response.model_dump() assert len(detections) == 2 # 2 choices detection_0 = detections[0] @@ -115,7 +120,7 @@ def test_response_from_completion_response_missing_content(): ) scores = [0.3, 0.7] detection_type = "type" - detection_response = ChatDetectionResponse.from_chat_completion_response( + detection_response = DetectionResponse.from_chat_completion_response( response, scores, detection_type ) assert type(detection_response) == ErrorResponse diff --git a/vllm_detector_adapter/api_server.py b/vllm_detector_adapter/api_server.py index 1153418..6489e05 100644 --- a/vllm_detector_adapter/api_server.py +++ b/vllm_detector_adapter/api_server.py @@ -25,7 +25,11 @@ # Local from vllm_detector_adapter import generative_detectors from vllm_detector_adapter.logging import init_logger -from vllm_detector_adapter.protocol import ChatDetectionRequest, ChatDetectionResponse +from vllm_detector_adapter.protocol import ( + ChatDetectionRequest, + ContextAnalysisRequest, + DetectionResponse, +) TIMEOUT_KEEP_ALIVE = 5 # seconds @@ -156,7 +160,29 @@ async def create_chat_detection(request: ChatDetectionRequest, raw_request: Requ content=detector_response.model_dump(), status_code=detector_response.code ) - elif isinstance(detector_response, ChatDetectionResponse): + elif isinstance(detector_response, DetectionResponse): + return JSONResponse(content=detector_response.model_dump()) + + return JSONResponse({}) + + +@router.post("/api/v1/text/context/doc") +async def create_context_doc_detection( + request: ContextAnalysisRequest, raw_request: Request +): + """Support context analysis endpoint""" + + detector_response = await chat_detection(raw_request).context_analyze( + request, raw_request + ) + + if isinstance(detector_response, ErrorResponse): + # ErrorResponse includes code and message, corresponding to errors for the detectorAPI + return JSONResponse( + content=detector_response.model_dump(), status_code=detector_response.code + ) + + elif isinstance(detector_response, DetectionResponse): return JSONResponse(content=detector_response.model_dump()) return JSONResponse({}) diff --git a/vllm_detector_adapter/generative_detectors/base.py b/vllm_detector_adapter/generative_detectors/base.py index dde25bf..60449f5 100644 --- a/vllm_detector_adapter/generative_detectors/base.py +++ b/vllm_detector_adapter/generative_detectors/base.py @@ -14,7 +14,11 @@ # Local from vllm_detector_adapter.logging import init_logger -from vllm_detector_adapter.protocol import ChatDetectionRequest, ChatDetectionResponse +from vllm_detector_adapter.protocol import ( + ChatDetectionRequest, + ContextAnalysisRequest, + DetectionResponse, +) logger = init_logger(__name__) @@ -32,6 +36,8 @@ def __init__(self, task_template: str, output_template: str, *args, **kwargs): self.output_template = self.load_template(output_template) + ##### Template functions ################################################### + def load_template(self, template_path: Optional[Union[Path, str]]) -> str: """Function to load template Note: this function currently is largely taken from the chat template method @@ -66,23 +72,27 @@ def load_template(self, template_path: Optional[Union[Path, str]]) -> str: logger.info("Using supplied template:\n%s", resolved_template) return self.jinja_env.from_string(resolved_template) - def apply_task_template( + def apply_output_template( + self, response: ChatCompletionResponse + ) -> Union[ChatCompletionResponse, ErrorResponse]: + """Apply output parsing template for the response""" + return response + + ##### Chat request processing functions #################################### + + def apply_task_template_to_chat( self, request: ChatDetectionRequest ) -> Union[ChatDetectionRequest, ErrorResponse]: - """Apply task template on the request""" + """Apply task template on the chat request""" return request - def preprocess( + def preprocess_chat_request( self, request: ChatDetectionRequest ) -> Union[ChatDetectionRequest, ErrorResponse]: - """Preprocess request""" + """Preprocess chat request""" return request - def apply_output_template( - self, response: ChatCompletionResponse - ) -> Union[ChatCompletionResponse, ErrorResponse]: - """Apply output parsing template for the response""" - return response + ##### General chat completion output processing functions ################## def calculate_scores(self, response: ChatCompletionResponse) -> List[float]: """Extract scores from logprobs of the raw chat response""" @@ -116,35 +126,9 @@ def calculate_scores(self, response: ChatCompletionResponse) -> List[float]: return choice_scores - ##### Detection methods #################################################### - # Base implementation of other detection endpoints like content can go here - - async def chat( - self, - request: ChatDetectionRequest, - raw_request: Optional[Request] = None, - ) -> Union[ChatDetectionResponse, ErrorResponse]: - """Function used to call chat detection and provide a /chat response""" - - # Fetch model name from super class: OpenAIServing - model_name = self.models.base_model_paths[0].name - - # Apply task template if it exists - if self.task_template: - request = self.apply_task_template(request) - if isinstance(request, ErrorResponse): - # Propagate any request problems that will not allow - # task template to be applied - return request - - # Optionally make model-dependent adjustments for the request - request = self.preprocess(request) - - chat_completion_request = request.to_chat_completion_request(model_name) - if isinstance(chat_completion_request, ErrorResponse): - # Propagate any request problems like extra unallowed parameters - return chat_completion_request - + async def process_chat_completion_with_scores( + self, chat_completion_request, raw_request + ) -> Union[DetectionResponse, ErrorResponse]: # Return an error for streaming for now. Since the detector API is unary, # results would not be streamed back anyway. The chat completion response # object would look different, and content would have to be aggregated. @@ -182,6 +166,53 @@ async def chat( # Calculate scores scores = self.calculate_scores(chat_response) - return ChatDetectionResponse.from_chat_completion_response( + return DetectionResponse.from_chat_completion_response( chat_response, scores, self.DETECTION_TYPE ) + + ##### Detection methods #################################################### + # Base implementation of other detection endpoints like content can go here + + async def chat( + self, + request: ChatDetectionRequest, + raw_request: Optional[Request] = None, + ) -> Union[DetectionResponse, ErrorResponse]: + """Function used to call chat detection and provide a /chat response""" + + # Fetch model name from super class: OpenAIServing + model_name = self.models.base_model_paths[0].name + + # Apply task template if it exists + if self.task_template: + request = self.apply_task_template_to_chat(request) + if isinstance(request, ErrorResponse): + # Propagate any request problems that will not allow + # task template to be applied + return request + + # Optionally make model-dependent adjustments for the request + request = self.preprocess_chat_request(request) + + chat_completion_request = request.to_chat_completion_request(model_name) + if isinstance(chat_completion_request, ErrorResponse): + # Propagate any request problems + return chat_completion_request + + return await self.process_chat_completion_with_scores( + chat_completion_request, raw_request + ) + + async def context_analyze( + self, + request: ContextAnalysisRequest, + raw_request: Optional[Request] = None, + ) -> Union[DetectionResponse, ErrorResponse]: + """Function used to call chat detection and provide a /context/doc response""" + # Return "not implemented" here since context analysis may not + # generally apply to all models at this time + return ErrorResponse( + message="context analysis is not supported for the detector", + type="NotImplementedError", + code=HTTPStatus.NOT_IMPLEMENTED.value, + ) diff --git a/vllm_detector_adapter/generative_detectors/granite_guardian.py b/vllm_detector_adapter/generative_detectors/granite_guardian.py index 8ad8c06..6fa6502 100644 --- a/vllm_detector_adapter/generative_detectors/granite_guardian.py +++ b/vllm_detector_adapter/generative_detectors/granite_guardian.py @@ -1,13 +1,20 @@ # Standard -from typing import Union +from http import HTTPStatus +from typing import Optional, Union # Third Party -from vllm.entrypoints.openai.protocol import ErrorResponse +from fastapi import Request +from pydantic import ValidationError +from vllm.entrypoints.openai.protocol import ChatCompletionRequest, ErrorResponse # Local from vllm_detector_adapter.generative_detectors.base import ChatCompletionDetectionBase from vllm_detector_adapter.logging import init_logger -from vllm_detector_adapter.protocol import ChatDetectionRequest +from vllm_detector_adapter.protocol import ( + ChatDetectionRequest, + ContextAnalysisRequest, + DetectionResponse, +) logger = init_logger(__name__) @@ -22,9 +29,13 @@ class GraniteGuardian(ChatCompletionDetectionBase): SAFE_TOKEN = "No" UNSAFE_TOKEN = "Yes" + # Risks associated with context analysis + PROMPT_CONTEXT_ANALYSIS_RISKS = ["context_relevance"] + RESPONSE_CONTEXT_ANALYSIS_RISKS = ["groundedness"] + def preprocess( - self, request: ChatDetectionRequest - ) -> Union[ChatDetectionRequest, ErrorResponse]: + self, request: Union[ChatDetectionRequest, ContextAnalysisRequest] + ) -> Union[ChatDetectionRequest, ContextAnalysisRequest, ErrorResponse]: """Granite guardian specific parameter updates for risk name and risk definition""" # Validation that one of the 'defined' risks is requested will be # done through the chat template on each request. Errors will @@ -47,3 +58,119 @@ def preprocess( } return request + + def preprocess_chat_request( + self, request: ChatDetectionRequest + ) -> Union[ChatDetectionRequest, ErrorResponse]: + """Granite guardian chat request preprocess is just detector parameter updates""" + return self.preprocess(request) + + def request_to_chat_completion_request( + self, request: ContextAnalysisRequest, model_name: str + ) -> Union[ChatCompletionRequest, ErrorResponse]: + NO_RISK_NAME_MESSAGE = "No risk_name for context analysis" + + risk_name = None + if ( + "chat_template_kwargs" not in request.detector_params + or "guardian_config" not in request.detector_params["chat_template_kwargs"] + ): + return ErrorResponse( + message=NO_RISK_NAME_MESSAGE, + type="BadRequestError", + code=HTTPStatus.BAD_REQUEST.value, + ) + # Use risk name to determine message format + if guardian_config := request.detector_params["chat_template_kwargs"][ + "guardian_config" + ]: + risk_name = guardian_config["risk_name"] + else: + # Leaving off risk name can lead to model/template errors + return ErrorResponse( + message=NO_RISK_NAME_MESSAGE, + type="BadRequestError", + code=HTTPStatus.BAD_REQUEST.value, + ) + + if len(request.context) > 1: + # The detector API for context docs detection supports more than one context text + # but currently chat completions will only take one context. + logger.warning("More than one context provided. Only the last will be used") + context_text = request.context[-1] + content = request.content + # The "context" role is not an officially supported OpenAI role, so this is specific + # to Granite Guardian. Messages must also be in precise ordering, or model/template + # errors may occur. + if risk_name in self.RESPONSE_CONTEXT_ANALYSIS_RISKS: + # Response analysis + messages = [ + {"role": "context", "content": context_text}, + {"role": "assistant", "content": content}, + ] + elif risk_name in self.PROMPT_CONTEXT_ANALYSIS_RISKS: + # Prompt analysis + messages = [ + {"role": "user", "content": content}, + {"role": "context", "content": context_text}, + ] + else: + # Return error if risk names are not expected ones + return ErrorResponse( + message="risk_name {} is not compatible with context analysis".format( + risk_name + ), + type="BadRequestError", + code=HTTPStatus.BAD_REQUEST.value, + ) + + # Try to pass all detector_params through as additional parameters to chat completions + # without additional validation or parameter changes, similar to ChatDetectionRequest processing + try: + return ChatCompletionRequest( + messages=messages, + model=model_name, + **request.detector_params, + ) + except ValidationError as e: + return ErrorResponse( + message=repr(e.errors()[0]), + type="BadRequestError", + code=HTTPStatus.BAD_REQUEST.value, + ) + + async def context_analyze( + self, + request: ContextAnalysisRequest, + raw_request: Optional[Request] = None, + ) -> Union[DetectionResponse, ErrorResponse]: + """Function used to call chat detection and provide a /context/doc response""" + # Fetch model name from super class: OpenAIServing + model_name = self.models.base_model_paths[0].name + + # Apply task template if it exists + if self.task_template: + request = self.apply_task_template(request) + if isinstance(request, ErrorResponse): + # Propagate any request problems that will not allow + # task template to be applied + return request + + # Make model-dependent adjustments for the request + request = self.preprocess(request) + + # Since particular chat messages are dependent on Granite Guardian risk definitions, + # the processing is done here rather than in a separate, general to_chat_completion_request + # for all context analysis requests. + chat_completion_request = self.request_to_chat_completion_request( + request, model_name + ) + if isinstance(chat_completion_request, ErrorResponse): + # Propagate any request problems + return chat_completion_request + + # Calling chat completion and processing of scores is currently + # the same as for the /chat case + return await self.process_chat_completion_with_scores( + chat_completion_request, raw_request + ) diff --git a/vllm_detector_adapter/protocol.py b/vllm_detector_adapter/protocol.py index 2c6bb92..dbb054f 100644 --- a/vllm_detector_adapter/protocol.py +++ b/vllm_detector_adapter/protocol.py @@ -11,10 +11,10 @@ ErrorResponse, ) -##### [FMS] Detection API types -# NOTE: This currently works with the /chat detection endpoint +##### [FMS] Detection API types ##### +# Endpoints are as documented https://foundation-model-stack.github.io/fms-guardrails-orchestrator/?urls.primaryName=Detector+API#/ -######## Contents Detection types +######## Contents Detection types (currently unused) for the /text/contents detection endpoint class ContentsDetectionRequest(BaseModel): @@ -35,7 +35,7 @@ class ContentsDetectionResponseObject(BaseModel): score: float = Field(examples=[0.5]) -######## Chat Detection types +##### Chat Detection types for the /text/chat detection endpoint ############### class DetectionChatMessageParam(TypedDict): @@ -85,16 +85,43 @@ def to_chat_completion_request(self, model_name: str): ) -class ChatDetectionResponseObject(BaseModel): +##### Context Analysis Detection types for the /text/context/docs detection endpoint + + +class ContextAnalysisRequest(BaseModel): + # Content to run detection on + content: str = Field(examples=["What is a moose?"]) + # Type of context - url or docs (for text documents) + context_type: str = Field(examples=["docs", "url"]) + # Context of type context_type to run detection on + context: List[str] = Field( + examples=[ + "https://en.wikipedia.org/wiki/Moose", + "https://www.nwf.org/Educational-Resources/Wildlife-Guide/Mammals/Moose", + ] + ) + # Parameters to pass through to chat completions, optional + detector_params: Optional[Dict] = {} + + # NOTE: currently there is no general to_chat_completion_request + # since the chat completion roles and messages are fairly tied + # to particular models' risk definitions. If a general strategy + # is identified, it can be implemented here. + + +##### General detection response objects ####################################### + + +class DetectionResponseObject(BaseModel): detection: str = Field(examples=["positive"]) detection_type: str = Field(examples=["simple_example"]) score: float = Field(examples=[0.5]) -class ChatDetectionResponse(RootModel): +class DetectionResponse(RootModel): # The root attribute is used here so that the response will appear # as a list instead of a list nested under a key - root: List[ChatDetectionResponseObject] + root: List[DetectionResponseObject] @staticmethod def from_chat_completion_response( @@ -105,7 +132,7 @@ def from_chat_completion_response( for i, choice in enumerate(response.choices): content = choice.message.content if content and isinstance(content, str): - response_object = ChatDetectionResponseObject( + response_object = DetectionResponseObject( detection_type=detection_type, detection=content.strip(), score=scores[i], @@ -123,4 +150,4 @@ def from_chat_completion_response( code=HTTPStatus.BAD_REQUEST.value, ) - return ChatDetectionResponse(root=detection_responses) + return DetectionResponse(root=detection_responses)