diff --git a/pyproject.toml b/pyproject.toml index 8daf8b2..920f41b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,7 +13,8 @@ classifiers = [ ] dependencies = [ - "vllm>=0.7.0" + "vllm @ git+https://github.com/vllm-project/vllm.git@v0.7.1 ; sys_platform == 'darwin'", + "vllm>=0.7.1 ; sys_platform != 'darwin'", ] [project.optional-dependencies] diff --git a/tests/generative_detectors/test_granite_guardian.py b/tests/generative_detectors/test_granite_guardian.py index d7843df..e753e38 100644 --- a/tests/generative_detectors/test_granite_guardian.py +++ b/tests/generative_detectors/test_granite_guardian.py @@ -30,6 +30,7 @@ DetectionChatMessageParam, DetectionResponse, ) +from vllm_detector_adapter.utils import DetectorType MODEL_NAME = "ibm-granite/granite-guardian" # Example granite-guardian model CHAT_TEMPLATE = "Dummy chat template for testing {}" @@ -177,8 +178,8 @@ def test_preprocess_chat_request_with_detector_params(granite_guardian_detection ], detector_params=detector_params, ) - processed_request = granite_guardian_detection_instance.preprocess_chat_request( - initial_request + processed_request = granite_guardian_detection_instance.preprocess_request( + initial_request, fn_type=DetectorType.TEXT_CHAT ) assert type(processed_request) == ChatDetectionRequest # Processed request should not have these extra params @@ -214,8 +215,8 @@ def test_request_to_chat_completion_request_prompt_analysis(granite_guardian_det }, ) chat_request = ( - granite_guardian_detection_instance.request_to_chat_completion_request( - context_request, MODEL_NAME + granite_guardian_detection_instance._request_to_chat_completion_request( + context_request, MODEL_NAME, fn_type=DetectorType.TEXT_CONTEXT_DOC ) ) assert type(chat_request) == ChatCompletionRequest @@ -247,8 +248,8 @@ def test_request_to_chat_completion_request_reponse_analysis( }, ) chat_request = ( - granite_guardian_detection_instance.request_to_chat_completion_request( - context_request, MODEL_NAME + granite_guardian_detection_instance._request_to_chat_completion_request( + context_request, MODEL_NAME, fn_type=DetectorType.TEXT_CONTEXT_DOC ) ) assert type(chat_request) == ChatCompletionRequest @@ -274,8 +275,8 @@ def test_request_to_chat_completion_request_empty_kwargs(granite_guardian_detect detector_params={"n": 2, "chat_template_kwargs": {}}, # no guardian config ) chat_request = ( - granite_guardian_detection_instance.request_to_chat_completion_request( - context_request, MODEL_NAME + granite_guardian_detection_instance._request_to_chat_completion_request( + context_request, MODEL_NAME, fn_type=DetectorType.TEXT_CONTEXT_DOC ) ) assert type(chat_request) == ErrorResponse @@ -294,8 +295,8 @@ def test_request_to_chat_completion_request_empty_guardian_config( detector_params={"n": 2, "chat_template_kwargs": {"guardian_config": {}}}, ) chat_request = ( - granite_guardian_detection_instance.request_to_chat_completion_request( - context_request, MODEL_NAME + granite_guardian_detection_instance._request_to_chat_completion_request( + context_request, MODEL_NAME, fn_type=DetectorType.TEXT_CONTEXT_DOC ) ) assert type(chat_request) == ErrorResponse @@ -317,8 +318,8 @@ def test_request_to_chat_completion_request_unsupported_risk_name( }, ) chat_request = ( - granite_guardian_detection_instance.request_to_chat_completion_request( - context_request, MODEL_NAME + granite_guardian_detection_instance._request_to_chat_completion_request( + context_request, MODEL_NAME, fn_type=DetectorType.TEXT_CONTEXT_DOC ) ) assert type(chat_request) == ErrorResponse diff --git a/vllm_detector_adapter/detector_dispatcher.py b/vllm_detector_adapter/detector_dispatcher.py new file mode 100644 index 0000000..83074b0 --- /dev/null +++ b/vllm_detector_adapter/detector_dispatcher.py @@ -0,0 +1,81 @@ +# Standard +import functools + +# global list to store all the registered functions with +# their types and qualified name +global_fn_list = dict() + + +def detector_dispatcher(types=None): + """Decorator to dispatch to processing function based on type of the detector. + + This decorator allows us to reuse same function name for different types of detectors. + For example, the same function name can be used for text chat and context analysis + detectors. These decorated functions for these detectors will have different arguments + and implementation but they share the same function name. + + NOTE: At the time of invoking these decorated function, the user needs to specify the type + of the detector using fn_type argument. + + CAUTION: Since this decorator allow re-use of the name, one must take care of using different types + for testing different functions. + + Args: + types (list): Type of the detector this function applies to. + args: Positional arguments passed to the processing function. + kwargs: Keyword arguments passed to the processing function. + + Examples + -------- + + @detector_dispatcher(types=["foo"]) + def f(x): + pass + + # Decorator can take multiple types as well + @detector_dispatcher(types=["bar", "baz"]) + def f(x): + pass + + When calling these functions, one can specify the type of the detector as follows: + f(x, fn_type="foo") + f(x, fn_type="bar") + """ + global global_fn_list + + if not types: + raise ValueError("Must specify types.") + + def decorator(func): + fn_name = func.__qualname__ + + if fn_name not in global_fn_list: + # Associate each function with its type to create a dictionary of form: + # {"fn_name": {type1: function, type2: function} + # NOTE: "function" here are really function pointers + global_fn_list[fn_name] = {t: func for t in types} + elif fn_name in global_fn_list and (types & global_fn_list[fn_name].keys()): + # Error out if the types function with same type declaration exist in the global + # list already + raise ValueError("Function already registered with the same types.") + else: + # Add the function to the global list with corresponding type + global_fn_list[fn_name] |= {t: func for t in types} + + @functools.wraps(func) + def wrapper(*args, fn_type=None, **kwargs): + fn_name = func.__qualname__ + + if not fn_type: + raise ValueError("Must specify fn_type.") + + if fn_type not in global_fn_list[fn_name].keys(): + raise ValueError("Invalid fn_type.") + + # Grab the function using its fully qualified name and the specified type + # and then call it + return global_fn_list[fn_name][fn_type](*args, **kwargs) + + return wrapper + + return decorator diff --git a/vllm_detector_adapter/generative_detectors/base.py b/vllm_detector_adapter/generative_detectors/base.py index 60449f5..9a60717 100644 --- a/vllm_detector_adapter/generative_detectors/base.py +++ b/vllm_detector_adapter/generative_detectors/base.py @@ -13,12 +13,14 @@ import torch # Local +from vllm_detector_adapter.detector_dispatcher import detector_dispatcher from vllm_detector_adapter.logging import init_logger from vllm_detector_adapter.protocol import ( ChatDetectionRequest, ContextAnalysisRequest, DetectionResponse, ) +from vllm_detector_adapter.utils import DetectorType logger = init_logger(__name__) @@ -80,13 +82,17 @@ def apply_output_template( ##### Chat request processing functions #################################### - def apply_task_template_to_chat( + # Usage of detector_dispatcher allows same function name to be called for different types of + # detectors with different arguments and implementation. + @detector_dispatcher(types=[DetectorType.TEXT_CHAT]) + def apply_task_template( self, request: ChatDetectionRequest ) -> Union[ChatDetectionRequest, ErrorResponse]: """Apply task template on the chat request""" return request - def preprocess_chat_request( + @detector_dispatcher(types=[DetectorType.TEXT_CHAT]) + def preprocess_request( self, request: ChatDetectionRequest ) -> Union[ChatDetectionRequest, ErrorResponse]: """Preprocess chat request""" @@ -185,14 +191,14 @@ async def chat( # Apply task template if it exists if self.task_template: - request = self.apply_task_template_to_chat(request) + request = self.apply_task_template(request, fn_type=DetectorType.TEXT_CHAT) if isinstance(request, ErrorResponse): # Propagate any request problems that will not allow # task template to be applied return request # Optionally make model-dependent adjustments for the request - request = self.preprocess_chat_request(request) + request = self.preprocess_request(request, fn_type=DetectorType.TEXT_CHAT) chat_completion_request = request.to_chat_completion_request(model_name) if isinstance(chat_completion_request, ErrorResponse): diff --git a/vllm_detector_adapter/generative_detectors/granite_guardian.py b/vllm_detector_adapter/generative_detectors/granite_guardian.py index 1f296e5..ff99584 100644 --- a/vllm_detector_adapter/generative_detectors/granite_guardian.py +++ b/vllm_detector_adapter/generative_detectors/granite_guardian.py @@ -8,6 +8,7 @@ from vllm.entrypoints.openai.protocol import ChatCompletionRequest, ErrorResponse # Local +from vllm_detector_adapter.detector_dispatcher import detector_dispatcher from vllm_detector_adapter.generative_detectors.base import ChatCompletionDetectionBase from vllm_detector_adapter.logging import init_logger from vllm_detector_adapter.protocol import ( @@ -15,6 +16,7 @@ ContextAnalysisRequest, DetectionResponse, ) +from vllm_detector_adapter.utils import DetectorType logger = init_logger(__name__) @@ -33,7 +35,9 @@ class GraniteGuardian(ChatCompletionDetectionBase): PROMPT_CONTEXT_ANALYSIS_RISKS = ["context_relevance"] RESPONSE_CONTEXT_ANALYSIS_RISKS = ["groundedness"] - def preprocess( + ##### Private / Internal functions ################################################### + + def __preprocess( self, request: Union[ChatDetectionRequest, ContextAnalysisRequest] ) -> Union[ChatDetectionRequest, ContextAnalysisRequest, ErrorResponse]: """Granite guardian specific parameter updates for risk name and risk definition""" @@ -59,13 +63,10 @@ def preprocess( return request - def preprocess_chat_request( - self, request: ChatDetectionRequest - ) -> Union[ChatDetectionRequest, ErrorResponse]: - """Granite guardian chat request preprocess is just detector parameter updates""" - return self.preprocess(request) - - def request_to_chat_completion_request( + # Decorating this function to make it cleaner for future iterations of this function + # to support other types of detectors + @detector_dispatcher(types=[DetectorType.TEXT_CONTEXT_DOC]) + def _request_to_chat_completion_request( self, request: ContextAnalysisRequest, model_name: str ) -> Union[ChatCompletionRequest, ErrorResponse]: NO_RISK_NAME_MESSAGE = "No risk_name for context analysis" @@ -141,6 +142,17 @@ def request_to_chat_completion_request( code=HTTPStatus.BAD_REQUEST.value, ) + ##### General request / response processing functions ################## + + # Used detector_dispatcher decorator to allow for the same function to be called + # for different types of detectors with different request types etc. + @detector_dispatcher(types=[DetectorType.TEXT_CHAT]) + def preprocess_request( + self, request: ChatDetectionRequest + ) -> Union[ChatDetectionRequest, ErrorResponse]: + """Granite guardian chat request preprocess is just detector parameter updates""" + return self.__preprocess(request) + async def context_analyze( self, request: ContextAnalysisRequest, @@ -152,13 +164,13 @@ async def context_analyze( # Task template not applied for context analysis at this time # Make model-dependent adjustments for the request - request = self.preprocess(request) + request = self.__preprocess(request) # Since particular chat messages are dependent on Granite Guardian risk definitions, # the processing is done here rather than in a separate, general to_chat_completion_request # for all context analysis requests. - chat_completion_request = self.request_to_chat_completion_request( - request, model_name + chat_completion_request = self._request_to_chat_completion_request( + request, model_name, fn_type=DetectorType.TEXT_CONTEXT_DOC ) if isinstance(chat_completion_request, ErrorResponse): # Propagate any request problems diff --git a/vllm_detector_adapter/utils.py b/vllm_detector_adapter/utils.py new file mode 100644 index 0000000..c966531 --- /dev/null +++ b/vllm_detector_adapter/utils.py @@ -0,0 +1,11 @@ +# Standard +from enum import Enum, auto + + +class DetectorType(Enum): + """Enum to represent different types of detectors""" + + TEXT_CONTENT = auto() + TEXT_GENERATION = auto() + TEXT_CHAT = auto() + TEXT_CONTEXT_DOC = auto()