-
Notifications
You must be signed in to change notification settings - Fork 7
Detector dispatcher decorator #18
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
feb64d7
30fcacf
4181b80
c104779
d4ae90e
2dcffef
13f8a4c
820374a
ae50dd8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -13,7 +13,8 @@ classifiers = [ | |
| ] | ||
|
|
||
| dependencies = [ | ||
| "vllm>=0.7.0" | ||
| "vllm @ git+https://github.com/vllm-project/[email protected] ; sys_platform == 'darwin'", | ||
| "vllm>=0.7.1 ; sys_platform != 'darwin'", | ||
| ] | ||
|
|
||
| [project.optional-dependencies] | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,81 @@ | ||
| # Standard | ||
| import functools | ||
|
|
||
| # global list to store all the registered functions with | ||
| # their types and qualified name | ||
| global_fn_list = dict() | ||
|
|
||
|
|
||
| def detector_dispatcher(types=None): | ||
| """Decorator to dispatch to processing function based on type of the detector. | ||
|
|
||
| This decorator allows us to reuse same function name for different types of detectors. | ||
| For example, the same function name can be used for text chat and context analysis | ||
| detectors. These decorated functions for these detectors will have different arguments | ||
| and implementation but they share the same function name. | ||
|
|
||
| NOTE: At the time of invoking these decorated function, the user needs to specify the type | ||
| of the detector using fn_type argument. | ||
|
|
||
| CAUTION: Since this decorator allow re-use of the name, one must take care of using different types | ||
| for testing different functions. | ||
|
|
||
| Args: | ||
| types (list): Type of the detector this function applies to. | ||
| args: Positional arguments passed to the processing function. | ||
| kwargs: Keyword arguments passed to the processing function. | ||
|
|
||
| Examples | ||
| -------- | ||
|
|
||
| @detector_dispatcher(types=["foo"]) | ||
| def f(x): | ||
| pass | ||
|
|
||
| # Decorator can take multiple types as well | ||
| @detector_dispatcher(types=["bar", "baz"]) | ||
| def f(x): | ||
| pass | ||
|
|
||
| When calling these functions, one can specify the type of the detector as follows: | ||
| f(x, fn_type="foo") | ||
| f(x, fn_type="bar") | ||
| """ | ||
| global global_fn_list | ||
|
|
||
| if not types: | ||
| raise ValueError("Must specify types.") | ||
|
|
||
| def decorator(func): | ||
| fn_name = func.__qualname__ | ||
|
|
||
| if fn_name not in global_fn_list: | ||
| # Associate each function with its type to create a dictionary of form: | ||
| # {"fn_name": {type1: function, type2: function} | ||
| # NOTE: "function" here are really function pointers | ||
| global_fn_list[fn_name] = {t: func for t in types} | ||
| elif fn_name in global_fn_list and (types & global_fn_list[fn_name].keys()): | ||
| # Error out if the types function with same type declaration exist in the global | ||
| # list already | ||
| raise ValueError("Function already registered with the same types.") | ||
| else: | ||
| # Add the function to the global list with corresponding type | ||
| global_fn_list[fn_name] |= {t: func for t in types} | ||
|
|
||
| @functools.wraps(func) | ||
| def wrapper(*args, fn_type=None, **kwargs): | ||
| fn_name = func.__qualname__ | ||
|
|
||
| if not fn_type: | ||
| raise ValueError("Must specify fn_type.") | ||
|
|
||
| if fn_type not in global_fn_list[fn_name].keys(): | ||
| raise ValueError("Invalid fn_type.") | ||
|
|
||
| # Grab the function using its fully qualified name and the specified type | ||
| # and then call it | ||
| return global_fn_list[fn_name][fn_type](*args, **kwargs) | ||
|
|
||
| return wrapper | ||
|
|
||
| return decorator |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -8,13 +8,15 @@ | |
| from vllm.entrypoints.openai.protocol import ChatCompletionRequest, ErrorResponse | ||
|
|
||
| # Local | ||
| from vllm_detector_adapter.detector_dispatcher import detector_dispatcher | ||
| from vllm_detector_adapter.generative_detectors.base import ChatCompletionDetectionBase | ||
| from vllm_detector_adapter.logging import init_logger | ||
| from vllm_detector_adapter.protocol import ( | ||
| ChatDetectionRequest, | ||
| ContextAnalysisRequest, | ||
| DetectionResponse, | ||
| ) | ||
| from vllm_detector_adapter.utils import DetectorType | ||
|
|
||
| logger = init_logger(__name__) | ||
|
|
||
|
|
@@ -33,7 +35,9 @@ class GraniteGuardian(ChatCompletionDetectionBase): | |
| PROMPT_CONTEXT_ANALYSIS_RISKS = ["context_relevance"] | ||
| RESPONSE_CONTEXT_ANALYSIS_RISKS = ["groundedness"] | ||
|
|
||
| def preprocess( | ||
| ##### Private / Internal functions ################################################### | ||
|
|
||
| def __preprocess( | ||
| self, request: Union[ChatDetectionRequest, ContextAnalysisRequest] | ||
| ) -> Union[ChatDetectionRequest, ContextAnalysisRequest, ErrorResponse]: | ||
| """Granite guardian specific parameter updates for risk name and risk definition""" | ||
|
|
@@ -59,13 +63,10 @@ def preprocess( | |
|
|
||
| return request | ||
|
|
||
| def preprocess_chat_request( | ||
| self, request: ChatDetectionRequest | ||
| ) -> Union[ChatDetectionRequest, ErrorResponse]: | ||
| """Granite guardian chat request preprocess is just detector parameter updates""" | ||
| return self.preprocess(request) | ||
|
|
||
| def request_to_chat_completion_request( | ||
| # Decorating this function to make it cleaner for future iterations of this function | ||
| # to support other types of detectors | ||
| @detector_dispatcher(types=[DetectorType.TEXT_CONTEXT_DOC]) | ||
| def _request_to_chat_completion_request( | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: if we're keeping this "private" since this only applies to granite guardian, maybe this should go in the "Private functions" section above?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I kept these "internal", since there were unit tests for this particular function, so testing them after making them private would get interesting.. But might still make sense to move them to private block and change name of that block
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ah I misread the number of underscores, it might be good to have an "internal functions" block then |
||
| self, request: ContextAnalysisRequest, model_name: str | ||
| ) -> Union[ChatCompletionRequest, ErrorResponse]: | ||
| NO_RISK_NAME_MESSAGE = "No risk_name for context analysis" | ||
|
|
@@ -141,6 +142,17 @@ def request_to_chat_completion_request( | |
| code=HTTPStatus.BAD_REQUEST.value, | ||
| ) | ||
|
|
||
| ##### General request / response processing functions ################## | ||
|
|
||
| # Used detector_dispatcher decorator to allow for the same function to be called | ||
| # for different types of detectors with different request types etc. | ||
| @detector_dispatcher(types=[DetectorType.TEXT_CHAT]) | ||
| def preprocess_request( | ||
| self, request: ChatDetectionRequest | ||
| ) -> Union[ChatDetectionRequest, ErrorResponse]: | ||
| """Granite guardian chat request preprocess is just detector parameter updates""" | ||
| return self.__preprocess(request) | ||
|
|
||
| async def context_analyze( | ||
| self, | ||
| request: ContextAnalysisRequest, | ||
|
|
@@ -152,13 +164,13 @@ async def context_analyze( | |
|
|
||
| # Task template not applied for context analysis at this time | ||
| # Make model-dependent adjustments for the request | ||
| request = self.preprocess(request) | ||
| request = self.__preprocess(request) | ||
|
|
||
| # Since particular chat messages are dependent on Granite Guardian risk definitions, | ||
| # the processing is done here rather than in a separate, general to_chat_completion_request | ||
| # for all context analysis requests. | ||
| chat_completion_request = self.request_to_chat_completion_request( | ||
| request, model_name | ||
| chat_completion_request = self._request_to_chat_completion_request( | ||
| request, model_name, fn_type=DetectorType.TEXT_CONTEXT_DOC | ||
| ) | ||
| if isinstance(chat_completion_request, ErrorResponse): | ||
| # Propagate any request problems | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,11 @@ | ||
| # Standard | ||
| from enum import Enum, auto | ||
|
|
||
|
|
||
| class DetectorType(Enum): | ||
| """Enum to represent different types of detectors""" | ||
|
|
||
| TEXT_CONTENT = auto() | ||
| TEXT_GENERATION = auto() | ||
| TEXT_CHAT = auto() | ||
| TEXT_CONTEXT_DOC = auto() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think this one strictly needed to change since it should be backwards compatible, but likely not an issue for current usage