From 04eb73363f8ca3950cf689228ca505e0b4aab2c9 Mon Sep 17 00:00:00 2001 From: Evaline Ju <69598118+evaline-ju@users.noreply.github.com> Date: Fri, 24 Jan 2025 15:08:23 -0700 Subject: [PATCH 01/12] :sparkles::white_check_mark: Context analysis request and test Signed-off-by: Evaline Ju <69598118+evaline-ju@users.noreply.github.com> --- tests/test_protocol.py | 30 +++++++++++++++-- vllm_detector_adapter/protocol.py | 53 +++++++++++++++++++++++++++++-- 2 files changed, 78 insertions(+), 5 deletions(-) diff --git a/tests/test_protocol.py b/tests/test_protocol.py index 8480691..3065fbc 100644 --- a/tests/test_protocol.py +++ b/tests/test_protocol.py @@ -15,6 +15,7 @@ from vllm_detector_adapter.protocol import ( ChatDetectionRequest, ChatDetectionResponse, + ContextAnalysisRequest, DetectionChatMessageParam, ) @@ -22,8 +23,9 @@ ### Tests ##################################################################### +#### Chat detection request tests -def test_detection_to_completion_request(): +def test_chat_detection_to_completion_request(): chat_request = ChatDetectionRequest( messages=[ DetectionChatMessageParam( @@ -46,7 +48,7 @@ def test_detection_to_completion_request(): assert request.n == 3 -def test_detection_to_completion_request_unknown_params(): +def test_chat_detection_to_completion_request_unknown_params(): chat_request = ChatDetectionRequest( messages=[ DetectionChatMessageParam(role="user", content="How do I search for moose?") @@ -58,6 +60,30 @@ def test_detection_to_completion_request_unknown_params(): assert type(request) == ChatCompletionRequest +#### Context analysis detection request tests + +def test_context_detection_to_completion_request(): + content = "Where do I find geese?" + context_doc = "Geese can be found in lakes, ponds, and rivers" + context_request = ContextAnalysisRequest( + content=content, + context_type="docs", + context=[context_doc], + detector_params={"n": 2, "temperature": 0.3}, + ) + request = context_request.to_chat_completion_request(MODEL_NAME) + assert type(request) == ChatCompletionRequest + assert request.messages[0]["role"] == "context" + assert request.messages[0]["content"] == context_doc + assert request.messages[1]["role"] == "assistant" + assert request.messages[1]["content"] == content + assert request.model == MODEL_NAME + assert request.temperature == 0.3 + assert request.n == 2 + + +#### General response tests + def test_response_from_completion_response(): # Simplified response without logprobs since not needed for this method choice_0 = ChatCompletionResponseChoice( diff --git a/vllm_detector_adapter/protocol.py b/vllm_detector_adapter/protocol.py index 2c6bb92..20b5221 100644 --- a/vllm_detector_adapter/protocol.py +++ b/vllm_detector_adapter/protocol.py @@ -12,9 +12,9 @@ ) ##### [FMS] Detection API types -# NOTE: This currently works with the /chat detection endpoint +# Endpoints are as documented https://foundation-model-stack.github.io/fms-guardrails-orchestrator/?urls.primaryName=Detector+API#/ -######## Contents Detection types +######## Contents Detection types (currently unused) for the /text/contents detection endpoint class ContentsDetectionRequest(BaseModel): @@ -35,7 +35,7 @@ class ContentsDetectionResponseObject(BaseModel): score: float = Field(examples=[0.5]) -######## Chat Detection types +######## Chat Detection types for the /text/chat detection endpoint class DetectionChatMessageParam(TypedDict): @@ -85,6 +85,53 @@ def to_chat_completion_request(self, model_name: str): ) +######## Context Analysis Detection types for the /text/context/docs detection endpoint + + +class ContextAnalysisRequest(BaseModel): + # Content to run detection on + content: str = Field(examples=["What is a moose?"]) + # Type of context - url or docs (for text documents) + context_type: str = Field(examples=["docs", "url"]) + # Context of type context_type to run detection on + context: List[str] = Field( + examples=[ + "https://en.wikipedia.org/wiki/Moose", + "https://www.nwf.org/Educational-Resources/Wildlife-Guide/Mammals/Moose", + ] + ) + # Parameters to pass through to chat completions, optional + detector_params: Optional[Dict] = {} + + def to_chat_completion_request(self, model_name: str): + """Function to convert context analysis request to openai chat completion request""" + # Can only process one context currently - TODO: validate + # For now, context_type is ignored but is required for the detection endpoint + # TODO: 'context' is not a generally supported 'role' in the openAI API + messages = [ + {"role": "context", "content": self.context[0]}, + {"role": "assistant", "content": self.content}, + ] + + # Try to pass all detector_params through as additional parameters to chat completions + # without additional validation or parameter changes as in ChatDetectionRequest above + try: + return ChatCompletionRequest( + messages=messages, + model=model_name, + **self.detector_params, + ) + except ValidationError as e: + return ErrorResponse( + message=repr(e.errors()[0]), + type="BadRequestError", + code=HTTPStatus.BAD_REQUEST.value, + ) + + +######## General modified response(s) for chat completions + + class ChatDetectionResponseObject(BaseModel): detection: str = Field(examples=["positive"]) detection_type: str = Field(examples=["simple_example"]) From 7304e897ff642b3da93851858943de1be5a205c5 Mon Sep 17 00:00:00 2001 From: Evaline Ju <69598118+evaline-ju@users.noreply.github.com> Date: Mon, 27 Jan 2025 08:57:07 -0700 Subject: [PATCH 02/12] :sparkles: Add context doc endpoint Signed-off-by: Evaline Ju <69598118+evaline-ju@users.noreply.github.com> --- tests/test_protocol.py | 3 ++ vllm_detector_adapter/api_server.py | 28 ++++++++++++++++++- .../generative_detectors/base.py | 14 +++++++++- 3 files changed, 43 insertions(+), 2 deletions(-) diff --git a/tests/test_protocol.py b/tests/test_protocol.py index 3065fbc..c05335c 100644 --- a/tests/test_protocol.py +++ b/tests/test_protocol.py @@ -25,6 +25,7 @@ #### Chat detection request tests + def test_chat_detection_to_completion_request(): chat_request = ChatDetectionRequest( messages=[ @@ -62,6 +63,7 @@ def test_chat_detection_to_completion_request_unknown_params(): #### Context analysis detection request tests + def test_context_detection_to_completion_request(): content = "Where do I find geese?" context_doc = "Geese can be found in lakes, ponds, and rivers" @@ -84,6 +86,7 @@ def test_context_detection_to_completion_request(): #### General response tests + def test_response_from_completion_response(): # Simplified response without logprobs since not needed for this method choice_0 = ChatCompletionResponseChoice( diff --git a/vllm_detector_adapter/api_server.py b/vllm_detector_adapter/api_server.py index 1153418..7f1f7a1 100644 --- a/vllm_detector_adapter/api_server.py +++ b/vllm_detector_adapter/api_server.py @@ -25,7 +25,11 @@ # Local from vllm_detector_adapter import generative_detectors from vllm_detector_adapter.logging import init_logger -from vllm_detector_adapter.protocol import ChatDetectionRequest, ChatDetectionResponse +from vllm_detector_adapter.protocol import ( + ChatDetectionRequest, + ChatDetectionResponse, + ContextAnalysisRequest, +) TIMEOUT_KEEP_ALIVE = 5 # seconds @@ -162,6 +166,28 @@ async def create_chat_detection(request: ChatDetectionRequest, raw_request: Requ return JSONResponse({}) +@router.post("/api/v1/text/context/doc") +async def create_context_doc_detection( + request: ContextAnalysisRequest, raw_request: Request +): + """Support context analysis endpoint""" + + detector_response = await chat_detection(raw_request).context_analyze( + request, raw_request + ) + + if isinstance(detector_response, ErrorResponse): + # ErrorResponse includes code and message, corresponding to errors for the detectorAPI + return JSONResponse( + content=detector_response.model_dump(), status_code=detector_response.code + ) + + elif isinstance(detector_response, ChatDetectionResponse): + return JSONResponse(content=detector_response.model_dump()) + + return JSONResponse({}) + + def add_chat_detection_params(parser): parser.add_argument( "--task-template", diff --git a/vllm_detector_adapter/generative_detectors/base.py b/vllm_detector_adapter/generative_detectors/base.py index dde25bf..d45d393 100644 --- a/vllm_detector_adapter/generative_detectors/base.py +++ b/vllm_detector_adapter/generative_detectors/base.py @@ -14,7 +14,11 @@ # Local from vllm_detector_adapter.logging import init_logger -from vllm_detector_adapter.protocol import ChatDetectionRequest, ChatDetectionResponse +from vllm_detector_adapter.protocol import ( + ChatDetectionRequest, + ChatDetectionResponse, + ContextAnalysisRequest, +) logger = init_logger(__name__) @@ -185,3 +189,11 @@ async def chat( return ChatDetectionResponse.from_chat_completion_response( chat_response, scores, self.DETECTION_TYPE ) + + async def context_analyze( + self, + request: ContextAnalysisRequest, + raw_request: Optional[Request] = None, + ) -> Union[ChatDetectionResponse, ErrorResponse]: + """Function used to call chat detection and provide a /context/doc response""" + pass From f4563c8a48c50016c101c9ffefdc6f7d5d45c73e Mon Sep 17 00:00:00 2001 From: Evaline Ju <69598118+evaline-ju@users.noreply.github.com> Date: Tue, 28 Jan 2025 12:59:14 -0700 Subject: [PATCH 03/12] :recycle: Update chat functions Signed-off-by: Evaline Ju <69598118+evaline-ju@users.noreply.github.com> --- .../test_granite_guardian.py | 4 +- tests/test_protocol.py | 42 ++++++------- .../generative_detectors/base.py | 22 ++++--- .../generative_detectors/granite_guardian.py | 62 +++++++++++++++++-- vllm_detector_adapter/protocol.py | 25 -------- 5 files changed, 96 insertions(+), 59 deletions(-) diff --git a/tests/generative_detectors/test_granite_guardian.py b/tests/generative_detectors/test_granite_guardian.py index 0e6343b..c26db89 100644 --- a/tests/generative_detectors/test_granite_guardian.py +++ b/tests/generative_detectors/test_granite_guardian.py @@ -172,7 +172,9 @@ def test_preprocess_with_detector_params(granite_guardian_detection): ], detector_params=detector_params, ) - processed_request = llama_guard_detection_instance.preprocess(initial_request) + processed_request = llama_guard_detection_instance.preprocess_chat_request( + initial_request + ) assert type(processed_request) == ChatDetectionRequest # Processed request should not have these extra params assert "risk_name" not in processed_request.detector_params diff --git a/tests/test_protocol.py b/tests/test_protocol.py index c05335c..e944fe4 100644 --- a/tests/test_protocol.py +++ b/tests/test_protocol.py @@ -61,27 +61,27 @@ def test_chat_detection_to_completion_request_unknown_params(): assert type(request) == ChatCompletionRequest -#### Context analysis detection request tests - - -def test_context_detection_to_completion_request(): - content = "Where do I find geese?" - context_doc = "Geese can be found in lakes, ponds, and rivers" - context_request = ContextAnalysisRequest( - content=content, - context_type="docs", - context=[context_doc], - detector_params={"n": 2, "temperature": 0.3}, - ) - request = context_request.to_chat_completion_request(MODEL_NAME) - assert type(request) == ChatCompletionRequest - assert request.messages[0]["role"] == "context" - assert request.messages[0]["content"] == context_doc - assert request.messages[1]["role"] == "assistant" - assert request.messages[1]["content"] == content - assert request.model == MODEL_NAME - assert request.temperature == 0.3 - assert request.n == 2 +# #### Context analysis detection request tests + + +# def test_context_detection_to_completion_request(): +# content = "Where do I find geese?" +# context_doc = "Geese can be found in lakes, ponds, and rivers" +# context_request = ContextAnalysisRequest( +# content=content, +# context_type="docs", +# context=[context_doc], +# detector_params={"n": 2, "temperature": 0.3}, +# ) +# request = context_request.to_chat_completion_request(MODEL_NAME) +# assert type(request) == ChatCompletionRequest +# assert request.messages[0]["role"] == "context" +# assert request.messages[0]["content"] == context_doc +# assert request.messages[1]["role"] == "assistant" +# assert request.messages[1]["content"] == content +# assert request.model == MODEL_NAME +# assert request.temperature == 0.3 +# assert request.n == 2 #### General response tests diff --git a/vllm_detector_adapter/generative_detectors/base.py b/vllm_detector_adapter/generative_detectors/base.py index d45d393..ea68aff 100644 --- a/vllm_detector_adapter/generative_detectors/base.py +++ b/vllm_detector_adapter/generative_detectors/base.py @@ -70,16 +70,16 @@ def load_template(self, template_path: Optional[Union[Path, str]]) -> str: logger.info("Using supplied template:\n%s", resolved_template) return self.jinja_env.from_string(resolved_template) - def apply_task_template( + def apply_task_template_to_chat( self, request: ChatDetectionRequest ) -> Union[ChatDetectionRequest, ErrorResponse]: - """Apply task template on the request""" + """Apply task template on the chat request""" return request - def preprocess( + def preprocess_chat_request( self, request: ChatDetectionRequest ) -> Union[ChatDetectionRequest, ErrorResponse]: - """Preprocess request""" + """Preprocess chat request""" return request def apply_output_template( @@ -135,18 +135,18 @@ async def chat( # Apply task template if it exists if self.task_template: - request = self.apply_task_template(request) + request = self.apply_task_template_to_chat(request) if isinstance(request, ErrorResponse): # Propagate any request problems that will not allow # task template to be applied return request # Optionally make model-dependent adjustments for the request - request = self.preprocess(request) + request = self.preprocess_chat_request(request) chat_completion_request = request.to_chat_completion_request(model_name) if isinstance(chat_completion_request, ErrorResponse): - # Propagate any request problems like extra unallowed parameters + # Propagate any request problems return chat_completion_request # Return an error for streaming for now. Since the detector API is unary, @@ -196,4 +196,10 @@ async def context_analyze( raw_request: Optional[Request] = None, ) -> Union[ChatDetectionResponse, ErrorResponse]: """Function used to call chat detection and provide a /context/doc response""" - pass + # Return "not implemented" here since context analysis may not + # generally apply to all models at this time + return ErrorResponse( + message="context analysis is not supported for the detector", + type="NotImplementedError", + code=HTTPStatus.NOT_IMPLEMENTED.value, + ) diff --git a/vllm_detector_adapter/generative_detectors/granite_guardian.py b/vllm_detector_adapter/generative_detectors/granite_guardian.py index 8ad8c06..b67bfed 100644 --- a/vllm_detector_adapter/generative_detectors/granite_guardian.py +++ b/vllm_detector_adapter/generative_detectors/granite_guardian.py @@ -1,13 +1,20 @@ # Standard -from typing import Union +from http import HTTPStatus +from typing import Optional, Union # Third Party -from vllm.entrypoints.openai.protocol import ErrorResponse +from fastapi import Request +from pydantic import ValidationError +from vllm.entrypoints.openai.protocol import ChatCompletionRequest, ErrorResponse # Local from vllm_detector_adapter.generative_detectors.base import ChatCompletionDetectionBase from vllm_detector_adapter.logging import init_logger -from vllm_detector_adapter.protocol import ChatDetectionRequest +from vllm_detector_adapter.protocol import ( + ChatDetectionRequest, + ChatDetectionResponse, + ContextAnalysisRequest, +) logger = init_logger(__name__) @@ -22,7 +29,7 @@ class GraniteGuardian(ChatCompletionDetectionBase): SAFE_TOKEN = "No" UNSAFE_TOKEN = "Yes" - def preprocess( + def preprocess_chat_request( self, request: ChatDetectionRequest ) -> Union[ChatDetectionRequest, ErrorResponse]: """Granite guardian specific parameter updates for risk name and risk definition""" @@ -47,3 +54,50 @@ def preprocess( } return request + + async def context_analyze( + self, + request: ContextAnalysisRequest, + raw_request: Optional[Request] = None, + ) -> Union[ChatDetectionResponse, ErrorResponse]: + """Function used to call chat detection and provide a /context/doc response""" + # # Fetch model name from super class: OpenAIServing + # model_name = self.models.base_model_paths[0].name + + # # Apply task template if it exists + # if self.task_template: + # request = self.apply_task_template(request) + # if isinstance(request, ErrorResponse): + # # Propagate any request problems that will not allow + # # task template to be applied + # return request + + # # Much of this chat completions processing is similar to the + # # .chat case and could be refactored if reused further in the future + pass + + def to_chat_completion_request(self, model_name: str): + """Function to convert context analysis request to openai chat completion request""" + # Can only process one context currently - TODO: validate + # For now, context_type is ignored but is required for the detection endpoint + # TODO: 'context' is not a generally supported 'role' in the openAI API + # TODO: messages are much more specific to risk type + messages = [ + {"role": "context", "content": self.context[0]}, + {"role": "assistant", "content": self.content}, + ] + + # Try to pass all detector_params through as additional parameters to chat completions + # without additional validation or parameter changes as in ChatDetectionRequest above + try: + return ChatCompletionRequest( + messages=messages, + model=model_name, + **self.detector_params, + ) + except ValidationError as e: + return ErrorResponse( + message=repr(e.errors()[0]), + type="BadRequestError", + code=HTTPStatus.BAD_REQUEST.value, + ) diff --git a/vllm_detector_adapter/protocol.py b/vllm_detector_adapter/protocol.py index 20b5221..4f88da6 100644 --- a/vllm_detector_adapter/protocol.py +++ b/vllm_detector_adapter/protocol.py @@ -103,31 +103,6 @@ class ContextAnalysisRequest(BaseModel): # Parameters to pass through to chat completions, optional detector_params: Optional[Dict] = {} - def to_chat_completion_request(self, model_name: str): - """Function to convert context analysis request to openai chat completion request""" - # Can only process one context currently - TODO: validate - # For now, context_type is ignored but is required for the detection endpoint - # TODO: 'context' is not a generally supported 'role' in the openAI API - messages = [ - {"role": "context", "content": self.context[0]}, - {"role": "assistant", "content": self.content}, - ] - - # Try to pass all detector_params through as additional parameters to chat completions - # without additional validation or parameter changes as in ChatDetectionRequest above - try: - return ChatCompletionRequest( - messages=messages, - model=model_name, - **self.detector_params, - ) - except ValidationError as e: - return ErrorResponse( - message=repr(e.errors()[0]), - type="BadRequestError", - code=HTTPStatus.BAD_REQUEST.value, - ) - ######## General modified response(s) for chat completions From df1560bbc082887b0fb24bb5bcc3350426df4826 Mon Sep 17 00:00:00 2001 From: Evaline Ju <69598118+evaline-ju@users.noreply.github.com> Date: Tue, 28 Jan 2025 14:07:59 -0700 Subject: [PATCH 04/12] :recycle::white_check_mark: Unimplemented context analysis test Signed-off-by: Evaline Ju <69598118+evaline-ju@users.noreply.github.com> --- .../generative_detectors/test_llama_guard.py | 19 +++++++++++++++ tests/test_protocol.py | 24 ------------------- 2 files changed, 19 insertions(+), 24 deletions(-) diff --git a/tests/generative_detectors/test_llama_guard.py b/tests/generative_detectors/test_llama_guard.py index 1ba98db..986c8f6 100644 --- a/tests/generative_detectors/test_llama_guard.py +++ b/tests/generative_detectors/test_llama_guard.py @@ -14,6 +14,7 @@ ChatCompletionResponse, ChatCompletionResponseChoice, ChatMessage, + ErrorResponse, UsageInfo, ) from vllm.entrypoints.openai.serving_models import BaseModelPath, OpenAIServingModels @@ -25,6 +26,7 @@ from vllm_detector_adapter.protocol import ( ChatDetectionRequest, ChatDetectionResponse, + ContextAnalysisRequest, DetectionChatMessageParam, ) @@ -194,3 +196,20 @@ def test_chat_detection(llama_guard_detection, llama_guard_completion_response): assert detection_0["detection"] == "safe" assert detection_0["detection_type"] == "risk" assert pytest.approx(detection_0["score"]) == 0.001346767 + + +def test_context_analyze(llama_guard_detection): + llama_guard_detection_instance = asyncio.run(llama_guard_detection) + content = "Where do I find geese?" + context_doc = "Geese can be found in lakes, ponds, and rivers" + context_request = ContextAnalysisRequest( + content=content, + context_type="docs", + context=[context_doc], + detector_params={"n": 2, "temperature": 0.3}, + ) + response = asyncio.run( + llama_guard_detection_instance.context_analyze(context_request) + ) + assert type(response) == ErrorResponse + assert response.code == HTTPStatus.NOT_IMPLEMENTED diff --git a/tests/test_protocol.py b/tests/test_protocol.py index e944fe4..8746dd7 100644 --- a/tests/test_protocol.py +++ b/tests/test_protocol.py @@ -15,7 +15,6 @@ from vllm_detector_adapter.protocol import ( ChatDetectionRequest, ChatDetectionResponse, - ContextAnalysisRequest, DetectionChatMessageParam, ) @@ -61,29 +60,6 @@ def test_chat_detection_to_completion_request_unknown_params(): assert type(request) == ChatCompletionRequest -# #### Context analysis detection request tests - - -# def test_context_detection_to_completion_request(): -# content = "Where do I find geese?" -# context_doc = "Geese can be found in lakes, ponds, and rivers" -# context_request = ContextAnalysisRequest( -# content=content, -# context_type="docs", -# context=[context_doc], -# detector_params={"n": 2, "temperature": 0.3}, -# ) -# request = context_request.to_chat_completion_request(MODEL_NAME) -# assert type(request) == ChatCompletionRequest -# assert request.messages[0]["role"] == "context" -# assert request.messages[0]["content"] == context_doc -# assert request.messages[1]["role"] == "assistant" -# assert request.messages[1]["content"] == content -# assert request.model == MODEL_NAME -# assert request.temperature == 0.3 -# assert request.n == 2 - - #### General response tests From 8cbf7c4eec2a785ea98747fbe83080e1dbf46348 Mon Sep 17 00:00:00 2001 From: Evaline Ju <69598118+evaline-ju@users.noreply.github.com> Date: Tue, 28 Jan 2025 14:54:58 -0700 Subject: [PATCH 05/12] :sparkles: Initial context analysis for granite guardian Signed-off-by: Evaline Ju <69598118+evaline-ju@users.noreply.github.com> --- .../generative_detectors/granite_guardian.py | 168 ++++++++++++++---- vllm_detector_adapter/protocol.py | 5 + 2 files changed, 138 insertions(+), 35 deletions(-) diff --git a/vllm_detector_adapter/generative_detectors/granite_guardian.py b/vllm_detector_adapter/generative_detectors/granite_guardian.py index b67bfed..f27e297 100644 --- a/vllm_detector_adapter/generative_detectors/granite_guardian.py +++ b/vllm_detector_adapter/generative_detectors/granite_guardian.py @@ -29,9 +29,13 @@ class GraniteGuardian(ChatCompletionDetectionBase): SAFE_TOKEN = "No" UNSAFE_TOKEN = "Yes" - def preprocess_chat_request( - self, request: ChatDetectionRequest - ) -> Union[ChatDetectionRequest, ErrorResponse]: + # Risks associated with context analysis + PROMPT_CONTEXT_ANALYSIS_RISKS = ["relevance"] + RESPONSE_CONTEXT_ANALYSIS_RISKS = ["groundedness"] + + def preprocess( + self, request: Union[ChatDetectionRequest, ContextAnalysisRequest] + ) -> Union[ChatDetectionRequest, ContextAnalysisRequest, ErrorResponse]: """Granite guardian specific parameter updates for risk name and risk definition""" # Validation that one of the 'defined' risks is requested will be # done through the chat template on each request. Errors will @@ -55,40 +59,60 @@ def preprocess_chat_request( return request - async def context_analyze( - self, - request: ContextAnalysisRequest, - raw_request: Optional[Request] = None, - ) -> Union[ChatDetectionResponse, ErrorResponse]: - """Function used to call chat detection and provide a /context/doc response""" - # # Fetch model name from super class: OpenAIServing - # model_name = self.models.base_model_paths[0].name - - # # Apply task template if it exists - # if self.task_template: - # request = self.apply_task_template(request) - # if isinstance(request, ErrorResponse): - # # Propagate any request problems that will not allow - # # task template to be applied - # return request - - # # Much of this chat completions processing is similar to the - # # .chat case and could be refactored if reused further in the future - pass - - def to_chat_completion_request(self, model_name: str): - """Function to convert context analysis request to openai chat completion request""" - # Can only process one context currently - TODO: validate - # For now, context_type is ignored but is required for the detection endpoint - # TODO: 'context' is not a generally supported 'role' in the openAI API - # TODO: messages are much more specific to risk type - messages = [ - {"role": "context", "content": self.context[0]}, - {"role": "assistant", "content": self.content}, - ] + def preprocess_chat_request( + self, request: ChatDetectionRequest + ) -> Union[ChatDetectionRequest, ErrorResponse]: + """Granite guardian chat request preprocess is just detector parameter updates""" + return self.preprocess(request) + + def request_to_chat_completion_request( + self, request: ContextAnalysisRequest, model_name: str + ) -> Union[ChatCompletionRequest, ErrorResponse]: + + risk_name = None + # Use risk name to determine message format + if guardian_config := request.detector_params["chat_template_kwargs"][ + "guardian_config" + ]: + risk_name = guardian_config["risk_name"] + else: + return ErrorResponse( + message="No risk_name for context analysis", + type="BadRequestError", + code=HTTPStatus.BAD_REQUEST.value, + ) + + if len(request.context) > 1: + # The API supports more than one context text but currently chat completions + # will only take one context + logger.warning("More than one context provided. Only the last will be used") + context_text = request.context[-1] + content = request.content + # The "context" role is not an officially support OpenAI role + if risk_name in self.RESPONSE_CONTEXT_ANALYSIS_RISKS: + # Response analysis + messages = [ + {"role": "context", "content": context_text}, + {"role": "assistant", "content": content}, + ] + elif risk_name in self.PROMPT_CONTEXT_ANALYSIS_RISKS: + # Prompt analysis + messages = [ + {"role": "user", "content": content}, + {"role": "context", "content": context_text}, + ] + else: + # Return error if risks are not appropriate [or could default to one of the above analyses] + return ErrorResponse( + message="risk_name {} is not compatible with context analysis".format( + risk_name + ), + type="BadRequestError", + code=HTTPStatus.BAD_REQUEST.value, + ) # Try to pass all detector_params through as additional parameters to chat completions - # without additional validation or parameter changes as in ChatDetectionRequest above + # without additional validation or parameter changes, similar to ChatDetectionRequest processing try: return ChatCompletionRequest( messages=messages, @@ -101,3 +125,77 @@ def to_chat_completion_request(self, model_name: str): type="BadRequestError", code=HTTPStatus.BAD_REQUEST.value, ) + + async def context_analyze( + self, + request: ContextAnalysisRequest, + raw_request: Optional[Request] = None, + ) -> Union[ChatDetectionResponse, ErrorResponse]: + """Function used to call chat detection and provide a /context/doc response""" + # Fetch model name from super class: OpenAIServing + model_name = self.models.base_model_paths[0].name + + # Apply task template if it exists + if self.task_template: + request = self.apply_task_template(request) + if isinstance(request, ErrorResponse): + # Propagate any request problems that will not allow + # task template to be applied + return request + + # Make model-dependent adjustments for the request + request = self.preprocess(request) + + # Since particular chat messages are dependent on Granite Guardian risk definitions, + # the processing is done here rather than in a separate, general to_chat_completion_request + # for all context analysis requests. + chat_completion_request = self.request_to_chat_completion_request( + request, model_name + ) + if isinstance(chat_completion_request, ErrorResponse): + # Propagate any request problems + return chat_completion_request + + # Much of this chat completions processing is similar to the + # .chat case and could be refactored if reused further in the future + + # Return an error for streaming for now. Since the detector API is unary, + # results would not be streamed back anyway. The chat completion response + # object would look different, and content would have to be aggregated. + if chat_completion_request.stream: + return ErrorResponse( + message="streaming is not supported for the detector", + type="BadRequestError", + code=HTTPStatus.BAD_REQUEST.value, + ) + + # Manually set logprobs to True to calculate score later on + # NOTE: this is supposed to override if user has set logprobs to False + # or left logprobs as the default False + chat_completion_request.logprobs = True + # NOTE: We need top_logprobs to be enabled to calculate score appropriately + # We override this and not allow configuration at this point. In future, we may + # want to expose this configurable to certain range. + chat_completion_request.top_logprobs = 5 + + logger.debug("Request to chat completion: %s", chat_completion_request) + + # Call chat completion + chat_response = await self.create_chat_completion( + chat_completion_request, raw_request + ) + logger.debug("Raw chat completion response: %s", chat_response) + if isinstance(chat_response, ErrorResponse): + # Propagate chat completion errors directly + return chat_response + + # Apply output template if it exists + if self.output_template: + chat_response = self.apply_output_template(chat_response) + + # Calculate scores + scores = self.calculate_scores(chat_response) + + return ChatDetectionResponse.from_chat_completion_response( + chat_response, scores, self.DETECTION_TYPE + ) diff --git a/vllm_detector_adapter/protocol.py b/vllm_detector_adapter/protocol.py index 4f88da6..e412195 100644 --- a/vllm_detector_adapter/protocol.py +++ b/vllm_detector_adapter/protocol.py @@ -103,6 +103,11 @@ class ContextAnalysisRequest(BaseModel): # Parameters to pass through to chat completions, optional detector_params: Optional[Dict] = {} + # NOTE: currently there is no general to_chat_completion_request + # since the chat completion roles and messages are fairly tied + # to particular models' risk definitions. If a general strategy + # is identified, it can be implemented here. + ######## General modified response(s) for chat completions From 07aaff7b02704abead7d2e5bf07b5433d76d4fae Mon Sep 17 00:00:00 2001 From: Evaline Ju <69598118+evaline-ju@users.noreply.github.com> Date: Tue, 28 Jan 2025 14:58:59 -0700 Subject: [PATCH 06/12] :bug: Request detector params Signed-off-by: Evaline Ju <69598118+evaline-ju@users.noreply.github.com> --- vllm_detector_adapter/generative_detectors/granite_guardian.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm_detector_adapter/generative_detectors/granite_guardian.py b/vllm_detector_adapter/generative_detectors/granite_guardian.py index f27e297..0030321 100644 --- a/vllm_detector_adapter/generative_detectors/granite_guardian.py +++ b/vllm_detector_adapter/generative_detectors/granite_guardian.py @@ -117,7 +117,7 @@ def request_to_chat_completion_request( return ChatCompletionRequest( messages=messages, model=model_name, - **self.detector_params, + **request.detector_params, ) except ValidationError as e: return ErrorResponse( From 18921e9d5320b133f187ad134bc779447e239669 Mon Sep 17 00:00:00 2001 From: Evaline Ju <69598118+evaline-ju@users.noreply.github.com> Date: Wed, 29 Jan 2025 09:13:33 -0700 Subject: [PATCH 07/12] :goal_net::white_check_mark: Context analysis request to chat completion request error handling and tests Signed-off-by: Evaline Ju <69598118+evaline-ju@users.noreply.github.com> --- .../test_granite_guardian.py | 142 +++++++++++++++++- .../generative_detectors/granite_guardian.py | 16 +- 2 files changed, 151 insertions(+), 7 deletions(-) diff --git a/tests/generative_detectors/test_granite_guardian.py b/tests/generative_detectors/test_granite_guardian.py index c26db89..9fe8865 100644 --- a/tests/generative_detectors/test_granite_guardian.py +++ b/tests/generative_detectors/test_granite_guardian.py @@ -11,6 +11,7 @@ ChatCompletionLogProb, ChatCompletionLogProbs, ChatCompletionLogProbsContent, + ChatCompletionRequest, ChatCompletionResponse, ChatCompletionResponseChoice, ChatMessage, @@ -26,6 +27,7 @@ from vllm_detector_adapter.protocol import ( ChatDetectionRequest, ChatDetectionResponse, + ContextAnalysisRequest, DetectionChatMessageParam, ) @@ -33,6 +35,9 @@ CHAT_TEMPLATE = "Dummy chat template for testing {}" BASE_MODEL_PATHS = [BaseModelPath(name=MODEL_NAME, model_path=MODEL_NAME)] +CONTENT = "Where do I find geese?" +CONTEXT_DOC = "Geese can be found in lakes, ponds, and rivers" + @dataclass class MockTokenizer: @@ -155,8 +160,8 @@ def granite_guardian_completion_response(): ### Tests ##################################################################### -def test_preprocess_with_detector_params(granite_guardian_detection): - llama_guard_detection_instance = asyncio.run(granite_guardian_detection) +def test_preprocess_chat_request_with_detector_params(granite_guardian_detection): + granite_guardian_detection_instance = asyncio.run(granite_guardian_detection) # Make sure with addition of allowed params like risk_name and risk_definition, # extra params do not get added to guardian_config detector_params = { @@ -172,7 +177,7 @@ def test_preprocess_with_detector_params(granite_guardian_detection): ], detector_params=detector_params, ) - processed_request = llama_guard_detection_instance.preprocess_chat_request( + processed_request = granite_guardian_detection_instance.preprocess_chat_request( initial_request ) assert type(processed_request) == ChatDetectionRequest @@ -192,6 +197,133 @@ def test_preprocess_with_detector_params(granite_guardian_detection): } +def test_request_to_chat_completion_request_prompt_analysis(granite_guardian_detection): + granite_guardian_detection_instance = asyncio.run(granite_guardian_detection) + context_request = ContextAnalysisRequest( + content=CONTENT, + context_type="docs", + context=[CONTEXT_DOC], + detector_params={ + "n": 2, + "chat_template_kwargs": { + "guardian_config": {"risk_name": "context_relevance"} + }, + }, + ) + chat_request = ( + granite_guardian_detection_instance.request_to_chat_completion_request( + context_request, MODEL_NAME + ) + ) + assert type(chat_request) == ChatCompletionRequest + assert chat_request.messages[0]["role"] == "user" + assert chat_request.messages[0]["content"] == CONTENT + assert chat_request.messages[1]["role"] == "context" + assert chat_request.messages[1]["content"] == CONTEXT_DOC + assert chat_request.model == MODEL_NAME + # detector_paramas + assert chat_request.n == 2 + assert ( + chat_request.chat_template_kwargs["guardian_config"]["risk_name"] + == "context_relevance" + ) + + +def test_request_to_chat_completion_request_reponse_analysis( + granite_guardian_detection, +): + granite_guardian_detection_instance = asyncio.run(granite_guardian_detection) + context_request = ContextAnalysisRequest( + content=CONTENT, + context_type="docs", + context=[CONTEXT_DOC], + detector_params={ + "n": 3, + "chat_template_kwargs": {"guardian_config": {"risk_name": "groundedness"}}, + }, + ) + chat_request = ( + granite_guardian_detection_instance.request_to_chat_completion_request( + context_request, MODEL_NAME + ) + ) + assert type(chat_request) == ChatCompletionRequest + assert chat_request.messages[0]["role"] == "context" + assert chat_request.messages[0]["content"] == CONTEXT_DOC + assert chat_request.messages[1]["role"] == "assistant" + assert chat_request.messages[1]["content"] == CONTENT + assert chat_request.model == MODEL_NAME + # detector_paramas + assert chat_request.n == 3 + assert ( + chat_request.chat_template_kwargs["guardian_config"]["risk_name"] + == "groundedness" + ) + + +def test_request_to_chat_completion_request_empty_kwargs(granite_guardian_detection): + granite_guardian_detection_instance = asyncio.run(granite_guardian_detection) + context_request = ContextAnalysisRequest( + content=CONTENT, + context_type="docs", + context=[CONTEXT_DOC], + detector_params={"n": 2, "chat_template_kwargs": {}}, # no guardian config + ) + chat_request = ( + granite_guardian_detection_instance.request_to_chat_completion_request( + context_request, MODEL_NAME + ) + ) + assert type(chat_request) == ErrorResponse + assert chat_request.code == HTTPStatus.BAD_REQUEST + assert "No risk_name for context analysis" in chat_request.message + + +def test_request_to_chat_completion_request_empty_guardian_config( + granite_guardian_detection, +): + granite_guardian_detection_instance = asyncio.run(granite_guardian_detection) + context_request = ContextAnalysisRequest( + content=CONTENT, + context_type="docs", + context=[CONTEXT_DOC], + detector_params={"n": 2, "chat_template_kwargs": {"guardian_config": {}}}, + ) + chat_request = ( + granite_guardian_detection_instance.request_to_chat_completion_request( + context_request, MODEL_NAME + ) + ) + assert type(chat_request) == ErrorResponse + assert chat_request.code == HTTPStatus.BAD_REQUEST + assert "No risk_name for context analysis" in chat_request.message + + +def test_request_to_chat_completion_request_unsupported_risk_name( + granite_guardian_detection, +): + granite_guardian_detection_instance = asyncio.run(granite_guardian_detection) + context_request = ContextAnalysisRequest( + content=CONTENT, + context_type="docs", + context=[CONTEXT_DOC], + detector_params={ + "n": 2, + "chat_template_kwargs": {"guardian_config": {"risk_name": "foo"}}, + }, + ) + chat_request = ( + granite_guardian_detection_instance.request_to_chat_completion_request( + context_request, MODEL_NAME + ) + ) + assert type(chat_request) == ErrorResponse + assert chat_request.code == HTTPStatus.BAD_REQUEST + assert ( + "risk_name foo is not compatible with context analysis" in chat_request.message + ) + + # NOTE: currently these functions are basically just the base implementations, # where safe/unsafe tokens are defined in the granite guardian class @@ -199,8 +331,8 @@ def test_preprocess_with_detector_params(granite_guardian_detection): def test_calculate_scores( granite_guardian_detection, granite_guardian_completion_response ): - llama_guard_detection_instance = asyncio.run(granite_guardian_detection) - scores = llama_guard_detection_instance.calculate_scores( + granite_guardian_detection_instance = asyncio.run(granite_guardian_detection) + scores = granite_guardian_detection_instance.calculate_scores( granite_guardian_completion_response ) assert len(scores) == 2 # 2 choices diff --git a/vllm_detector_adapter/generative_detectors/granite_guardian.py b/vllm_detector_adapter/generative_detectors/granite_guardian.py index 0030321..a7abb74 100644 --- a/vllm_detector_adapter/generative_detectors/granite_guardian.py +++ b/vllm_detector_adapter/generative_detectors/granite_guardian.py @@ -30,7 +30,7 @@ class GraniteGuardian(ChatCompletionDetectionBase): UNSAFE_TOKEN = "Yes" # Risks associated with context analysis - PROMPT_CONTEXT_ANALYSIS_RISKS = ["relevance"] + PROMPT_CONTEXT_ANALYSIS_RISKS = ["context_relevance"] RESPONSE_CONTEXT_ANALYSIS_RISKS = ["groundedness"] def preprocess( @@ -68,16 +68,27 @@ def preprocess_chat_request( def request_to_chat_completion_request( self, request: ContextAnalysisRequest, model_name: str ) -> Union[ChatCompletionRequest, ErrorResponse]: + NO_RISK_NAME_MESSAGE = "No risk_name for context analysis" risk_name = None + if ( + "chat_template_kwargs" not in request.detector_params + or "guardian_config" not in request.detector_params["chat_template_kwargs"] + ): + return ErrorResponse( + message=NO_RISK_NAME_MESSAGE, + type="BadRequestError", + code=HTTPStatus.BAD_REQUEST.value, + ) # Use risk name to determine message format if guardian_config := request.detector_params["chat_template_kwargs"][ "guardian_config" ]: risk_name = guardian_config["risk_name"] else: + # Leaving off risk name can lead to model/template errors return ErrorResponse( - message="No risk_name for context analysis", + message=NO_RISK_NAME_MESSAGE, type="BadRequestError", code=HTTPStatus.BAD_REQUEST.value, ) @@ -89,6 +100,7 @@ def request_to_chat_completion_request( context_text = request.context[-1] content = request.content # The "context" role is not an officially support OpenAI role + # Messages must be in precise ordering, or model/template errors may occur if risk_name in self.RESPONSE_CONTEXT_ANALYSIS_RISKS: # Response analysis messages = [ From 90f7d13bef19c133fa3ab80e3b03833f94f1b93d Mon Sep 17 00:00:00 2001 From: Evaline Ju <69598118+evaline-ju@users.noreply.github.com> Date: Wed, 29 Jan 2025 09:29:35 -0700 Subject: [PATCH 08/12] :bulb::white_check_mark: Granite Guardian context analyze tests Signed-off-by: Evaline Ju <69598118+evaline-ju@users.noreply.github.com> --- .../test_granite_guardian.py | 35 ++++++++++++++++++- .../generative_detectors/granite_guardian.py | 7 ++-- 2 files changed, 38 insertions(+), 4 deletions(-) diff --git a/tests/generative_detectors/test_granite_guardian.py b/tests/generative_detectors/test_granite_guardian.py index 9fe8865..154f373 100644 --- a/tests/generative_detectors/test_granite_guardian.py +++ b/tests/generative_detectors/test_granite_guardian.py @@ -202,7 +202,10 @@ def test_request_to_chat_completion_request_prompt_analysis(granite_guardian_det context_request = ContextAnalysisRequest( content=CONTENT, context_type="docs", - context=[CONTEXT_DOC], + context=[ + "extra!", + CONTEXT_DOC, + ], # additionally test that only last context is used detector_params={ "n": 2, "chat_template_kwargs": { @@ -216,6 +219,7 @@ def test_request_to_chat_completion_request_prompt_analysis(granite_guardian_det ) ) assert type(chat_request) == ChatCompletionRequest + assert len(chat_request.messages) == 2 assert chat_request.messages[0]["role"] == "user" assert chat_request.messages[0]["content"] == CONTENT assert chat_request.messages[1]["role"] == "context" @@ -324,6 +328,35 @@ def test_request_to_chat_completion_request_unsupported_risk_name( ) +def test_context_analyze( + granite_guardian_detection, granite_guardian_completion_response +): + granite_guardian_detection_instance = asyncio.run(granite_guardian_detection) + context_request = ContextAnalysisRequest( + content=CONTENT, + context_type="docs", + context=[CONTEXT_DOC], + detector_params={ + "n": 2, + "chat_template_kwargs": {"guardian_config": {"risk_name": "groundedness"}}, + }, + ) + with patch( + "vllm_detector_adapter.generative_detectors.granite_guardian.GraniteGuardian.create_chat_completion", + return_value=granite_guardian_completion_response, + ): + detection_response = asyncio.run( + granite_guardian_detection_instance.context_analyze(context_request) + ) + assert type(detection_response) == ChatDetectionResponse + detections = detection_response.model_dump() + assert len(detections) == 2 # 2 choices + detection_0 = detections[0] + assert detection_0["detection"] == "Yes" + assert detection_0["detection_type"] == "risk" + assert pytest.approx(detection_0["score"]) == 1.0 + + # NOTE: currently these functions are basically just the base implementations, # where safe/unsafe tokens are defined in the granite guardian class diff --git a/vllm_detector_adapter/generative_detectors/granite_guardian.py b/vllm_detector_adapter/generative_detectors/granite_guardian.py index a7abb74..94235dd 100644 --- a/vllm_detector_adapter/generative_detectors/granite_guardian.py +++ b/vllm_detector_adapter/generative_detectors/granite_guardian.py @@ -99,8 +99,9 @@ def request_to_chat_completion_request( logger.warning("More than one context provided. Only the last will be used") context_text = request.context[-1] content = request.content - # The "context" role is not an officially support OpenAI role - # Messages must be in precise ordering, or model/template errors may occur + # The "context" role is not an officially support OpenAI role, so this is specific + # to Granite Guardian. Messages must also be in precise ordering, or model/template + # errors may occur. if risk_name in self.RESPONSE_CONTEXT_ANALYSIS_RISKS: # Response analysis messages = [ @@ -114,7 +115,7 @@ def request_to_chat_completion_request( {"role": "context", "content": context_text}, ] else: - # Return error if risks are not appropriate [or could default to one of the above analyses] + # Return error if risk names are not expected ones return ErrorResponse( message="risk_name {} is not compatible with context analysis".format( risk_name From 3350fe42284a7314b6f28a3e66f583021bb21f64 Mon Sep 17 00:00:00 2001 From: Evaline Ju <69598118+evaline-ju@users.noreply.github.com> Date: Wed, 29 Jan 2025 14:48:20 -0700 Subject: [PATCH 09/12] :pencil2: Update comments Signed-off-by: Evaline Ju <69598118+evaline-ju@users.noreply.github.com> --- .../generative_detectors/granite_guardian.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm_detector_adapter/generative_detectors/granite_guardian.py b/vllm_detector_adapter/generative_detectors/granite_guardian.py index 94235dd..6c48051 100644 --- a/vllm_detector_adapter/generative_detectors/granite_guardian.py +++ b/vllm_detector_adapter/generative_detectors/granite_guardian.py @@ -94,12 +94,12 @@ def request_to_chat_completion_request( ) if len(request.context) > 1: - # The API supports more than one context text but currently chat completions - # will only take one context + # The detector API for context docs detection supports more than one context text + # but currently chat completions will only take one context. logger.warning("More than one context provided. Only the last will be used") context_text = request.context[-1] content = request.content - # The "context" role is not an officially support OpenAI role, so this is specific + # The "context" role is not an officially supported OpenAI role, so this is specific # to Granite Guardian. Messages must also be in precise ordering, or model/template # errors may occur. if risk_name in self.RESPONSE_CONTEXT_ANALYSIS_RISKS: From 0c56c7d4dcba6063fae7f500b547195326cddaea Mon Sep 17 00:00:00 2001 From: Evaline Ju <69598118+evaline-ju@users.noreply.github.com> Date: Thu, 30 Jan 2025 09:50:34 -0700 Subject: [PATCH 10/12] :bulb::recycle: Section comments Signed-off-by: Evaline Ju <69598118+evaline-ju@users.noreply.github.com> --- .../generative_detectors/base.py | 16 +++++++++++----- vllm_detector_adapter/protocol.py | 8 ++++---- 2 files changed, 15 insertions(+), 9 deletions(-) diff --git a/vllm_detector_adapter/generative_detectors/base.py b/vllm_detector_adapter/generative_detectors/base.py index ea68aff..82879fb 100644 --- a/vllm_detector_adapter/generative_detectors/base.py +++ b/vllm_detector_adapter/generative_detectors/base.py @@ -36,6 +36,8 @@ def __init__(self, task_template: str, output_template: str, *args, **kwargs): self.output_template = self.load_template(output_template) + ##### Template functions ################################################### + def load_template(self, template_path: Optional[Union[Path, str]]) -> str: """Function to load template Note: this function currently is largely taken from the chat template method @@ -70,6 +72,14 @@ def load_template(self, template_path: Optional[Union[Path, str]]) -> str: logger.info("Using supplied template:\n%s", resolved_template) return self.jinja_env.from_string(resolved_template) + def apply_output_template( + self, response: ChatCompletionResponse + ) -> Union[ChatCompletionResponse, ErrorResponse]: + """Apply output parsing template for the response""" + return response + + ##### Chat request processing functions #################################### + def apply_task_template_to_chat( self, request: ChatDetectionRequest ) -> Union[ChatDetectionRequest, ErrorResponse]: @@ -82,11 +92,7 @@ def preprocess_chat_request( """Preprocess chat request""" return request - def apply_output_template( - self, response: ChatCompletionResponse - ) -> Union[ChatCompletionResponse, ErrorResponse]: - """Apply output parsing template for the response""" - return response + ##### General chat completion output processing functions ################## def calculate_scores(self, response: ChatCompletionResponse) -> List[float]: """Extract scores from logprobs of the raw chat response""" diff --git a/vllm_detector_adapter/protocol.py b/vllm_detector_adapter/protocol.py index e412195..f7f730c 100644 --- a/vllm_detector_adapter/protocol.py +++ b/vllm_detector_adapter/protocol.py @@ -11,7 +11,7 @@ ErrorResponse, ) -##### [FMS] Detection API types +##### [FMS] Detection API types ##### # Endpoints are as documented https://foundation-model-stack.github.io/fms-guardrails-orchestrator/?urls.primaryName=Detector+API#/ ######## Contents Detection types (currently unused) for the /text/contents detection endpoint @@ -35,7 +35,7 @@ class ContentsDetectionResponseObject(BaseModel): score: float = Field(examples=[0.5]) -######## Chat Detection types for the /text/chat detection endpoint +##### Chat Detection types for the /text/chat detection endpoint ############### class DetectionChatMessageParam(TypedDict): @@ -85,7 +85,7 @@ def to_chat_completion_request(self, model_name: str): ) -######## Context Analysis Detection types for the /text/context/docs detection endpoint +##### Context Analysis Detection types for the /text/context/docs detection endpoint class ContextAnalysisRequest(BaseModel): @@ -109,7 +109,7 @@ class ContextAnalysisRequest(BaseModel): # is identified, it can be implemented here. -######## General modified response(s) for chat completions +##### General modified response(s) for chat completions ######################## class ChatDetectionResponseObject(BaseModel): From 993c989fb2d0dc2ef94d70f7120288b12a7bbc1f Mon Sep 17 00:00:00 2001 From: Evaline Ju <69598118+evaline-ju@users.noreply.github.com> Date: Thu, 30 Jan 2025 10:22:48 -0700 Subject: [PATCH 11/12] :recycle::white_check_mark: Consolidate chat completion response call and processing Signed-off-by: Evaline Ju <69598118+evaline-ju@users.noreply.github.com> --- .../test_granite_guardian.py | 6 +- .../generative_detectors/test_llama_guard.py | 4 +- tests/test_protocol.py | 8 +-- vllm_detector_adapter/api_server.py | 6 +- .../generative_detectors/base.py | 71 ++++++++++--------- .../generative_detectors/granite_guardian.py | 48 ++----------- vllm_detector_adapter/protocol.py | 12 ++-- 7 files changed, 62 insertions(+), 93 deletions(-) diff --git a/tests/generative_detectors/test_granite_guardian.py b/tests/generative_detectors/test_granite_guardian.py index 154f373..e063c65 100644 --- a/tests/generative_detectors/test_granite_guardian.py +++ b/tests/generative_detectors/test_granite_guardian.py @@ -26,9 +26,9 @@ from vllm_detector_adapter.generative_detectors.granite_guardian import GraniteGuardian from vllm_detector_adapter.protocol import ( ChatDetectionRequest, - ChatDetectionResponse, ContextAnalysisRequest, DetectionChatMessageParam, + DetectionResponse, ) MODEL_NAME = "ibm-granite/granite-guardian" # Example granite-guardian model @@ -348,7 +348,7 @@ def test_context_analyze( detection_response = asyncio.run( granite_guardian_detection_instance.context_analyze(context_request) ) - assert type(detection_response) == ChatDetectionResponse + assert type(detection_response) == DetectionResponse detections = detection_response.model_dump() assert len(detections) == 2 # 2 choices detection_0 = detections[0] @@ -391,7 +391,7 @@ def test_chat_detection( detection_response = asyncio.run( granite_guardian_detection_instance.chat(chat_request) ) - assert type(detection_response) == ChatDetectionResponse + assert type(detection_response) == DetectionResponse detections = detection_response.model_dump() assert len(detections) == 2 # 2 choices detection_0 = detections[0] diff --git a/tests/generative_detectors/test_llama_guard.py b/tests/generative_detectors/test_llama_guard.py index 986c8f6..0d68366 100644 --- a/tests/generative_detectors/test_llama_guard.py +++ b/tests/generative_detectors/test_llama_guard.py @@ -25,9 +25,9 @@ from vllm_detector_adapter.generative_detectors.llama_guard import LlamaGuard from vllm_detector_adapter.protocol import ( ChatDetectionRequest, - ChatDetectionResponse, ContextAnalysisRequest, DetectionChatMessageParam, + DetectionResponse, ) MODEL_NAME = "meta-llama/Llama-Guard-3-8B" # Example llama guard model @@ -189,7 +189,7 @@ def test_chat_detection(llama_guard_detection, llama_guard_completion_response): detection_response = asyncio.run( llama_guard_detection_instance.chat(chat_request) ) - assert type(detection_response) == ChatDetectionResponse + assert type(detection_response) == DetectionResponse detections = detection_response.model_dump() assert len(detections) == 2 # 2 choices detection_0 = detections[0] diff --git a/tests/test_protocol.py b/tests/test_protocol.py index 8746dd7..9178ae6 100644 --- a/tests/test_protocol.py +++ b/tests/test_protocol.py @@ -14,8 +14,8 @@ # Local from vllm_detector_adapter.protocol import ( ChatDetectionRequest, - ChatDetectionResponse, DetectionChatMessageParam, + DetectionResponse, ) MODEL_NAME = "org/model-name" @@ -86,10 +86,10 @@ def test_response_from_completion_response(): ) scores = [0.3, 0.7] detection_type = "type" - detection_response = ChatDetectionResponse.from_chat_completion_response( + detection_response = DetectionResponse.from_chat_completion_response( response, scores, detection_type ) - assert type(detection_response) == ChatDetectionResponse + assert type(detection_response) == DetectionResponse detections = detection_response.model_dump() assert len(detections) == 2 # 2 choices detection_0 = detections[0] @@ -120,7 +120,7 @@ def test_response_from_completion_response_missing_content(): ) scores = [0.3, 0.7] detection_type = "type" - detection_response = ChatDetectionResponse.from_chat_completion_response( + detection_response = DetectionResponse.from_chat_completion_response( response, scores, detection_type ) assert type(detection_response) == ErrorResponse diff --git a/vllm_detector_adapter/api_server.py b/vllm_detector_adapter/api_server.py index 7f1f7a1..6489e05 100644 --- a/vllm_detector_adapter/api_server.py +++ b/vllm_detector_adapter/api_server.py @@ -27,8 +27,8 @@ from vllm_detector_adapter.logging import init_logger from vllm_detector_adapter.protocol import ( ChatDetectionRequest, - ChatDetectionResponse, ContextAnalysisRequest, + DetectionResponse, ) TIMEOUT_KEEP_ALIVE = 5 # seconds @@ -160,7 +160,7 @@ async def create_chat_detection(request: ChatDetectionRequest, raw_request: Requ content=detector_response.model_dump(), status_code=detector_response.code ) - elif isinstance(detector_response, ChatDetectionResponse): + elif isinstance(detector_response, DetectionResponse): return JSONResponse(content=detector_response.model_dump()) return JSONResponse({}) @@ -182,7 +182,7 @@ async def create_context_doc_detection( content=detector_response.model_dump(), status_code=detector_response.code ) - elif isinstance(detector_response, ChatDetectionResponse): + elif isinstance(detector_response, DetectionResponse): return JSONResponse(content=detector_response.model_dump()) return JSONResponse({}) diff --git a/vllm_detector_adapter/generative_detectors/base.py b/vllm_detector_adapter/generative_detectors/base.py index 82879fb..60449f5 100644 --- a/vllm_detector_adapter/generative_detectors/base.py +++ b/vllm_detector_adapter/generative_detectors/base.py @@ -16,8 +16,8 @@ from vllm_detector_adapter.logging import init_logger from vllm_detector_adapter.protocol import ( ChatDetectionRequest, - ChatDetectionResponse, ContextAnalysisRequest, + DetectionResponse, ) logger = init_logger(__name__) @@ -126,35 +126,9 @@ def calculate_scores(self, response: ChatCompletionResponse) -> List[float]: return choice_scores - ##### Detection methods #################################################### - # Base implementation of other detection endpoints like content can go here - - async def chat( - self, - request: ChatDetectionRequest, - raw_request: Optional[Request] = None, - ) -> Union[ChatDetectionResponse, ErrorResponse]: - """Function used to call chat detection and provide a /chat response""" - - # Fetch model name from super class: OpenAIServing - model_name = self.models.base_model_paths[0].name - - # Apply task template if it exists - if self.task_template: - request = self.apply_task_template_to_chat(request) - if isinstance(request, ErrorResponse): - # Propagate any request problems that will not allow - # task template to be applied - return request - - # Optionally make model-dependent adjustments for the request - request = self.preprocess_chat_request(request) - - chat_completion_request = request.to_chat_completion_request(model_name) - if isinstance(chat_completion_request, ErrorResponse): - # Propagate any request problems - return chat_completion_request - + async def process_chat_completion_with_scores( + self, chat_completion_request, raw_request + ) -> Union[DetectionResponse, ErrorResponse]: # Return an error for streaming for now. Since the detector API is unary, # results would not be streamed back anyway. The chat completion response # object would look different, and content would have to be aggregated. @@ -192,15 +166,48 @@ async def chat( # Calculate scores scores = self.calculate_scores(chat_response) - return ChatDetectionResponse.from_chat_completion_response( + return DetectionResponse.from_chat_completion_response( chat_response, scores, self.DETECTION_TYPE ) + ##### Detection methods #################################################### + # Base implementation of other detection endpoints like content can go here + + async def chat( + self, + request: ChatDetectionRequest, + raw_request: Optional[Request] = None, + ) -> Union[DetectionResponse, ErrorResponse]: + """Function used to call chat detection and provide a /chat response""" + + # Fetch model name from super class: OpenAIServing + model_name = self.models.base_model_paths[0].name + + # Apply task template if it exists + if self.task_template: + request = self.apply_task_template_to_chat(request) + if isinstance(request, ErrorResponse): + # Propagate any request problems that will not allow + # task template to be applied + return request + + # Optionally make model-dependent adjustments for the request + request = self.preprocess_chat_request(request) + + chat_completion_request = request.to_chat_completion_request(model_name) + if isinstance(chat_completion_request, ErrorResponse): + # Propagate any request problems + return chat_completion_request + + return await self.process_chat_completion_with_scores( + chat_completion_request, raw_request + ) + async def context_analyze( self, request: ContextAnalysisRequest, raw_request: Optional[Request] = None, - ) -> Union[ChatDetectionResponse, ErrorResponse]: + ) -> Union[DetectionResponse, ErrorResponse]: """Function used to call chat detection and provide a /context/doc response""" # Return "not implemented" here since context analysis may not # generally apply to all models at this time diff --git a/vllm_detector_adapter/generative_detectors/granite_guardian.py b/vllm_detector_adapter/generative_detectors/granite_guardian.py index 6c48051..6fa6502 100644 --- a/vllm_detector_adapter/generative_detectors/granite_guardian.py +++ b/vllm_detector_adapter/generative_detectors/granite_guardian.py @@ -12,8 +12,8 @@ from vllm_detector_adapter.logging import init_logger from vllm_detector_adapter.protocol import ( ChatDetectionRequest, - ChatDetectionResponse, ContextAnalysisRequest, + DetectionResponse, ) logger = init_logger(__name__) @@ -143,7 +143,7 @@ async def context_analyze( self, request: ContextAnalysisRequest, raw_request: Optional[Request] = None, - ) -> Union[ChatDetectionResponse, ErrorResponse]: + ) -> Union[DetectionResponse, ErrorResponse]: """Function used to call chat detection and provide a /context/doc response""" # Fetch model name from super class: OpenAIServing model_name = self.models.base_model_paths[0].name @@ -169,46 +169,8 @@ async def context_analyze( # Propagate any request problems return chat_completion_request - # Much of this chat completions processing is similar to the - # .chat case and could be refactored if reused further in the future - - # Return an error for streaming for now. Since the detector API is unary, - # results would not be streamed back anyway. The chat completion response - # object would look different, and content would have to be aggregated. - if chat_completion_request.stream: - return ErrorResponse( - message="streaming is not supported for the detector", - type="BadRequestError", - code=HTTPStatus.BAD_REQUEST.value, - ) - - # Manually set logprobs to True to calculate score later on - # NOTE: this is supposed to override if user has set logprobs to False - # or left logprobs as the default False - chat_completion_request.logprobs = True - # NOTE: We need top_logprobs to be enabled to calculate score appropriately - # We override this and not allow configuration at this point. In future, we may - # want to expose this configurable to certain range. - chat_completion_request.top_logprobs = 5 - - logger.debug("Request to chat completion: %s", chat_completion_request) - - # Call chat completion - chat_response = await self.create_chat_completion( + # Calling chat completion and processing of scores is currently + # the same as for the /chat case + return await self.process_chat_completion_with_scores( chat_completion_request, raw_request ) - logger.debug("Raw chat completion response: %s", chat_response) - if isinstance(chat_response, ErrorResponse): - # Propagate chat completion errors directly - return chat_response - - # Apply output template if it exists - if self.output_template: - chat_response = self.apply_output_template(chat_response) - - # Calculate scores - scores = self.calculate_scores(chat_response) - - return ChatDetectionResponse.from_chat_completion_response( - chat_response, scores, self.DETECTION_TYPE - ) diff --git a/vllm_detector_adapter/protocol.py b/vllm_detector_adapter/protocol.py index f7f730c..dbb054f 100644 --- a/vllm_detector_adapter/protocol.py +++ b/vllm_detector_adapter/protocol.py @@ -109,19 +109,19 @@ class ContextAnalysisRequest(BaseModel): # is identified, it can be implemented here. -##### General modified response(s) for chat completions ######################## +##### General detection response objects ####################################### -class ChatDetectionResponseObject(BaseModel): +class DetectionResponseObject(BaseModel): detection: str = Field(examples=["positive"]) detection_type: str = Field(examples=["simple_example"]) score: float = Field(examples=[0.5]) -class ChatDetectionResponse(RootModel): +class DetectionResponse(RootModel): # The root attribute is used here so that the response will appear # as a list instead of a list nested under a key - root: List[ChatDetectionResponseObject] + root: List[DetectionResponseObject] @staticmethod def from_chat_completion_response( @@ -132,7 +132,7 @@ def from_chat_completion_response( for i, choice in enumerate(response.choices): content = choice.message.content if content and isinstance(content, str): - response_object = ChatDetectionResponseObject( + response_object = DetectionResponseObject( detection_type=detection_type, detection=content.strip(), score=scores[i], @@ -150,4 +150,4 @@ def from_chat_completion_response( code=HTTPStatus.BAD_REQUEST.value, ) - return ChatDetectionResponse(root=detection_responses) + return DetectionResponse(root=detection_responses) From 3bf85d22d6c359404cd834681c0c7c6bf14a43fe Mon Sep 17 00:00:00 2001 From: Gaurav-Kumbhat Date: Thu, 30 Jan 2025 16:01:05 -0600 Subject: [PATCH 12/12] :pushpin: pin vllm dependencies to make them work on mac m1 Signed-off-by: Gaurav-Kumbhat --- pyproject.toml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 8daf8b2..87d43fe 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,7 +13,8 @@ classifiers = [ ] dependencies = [ - "vllm>=0.7.0" + "vllm @ git+https://github.com/vllm-project/vllm.git@v0.7.0 ; sys_platform == 'darwin'", + "vllm>=0.7.0 ; sys_platform != 'darwin'", ] [project.optional-dependencies]