Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ classifiers = [
]

dependencies = [
"vllm>=0.7.0"
"vllm @ git+https://github.com/vllm-project/[email protected] ; sys_platform == 'darwin'",
"vllm>=0.7.1 ; sys_platform != 'darwin'",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this one strictly needed to change since it should be backwards compatible, but likely not an issue for current usage

]

[project.optional-dependencies]
Expand Down
25 changes: 13 additions & 12 deletions tests/generative_detectors/test_granite_guardian.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
DetectionChatMessageParam,
DetectionResponse,
)
from vllm_detector_adapter.utils import DetectorType

MODEL_NAME = "ibm-granite/granite-guardian" # Example granite-guardian model
CHAT_TEMPLATE = "Dummy chat template for testing {}"
Expand Down Expand Up @@ -177,8 +178,8 @@ def test_preprocess_chat_request_with_detector_params(granite_guardian_detection
],
detector_params=detector_params,
)
processed_request = granite_guardian_detection_instance.preprocess_chat_request(
initial_request
processed_request = granite_guardian_detection_instance.preprocess_request(
initial_request, fn_type=DetectorType.TEXT_CHAT
)
assert type(processed_request) == ChatDetectionRequest
# Processed request should not have these extra params
Expand Down Expand Up @@ -214,8 +215,8 @@ def test_request_to_chat_completion_request_prompt_analysis(granite_guardian_det
},
)
chat_request = (
granite_guardian_detection_instance.request_to_chat_completion_request(
context_request, MODEL_NAME
granite_guardian_detection_instance._request_to_chat_completion_request(
context_request, MODEL_NAME, fn_type=DetectorType.TEXT_CONTEXT_DOC
)
)
assert type(chat_request) == ChatCompletionRequest
Expand Down Expand Up @@ -247,8 +248,8 @@ def test_request_to_chat_completion_request_reponse_analysis(
},
)
chat_request = (
granite_guardian_detection_instance.request_to_chat_completion_request(
context_request, MODEL_NAME
granite_guardian_detection_instance._request_to_chat_completion_request(
context_request, MODEL_NAME, fn_type=DetectorType.TEXT_CONTEXT_DOC
)
)
assert type(chat_request) == ChatCompletionRequest
Expand All @@ -274,8 +275,8 @@ def test_request_to_chat_completion_request_empty_kwargs(granite_guardian_detect
detector_params={"n": 2, "chat_template_kwargs": {}}, # no guardian config
)
chat_request = (
granite_guardian_detection_instance.request_to_chat_completion_request(
context_request, MODEL_NAME
granite_guardian_detection_instance._request_to_chat_completion_request(
context_request, MODEL_NAME, fn_type=DetectorType.TEXT_CONTEXT_DOC
)
)
assert type(chat_request) == ErrorResponse
Expand All @@ -294,8 +295,8 @@ def test_request_to_chat_completion_request_empty_guardian_config(
detector_params={"n": 2, "chat_template_kwargs": {"guardian_config": {}}},
)
chat_request = (
granite_guardian_detection_instance.request_to_chat_completion_request(
context_request, MODEL_NAME
granite_guardian_detection_instance._request_to_chat_completion_request(
context_request, MODEL_NAME, fn_type=DetectorType.TEXT_CONTEXT_DOC
)
)
assert type(chat_request) == ErrorResponse
Expand All @@ -317,8 +318,8 @@ def test_request_to_chat_completion_request_unsupported_risk_name(
},
)
chat_request = (
granite_guardian_detection_instance.request_to_chat_completion_request(
context_request, MODEL_NAME
granite_guardian_detection_instance._request_to_chat_completion_request(
context_request, MODEL_NAME, fn_type=DetectorType.TEXT_CONTEXT_DOC
)
)
assert type(chat_request) == ErrorResponse
Expand Down
81 changes: 81 additions & 0 deletions vllm_detector_adapter/detector_dispatcher.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# Standard
import functools

# global list to store all the registered functions with
# their types and qualified name
global_fn_list = dict()


def detector_dispatcher(types=None):
"""Decorator to dispatch to processing function based on type of the detector.

This decorator allows us to reuse same function name for different types of detectors.
For example, the same function name can be used for text chat and context analysis
detectors. These decorated functions for these detectors will have different arguments
and implementation but they share the same function name.

NOTE: At the time of invoking these decorated function, the user needs to specify the type
of the detector using fn_type argument.

CAUTION: Since this decorator allow re-use of the name, one must take care of using different types
for testing different functions.

Args:
types (list): Type of the detector this function applies to.
args: Positional arguments passed to the processing function.
kwargs: Keyword arguments passed to the processing function.

Examples
--------

@detector_dispatcher(types=["foo"])
def f(x):
pass

# Decorator can take multiple types as well
@detector_dispatcher(types=["bar", "baz"])
def f(x):
pass

When calling these functions, one can specify the type of the detector as follows:
f(x, fn_type="foo")
f(x, fn_type="bar")
"""
global global_fn_list

if not types:
raise ValueError("Must specify types.")

def decorator(func):
fn_name = func.__qualname__

if fn_name not in global_fn_list:
# Associate each function with its type to create a dictionary of form:
# {"fn_name": {type1: function, type2: function}
# NOTE: "function" here are really function pointers
global_fn_list[fn_name] = {t: func for t in types}
elif fn_name in global_fn_list and (types & global_fn_list[fn_name].keys()):
# Error out if the types function with same type declaration exist in the global
# list already
raise ValueError("Function already registered with the same types.")
else:
# Add the function to the global list with corresponding type
global_fn_list[fn_name] |= {t: func for t in types}

@functools.wraps(func)
def wrapper(*args, fn_type=None, **kwargs):
fn_name = func.__qualname__

if not fn_type:
raise ValueError("Must specify fn_type.")

if fn_type not in global_fn_list[fn_name].keys():
raise ValueError("Invalid fn_type.")

# Grab the function using its fully qualified name and the specified type
# and then call it
return global_fn_list[fn_name][fn_type](*args, **kwargs)

return wrapper

return decorator
14 changes: 10 additions & 4 deletions vllm_detector_adapter/generative_detectors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,14 @@
import torch

# Local
from vllm_detector_adapter.detector_dispatcher import detector_dispatcher
from vllm_detector_adapter.logging import init_logger
from vllm_detector_adapter.protocol import (
ChatDetectionRequest,
ContextAnalysisRequest,
DetectionResponse,
)
from vllm_detector_adapter.utils import DetectorType

logger = init_logger(__name__)

Expand Down Expand Up @@ -80,13 +82,17 @@ def apply_output_template(

##### Chat request processing functions ####################################

def apply_task_template_to_chat(
# Usage of detector_dispatcher allows same function name to be called for different types of
# detectors with different arguments and implementation.
@detector_dispatcher(types=[DetectorType.TEXT_CHAT])
def apply_task_template(
self, request: ChatDetectionRequest
) -> Union[ChatDetectionRequest, ErrorResponse]:
"""Apply task template on the chat request"""
return request

def preprocess_chat_request(
@detector_dispatcher(types=[DetectorType.TEXT_CHAT])
def preprocess_request(
self, request: ChatDetectionRequest
) -> Union[ChatDetectionRequest, ErrorResponse]:
"""Preprocess chat request"""
Expand Down Expand Up @@ -185,14 +191,14 @@ async def chat(

# Apply task template if it exists
if self.task_template:
request = self.apply_task_template_to_chat(request)
request = self.apply_task_template(request, fn_type=DetectorType.TEXT_CHAT)
if isinstance(request, ErrorResponse):
# Propagate any request problems that will not allow
# task template to be applied
return request

# Optionally make model-dependent adjustments for the request
request = self.preprocess_chat_request(request)
request = self.preprocess_request(request, fn_type=DetectorType.TEXT_CHAT)

chat_completion_request = request.to_chat_completion_request(model_name)
if isinstance(chat_completion_request, ErrorResponse):
Expand Down
34 changes: 23 additions & 11 deletions vllm_detector_adapter/generative_detectors/granite_guardian.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,15 @@
from vllm.entrypoints.openai.protocol import ChatCompletionRequest, ErrorResponse

# Local
from vllm_detector_adapter.detector_dispatcher import detector_dispatcher
from vllm_detector_adapter.generative_detectors.base import ChatCompletionDetectionBase
from vllm_detector_adapter.logging import init_logger
from vllm_detector_adapter.protocol import (
ChatDetectionRequest,
ContextAnalysisRequest,
DetectionResponse,
)
from vllm_detector_adapter.utils import DetectorType

logger = init_logger(__name__)

Expand All @@ -33,7 +35,9 @@ class GraniteGuardian(ChatCompletionDetectionBase):
PROMPT_CONTEXT_ANALYSIS_RISKS = ["context_relevance"]
RESPONSE_CONTEXT_ANALYSIS_RISKS = ["groundedness"]

def preprocess(
##### Private / Internal functions ###################################################

def __preprocess(
self, request: Union[ChatDetectionRequest, ContextAnalysisRequest]
) -> Union[ChatDetectionRequest, ContextAnalysisRequest, ErrorResponse]:
"""Granite guardian specific parameter updates for risk name and risk definition"""
Expand All @@ -59,13 +63,10 @@ def preprocess(

return request

def preprocess_chat_request(
self, request: ChatDetectionRequest
) -> Union[ChatDetectionRequest, ErrorResponse]:
"""Granite guardian chat request preprocess is just detector parameter updates"""
return self.preprocess(request)

def request_to_chat_completion_request(
# Decorating this function to make it cleaner for future iterations of this function
# to support other types of detectors
@detector_dispatcher(types=[DetectorType.TEXT_CONTEXT_DOC])
def _request_to_chat_completion_request(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: if we're keeping this "private" since this only applies to granite guardian, maybe this should go in the "Private functions" section above?

Copy link
Collaborator Author

@gkumbhat gkumbhat Jan 31, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I kept these "internal", since there were unit tests for this particular function, so testing them after making them private would get interesting..

But might still make sense to move them to private block and change name of that block

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah I misread the number of underscores, it might be good to have an "internal functions" block then

self, request: ContextAnalysisRequest, model_name: str
) -> Union[ChatCompletionRequest, ErrorResponse]:
NO_RISK_NAME_MESSAGE = "No risk_name for context analysis"
Expand Down Expand Up @@ -141,6 +142,17 @@ def request_to_chat_completion_request(
code=HTTPStatus.BAD_REQUEST.value,
)

##### General request / response processing functions ##################

# Used detector_dispatcher decorator to allow for the same function to be called
# for different types of detectors with different request types etc.
@detector_dispatcher(types=[DetectorType.TEXT_CHAT])
def preprocess_request(
self, request: ChatDetectionRequest
) -> Union[ChatDetectionRequest, ErrorResponse]:
"""Granite guardian chat request preprocess is just detector parameter updates"""
return self.__preprocess(request)

async def context_analyze(
self,
request: ContextAnalysisRequest,
Expand All @@ -152,13 +164,13 @@ async def context_analyze(

# Task template not applied for context analysis at this time
# Make model-dependent adjustments for the request
request = self.preprocess(request)
request = self.__preprocess(request)

# Since particular chat messages are dependent on Granite Guardian risk definitions,
# the processing is done here rather than in a separate, general to_chat_completion_request
# for all context analysis requests.
chat_completion_request = self.request_to_chat_completion_request(
request, model_name
chat_completion_request = self._request_to_chat_completion_request(
request, model_name, fn_type=DetectorType.TEXT_CONTEXT_DOC
)
if isinstance(chat_completion_request, ErrorResponse):
# Propagate any request problems
Expand Down
11 changes: 11 additions & 0 deletions vllm_detector_adapter/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# Standard
from enum import Enum, auto


class DetectorType(Enum):
"""Enum to represent different types of detectors"""

TEXT_CONTENT = auto()
TEXT_GENERATION = auto()
TEXT_CHAT = auto()
TEXT_CONTEXT_DOC = auto()