diff --git a/pyproject.toml b/pyproject.toml index 5096592..4fbc520 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,7 +15,7 @@ dependencies = ["orjson>=3.10.16,<3.11"] [project.optional-dependencies] vllm-tgis-adapter = ["vllm-tgis-adapter>=0.7.0,<0.7.2"] vllm = [ - # Note: 0.8.4 has a triton bug on Mac + "transformers<4.54.0", # vllm <= 0.10.0 has issues with higher transformers versions, fixed later in https://github.com/vllm-project/vllm/pull/20921 "vllm @ git+https://github.com/vllm-project/vllm.git@v0.9.0 ; sys_platform == 'darwin'", "vllm>=0.7.2,<0.9.1 ; sys_platform != 'darwin'", ] diff --git a/tests/generative_detectors/test_granite_guardian.py b/tests/generative_detectors/test_granite_guardian.py index cb3e83a..4b9dbe3 100644 --- a/tests/generative_detectors/test_granite_guardian.py +++ b/tests/generative_detectors/test_granite_guardian.py @@ -423,6 +423,44 @@ def test_preprocess_chat_request_with_detector_params(granite_guardian_detection } +def test_preprocess_chat_request_with_custom_criteria_detector_params( + granite_guardian_detection, +): + # Guardian 3.3+ parameters + granite_guardian_detection_instance = asyncio.run(granite_guardian_detection) + detector_params = { + "custom_criteria": "Here is some custom criteria", + "custom_scoring_schema": "If text meets criteria say yes", + "foo": "bar", + } + initial_request = ChatDetectionRequest( + messages=[ + DetectionChatMessageParam( + role="user", content="How do I figure out how to break into a house?" + ) + ], + detector_params=detector_params, + ) + processed_request = granite_guardian_detection_instance.preprocess_request( + initial_request, fn_type=DetectorType.TEXT_CHAT + ) + assert type(processed_request) == ChatDetectionRequest + # Processed request should not have these extra params + assert "custom_criteria" not in processed_request.detector_params + assert "custom_scoring_schema" not in processed_request.detector_params + assert "chat_template_kwargs" in processed_request.detector_params + assert ( + "guardian_config" in processed_request.detector_params["chat_template_kwargs"] + ) + guardian_config = processed_request.detector_params["chat_template_kwargs"][ + "guardian_config" + ] + assert guardian_config == { + "custom_criteria": "Here is some custom criteria", + "custom_scoring_schema": "If text meets criteria say yes", + } + + def test_preprocess_chat_request_with_extra_chat_template_kwargs( granite_guardian_detection, ): @@ -534,6 +572,41 @@ def test_request_to_chat_completion_request_response_analysis( ) +def test_request_to_chat_completion_request_response_analysis_criteria_id( + granite_guardian_detection, +): + # Guardian 3.3 parameters + granite_guardian_detection_instance = asyncio.run(granite_guardian_detection) + context_request = ContextAnalysisRequest( + content=CONTENT, + context_type="docs", + context=[CONTEXT_DOC], + detector_params={ + "n": 3, + "chat_template_kwargs": { + "guardian_config": {"criteria_id": "groundedness"} + }, + }, + ) + chat_request = ( + granite_guardian_detection_instance._request_to_chat_completion_request( + context_request, MODEL_NAME, fn_type=DetectorType.TEXT_CONTEXT_DOC + ) + ) + assert type(chat_request) == ChatCompletionRequest + assert chat_request.messages[0]["role"] == "context" + assert chat_request.messages[0]["content"] == CONTEXT_DOC + assert chat_request.messages[1]["role"] == "assistant" + assert chat_request.messages[1]["content"] == CONTENT + assert chat_request.model == MODEL_NAME + # detector_paramas + assert chat_request.n == 3 + assert ( + chat_request.chat_template_kwargs["guardian_config"]["criteria_id"] + == "groundedness" + ) + + def test_request_to_chat_completion_request_empty_kwargs(granite_guardian_detection): granite_guardian_detection_instance = asyncio.run(granite_guardian_detection) context_request = ContextAnalysisRequest( @@ -549,7 +622,7 @@ def test_request_to_chat_completion_request_empty_kwargs(granite_guardian_detect ) assert type(chat_request) == ErrorResponse assert chat_request.code == HTTPStatus.BAD_REQUEST - assert "No risk_name for context analysis" in chat_request.message + assert "No risk_name or criteria_id for context analysis" in chat_request.message def test_request_to_chat_completion_request_empty_guardian_config( @@ -569,10 +642,10 @@ def test_request_to_chat_completion_request_empty_guardian_config( ) assert type(chat_request) == ErrorResponse assert chat_request.code == HTTPStatus.BAD_REQUEST - assert "No risk_name for context analysis" in chat_request.message + assert "No risk_name or criteria_id for context analysis" in chat_request.message -def test_request_to_chat_completion_request_missing_risk_name( +def test_request_to_chat_completion_request_missing_risk_name_and_criteria_id( granite_guardian_detection, ): granite_guardian_detection_instance = asyncio.run(granite_guardian_detection) @@ -592,7 +665,7 @@ def test_request_to_chat_completion_request_missing_risk_name( ) assert type(chat_request) == ErrorResponse assert chat_request.code == HTTPStatus.BAD_REQUEST - assert "No risk_name for context analysis" in chat_request.message + assert "No risk_name or criteria_id for context analysis" in chat_request.message def test_request_to_chat_completion_request_unsupported_risk_name( @@ -616,7 +689,8 @@ def test_request_to_chat_completion_request_unsupported_risk_name( assert type(chat_request) == ErrorResponse assert chat_request.code == HTTPStatus.BAD_REQUEST assert ( - "risk_name foo is not compatible with context analysis" in chat_request.message + "risk_name or criteria_id foo is not compatible with context analysis" + in chat_request.message ) @@ -816,7 +890,7 @@ def test_context_analyze_unsupported_risk( assert type(detection_response) == ErrorResponse assert detection_response.code == HTTPStatus.BAD_REQUEST assert ( - "risk_name boo is not compatible with context analysis" + "risk_name or criteria_id boo is not compatible with context analysis" in detection_response.message ) @@ -970,6 +1044,34 @@ def test_chat_detection_with_tools( assert len(detections) == 2 # 2 choices +def test_chat_detection_with_tools_criteria_id( + granite_guardian_detection, granite_guardian_completion_response +): + # Guardian 3.3 parameters + granite_guardian_detection_instance = asyncio.run(granite_guardian_detection) + chat_request = ChatDetectionRequest( + messages=[ + DetectionChatMessageParam( + role="user", + content=USER_CONTENT_TOOLS, + ), + DetectionChatMessageParam(role="assistant", tool_calls=[TOOL_CALL]), + ], + tools=[TOOL], + detector_params={"criteria_id": "function_call", "n": 2}, + ) + with patch( + "vllm_detector_adapter.generative_detectors.granite_guardian.GraniteGuardian.create_chat_completion", + return_value=granite_guardian_completion_response, + ): + detection_response = asyncio.run( + granite_guardian_detection_instance.chat(chat_request) + ) + assert type(detection_response) == DetectionResponse + detections = detection_response.model_dump() + assert len(detections) == 2 # 2 choices + + def test_chat_detection_with_tools_wrong_risk( granite_guardian_detection, granite_guardian_completion_response ): diff --git a/vllm_detector_adapter/generative_detectors/granite_guardian.py b/vllm_detector_adapter/generative_detectors/granite_guardian.py index 3e998da..a27b7c8 100644 --- a/vllm_detector_adapter/generative_detectors/granite_guardian.py +++ b/vllm_detector_adapter/generative_detectors/granite_guardian.py @@ -40,7 +40,9 @@ class GraniteGuardianToolCallFunctionObject(TypedDict): class GraniteGuardian(ChatCompletionDetectionBase): - + # Note: Earlier generations of Granite Guardian use 'risk' while Granite Guardian + # 3.3 refers to 'criteria' for generalization. For now, because the taxonomy is + # still characterized as a 'risk' taxonomy, the detection type remains. DETECTION_TYPE = "risk" # User text pattern in task template USER_TEXT_PATTERN = "user_text" @@ -62,6 +64,7 @@ class GraniteGuardian(ChatCompletionDetectionBase): INDENT = orjson.OPT_INDENT_2 # Risk Bank name defined in the chat template + # Not actively used ref. https://github.com/foundation-model-stack/vllm-detector-adapter/issues/64 RISK_BANK_VAR_NAME = "risk_bank" # Attributes to be put in metadata @@ -87,7 +90,7 @@ def __preprocess( GenerationDetectionRequest, ErrorResponse, ]: - """Granite guardian specific parameter updates for risk name and risk definition""" + """Granite guardian specific parameter updates for risks/criteria""" # Validation that one of the 'defined' risks is requested will be # done through the chat template on each request. Errors will # be propagated for chat completion separately @@ -95,13 +98,23 @@ def __preprocess( if not request.detector_params: return request + # Guardian 3.2 and earlier if risk_name := request.detector_params.pop("risk_name", None): guardian_config["risk_name"] = risk_name if risk_definition := request.detector_params.pop("risk_definition", None): guardian_config["risk_definition"] = risk_definition + # Guardian 3.3+ + if criteria_id := request.detector_params.pop("criteria_id", None): + guardian_config["criteria_id"] = criteria_id + if custom_criteria := request.detector_params.pop("custom_criteria", None): + guardian_config["custom_criteria"] = custom_criteria + if custom_scoring_schema := request.detector_params.pop( + "custom_scoring_schema", None + ): + guardian_config["custom_scoring_schema"] = custom_scoring_schema if guardian_config: logger.debug("guardian_config {} provided for request", guardian_config) - # Move the risk name and/or risk definition to chat_template_kwargs + # Move the parameters to chat_template_kwargs # to be propagated to tokenizer.apply_chat_template during # chat completion if "chat_template_kwargs" in request.detector_params: @@ -134,15 +147,33 @@ def _make_tools_request( if ( "risk_name" not in request.detector_params - or request.detector_params["risk_name"] not in self.TOOLS_RISKS + and "criteria_id" not in request.detector_params + ): + return ErrorResponse( + message="tools analysis is not supported without a given risk/criteria", + type="BadRequestError", + code=HTTPStatus.BAD_REQUEST.value, + ) + # Granite 3.2 and earlier + if ( + "risk_name" in request.detector_params + and request.detector_params["risk_name"] not in self.TOOLS_RISKS ): - # Provide error here, since otherwise follow-on tools message - # and assistant message flattening will not be applicable return ErrorResponse( message="tools analysis is not supported with given risk", type="BadRequestError", code=HTTPStatus.BAD_REQUEST.value, ) + # Granite 3.3+ + elif ( + "criteria_id" in request.detector_params + and request.detector_params["criteria_id"] not in self.TOOLS_RISKS + ): + return ErrorResponse( + message="tools analysis is not supported with given criteria", + type="BadRequestError", + code=HTTPStatus.BAD_REQUEST.value, + ) # (1) 'Flatten' the assistant message, extracting the functions in the tool_calls # portion of the message @@ -242,7 +273,7 @@ def _make_tools_request( def _request_to_chat_completion_request( self, request: ContextAnalysisRequest, model_name: str ) -> Union[ChatCompletionRequest, ErrorResponse]: - NO_RISK_NAME_MESSAGE = "No risk_name for context analysis" + NO_RISK_NAME_MESSAGE = "No risk_name or criteria_id for context analysis" risk_name = None if ( @@ -259,8 +290,10 @@ def _request_to_chat_completion_request( "guardian_config" ]: if isinstance(guardian_config, dict): - risk_name = guardian_config.get("risk_name") - # Leaving off risk name can lead to model/template errors + risk_name = guardian_config.get("risk_name") or guardian_config.get( + "criteria_id" + ) + # Leaving off risk_name and criteria_id can lead to model/template errors if not risk_name: return ErrorResponse( message=NO_RISK_NAME_MESSAGE, @@ -292,9 +325,9 @@ def _request_to_chat_completion_request( {"role": "user", "content": content}, ] else: - # Return error if risk names are not expected ones + # Return error if risk names or criteria are not expected ones return ErrorResponse( - message="risk_name {} is not compatible with context analysis".format( + message="risk_name or criteria_id {} is not compatible with context analysis".format( risk_name ), type="BadRequestError", @@ -448,10 +481,18 @@ async def generation_analyze( # If risk_name is not specifically provided for this endpoint, we will add a # risk_name, since the user has already decided to use this particular endpoint + # Granite Guardian 3.2 and earlier if "risk_name" not in request.detector_params: request.detector_params[ "risk_name" ] = self.DEFAULT_GENERATION_DETECTION_RISK + # Granite Guardian 3.3+ + # Generally the additional/repeated risk is not problematic + # This avoids having to verify Guardian version at this step + if "criteria_id" not in request.detector_params: + request.detector_params[ + "criteria_id" + ] = self.DEFAULT_GENERATION_DETECTION_RISK # Task template not applied for generation analysis at this time # Make model-dependent adjustments for the request