Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ dependencies = ["orjson>=3.10.16,<3.11"]
[project.optional-dependencies]
vllm-tgis-adapter = ["vllm-tgis-adapter>=0.7.0,<0.7.2"]
vllm = [
# Note: 0.8.4 has a triton bug on Mac
"transformers<4.54.0", # vllm <= 0.10.0 has issues with higher transformers versions, fixed later in https://github.com/vllm-project/vllm/pull/20921
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This fix isn't released yet so tests will fail with ValueError: 'aimv2' is already used by a Transformers config, pick another name. with the latest transformers versions

"vllm @ git+https://github.com/vllm-project/[email protected] ; sys_platform == 'darwin'",
"vllm>=0.7.2,<0.9.1 ; sys_platform != 'darwin'",
]
Expand Down
114 changes: 108 additions & 6 deletions tests/generative_detectors/test_granite_guardian.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,6 +423,44 @@ def test_preprocess_chat_request_with_detector_params(granite_guardian_detection
}


def test_preprocess_chat_request_with_custom_criteria_detector_params(
granite_guardian_detection,
):
# Guardian 3.3+ parameters
granite_guardian_detection_instance = asyncio.run(granite_guardian_detection)
detector_params = {
"custom_criteria": "Here is some custom criteria",
"custom_scoring_schema": "If text meets criteria say yes",
"foo": "bar",
}
initial_request = ChatDetectionRequest(
messages=[
DetectionChatMessageParam(
role="user", content="How do I figure out how to break into a house?"
)
],
detector_params=detector_params,
)
processed_request = granite_guardian_detection_instance.preprocess_request(
initial_request, fn_type=DetectorType.TEXT_CHAT
)
assert type(processed_request) == ChatDetectionRequest
# Processed request should not have these extra params
assert "custom_criteria" not in processed_request.detector_params
assert "custom_scoring_schema" not in processed_request.detector_params
assert "chat_template_kwargs" in processed_request.detector_params
assert (
"guardian_config" in processed_request.detector_params["chat_template_kwargs"]
)
guardian_config = processed_request.detector_params["chat_template_kwargs"][
"guardian_config"
]
assert guardian_config == {
"custom_criteria": "Here is some custom criteria",
"custom_scoring_schema": "If text meets criteria say yes",
}


def test_preprocess_chat_request_with_extra_chat_template_kwargs(
granite_guardian_detection,
):
Expand Down Expand Up @@ -534,6 +572,41 @@ def test_request_to_chat_completion_request_response_analysis(
)


def test_request_to_chat_completion_request_response_analysis_criteria_id(
granite_guardian_detection,
):
# Guardian 3.3 parameters
granite_guardian_detection_instance = asyncio.run(granite_guardian_detection)
context_request = ContextAnalysisRequest(
content=CONTENT,
context_type="docs",
context=[CONTEXT_DOC],
detector_params={
"n": 3,
"chat_template_kwargs": {
"guardian_config": {"criteria_id": "groundedness"}
},
},
)
chat_request = (
granite_guardian_detection_instance._request_to_chat_completion_request(
context_request, MODEL_NAME, fn_type=DetectorType.TEXT_CONTEXT_DOC
)
)
assert type(chat_request) == ChatCompletionRequest
assert chat_request.messages[0]["role"] == "context"
assert chat_request.messages[0]["content"] == CONTEXT_DOC
assert chat_request.messages[1]["role"] == "assistant"
assert chat_request.messages[1]["content"] == CONTENT
assert chat_request.model == MODEL_NAME
# detector_paramas
assert chat_request.n == 3
assert (
chat_request.chat_template_kwargs["guardian_config"]["criteria_id"]
== "groundedness"
)


def test_request_to_chat_completion_request_empty_kwargs(granite_guardian_detection):
granite_guardian_detection_instance = asyncio.run(granite_guardian_detection)
context_request = ContextAnalysisRequest(
Expand All @@ -549,7 +622,7 @@ def test_request_to_chat_completion_request_empty_kwargs(granite_guardian_detect
)
assert type(chat_request) == ErrorResponse
assert chat_request.code == HTTPStatus.BAD_REQUEST
assert "No risk_name for context analysis" in chat_request.message
assert "No risk_name or criteria_id for context analysis" in chat_request.message


def test_request_to_chat_completion_request_empty_guardian_config(
Expand All @@ -569,10 +642,10 @@ def test_request_to_chat_completion_request_empty_guardian_config(
)
assert type(chat_request) == ErrorResponse
assert chat_request.code == HTTPStatus.BAD_REQUEST
assert "No risk_name for context analysis" in chat_request.message
assert "No risk_name or criteria_id for context analysis" in chat_request.message


def test_request_to_chat_completion_request_missing_risk_name(
def test_request_to_chat_completion_request_missing_risk_name_and_criteria_id(
granite_guardian_detection,
):
granite_guardian_detection_instance = asyncio.run(granite_guardian_detection)
Expand All @@ -592,7 +665,7 @@ def test_request_to_chat_completion_request_missing_risk_name(
)
assert type(chat_request) == ErrorResponse
assert chat_request.code == HTTPStatus.BAD_REQUEST
assert "No risk_name for context analysis" in chat_request.message
assert "No risk_name or criteria_id for context analysis" in chat_request.message


def test_request_to_chat_completion_request_unsupported_risk_name(
Expand All @@ -616,7 +689,8 @@ def test_request_to_chat_completion_request_unsupported_risk_name(
assert type(chat_request) == ErrorResponse
assert chat_request.code == HTTPStatus.BAD_REQUEST
assert (
"risk_name foo is not compatible with context analysis" in chat_request.message
"risk_name or criteria_id foo is not compatible with context analysis"
in chat_request.message
)


Expand Down Expand Up @@ -816,7 +890,7 @@ def test_context_analyze_unsupported_risk(
assert type(detection_response) == ErrorResponse
assert detection_response.code == HTTPStatus.BAD_REQUEST
assert (
"risk_name boo is not compatible with context analysis"
"risk_name or criteria_id boo is not compatible with context analysis"
in detection_response.message
)

Expand Down Expand Up @@ -970,6 +1044,34 @@ def test_chat_detection_with_tools(
assert len(detections) == 2 # 2 choices


def test_chat_detection_with_tools_criteria_id(
granite_guardian_detection, granite_guardian_completion_response
):
# Guardian 3.3 parameters
granite_guardian_detection_instance = asyncio.run(granite_guardian_detection)
chat_request = ChatDetectionRequest(
messages=[
DetectionChatMessageParam(
role="user",
content=USER_CONTENT_TOOLS,
),
DetectionChatMessageParam(role="assistant", tool_calls=[TOOL_CALL]),
],
tools=[TOOL],
detector_params={"criteria_id": "function_call", "n": 2},
)
with patch(
"vllm_detector_adapter.generative_detectors.granite_guardian.GraniteGuardian.create_chat_completion",
return_value=granite_guardian_completion_response,
):
detection_response = asyncio.run(
granite_guardian_detection_instance.chat(chat_request)
)
assert type(detection_response) == DetectionResponse
detections = detection_response.model_dump()
assert len(detections) == 2 # 2 choices


def test_chat_detection_with_tools_wrong_risk(
granite_guardian_detection, granite_guardian_completion_response
):
Expand Down
63 changes: 52 additions & 11 deletions vllm_detector_adapter/generative_detectors/granite_guardian.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,9 @@ class GraniteGuardianToolCallFunctionObject(TypedDict):


class GraniteGuardian(ChatCompletionDetectionBase):

# Note: Earlier generations of Granite Guardian use 'risk' while Granite Guardian
# 3.3 refers to 'criteria' for generalization. For now, because the taxonomy is
# still characterized as a 'risk' taxonomy, the detection type remains.
DETECTION_TYPE = "risk"
# User text pattern in task template
USER_TEXT_PATTERN = "user_text"
Expand All @@ -62,6 +64,7 @@ class GraniteGuardian(ChatCompletionDetectionBase):
INDENT = orjson.OPT_INDENT_2

# Risk Bank name defined in the chat template
# Not actively used ref. https://github.com/foundation-model-stack/vllm-detector-adapter/issues/64
RISK_BANK_VAR_NAME = "risk_bank"

# Attributes to be put in metadata
Expand All @@ -87,21 +90,31 @@ def __preprocess(
GenerationDetectionRequest,
ErrorResponse,
]:
"""Granite guardian specific parameter updates for risk name and risk definition"""
"""Granite guardian specific parameter updates for risks/criteria"""
# Validation that one of the 'defined' risks is requested will be
# done through the chat template on each request. Errors will
# be propagated for chat completion separately
guardian_config = {}
if not request.detector_params:
return request

# Guardian 3.2 and earlier
if risk_name := request.detector_params.pop("risk_name", None):
guardian_config["risk_name"] = risk_name
if risk_definition := request.detector_params.pop("risk_definition", None):
guardian_config["risk_definition"] = risk_definition
# Guardian 3.3+
if criteria_id := request.detector_params.pop("criteria_id", None):
guardian_config["criteria_id"] = criteria_id
if custom_criteria := request.detector_params.pop("custom_criteria", None):
guardian_config["custom_criteria"] = custom_criteria
if custom_scoring_schema := request.detector_params.pop(
"custom_scoring_schema", None
):
guardian_config["custom_scoring_schema"] = custom_scoring_schema
if guardian_config:
logger.debug("guardian_config {} provided for request", guardian_config)
# Move the risk name and/or risk definition to chat_template_kwargs
# Move the parameters to chat_template_kwargs
# to be propagated to tokenizer.apply_chat_template during
# chat completion
if "chat_template_kwargs" in request.detector_params:
Expand Down Expand Up @@ -134,15 +147,33 @@ def _make_tools_request(

if (
"risk_name" not in request.detector_params
or request.detector_params["risk_name"] not in self.TOOLS_RISKS
and "criteria_id" not in request.detector_params
):
return ErrorResponse(
message="tools analysis is not supported without a given risk/criteria",
type="BadRequestError",
code=HTTPStatus.BAD_REQUEST.value,
)
# Granite 3.2 and earlier
if (
"risk_name" in request.detector_params
and request.detector_params["risk_name"] not in self.TOOLS_RISKS
):
# Provide error here, since otherwise follow-on tools message
# and assistant message flattening will not be applicable
return ErrorResponse(
message="tools analysis is not supported with given risk",
type="BadRequestError",
code=HTTPStatus.BAD_REQUEST.value,
)
# Granite 3.3+
elif (
"criteria_id" in request.detector_params
and request.detector_params["criteria_id"] not in self.TOOLS_RISKS
):
return ErrorResponse(
message="tools analysis is not supported with given criteria",
type="BadRequestError",
code=HTTPStatus.BAD_REQUEST.value,
)

# (1) 'Flatten' the assistant message, extracting the functions in the tool_calls
# portion of the message
Expand Down Expand Up @@ -242,7 +273,7 @@ def _make_tools_request(
def _request_to_chat_completion_request(
self, request: ContextAnalysisRequest, model_name: str
) -> Union[ChatCompletionRequest, ErrorResponse]:
NO_RISK_NAME_MESSAGE = "No risk_name for context analysis"
NO_RISK_NAME_MESSAGE = "No risk_name or criteria_id for context analysis"

risk_name = None
if (
Expand All @@ -259,8 +290,10 @@ def _request_to_chat_completion_request(
"guardian_config"
]:
if isinstance(guardian_config, dict):
risk_name = guardian_config.get("risk_name")
# Leaving off risk name can lead to model/template errors
risk_name = guardian_config.get("risk_name") or guardian_config.get(
"criteria_id"
)
# Leaving off risk_name and criteria_id can lead to model/template errors
if not risk_name:
return ErrorResponse(
message=NO_RISK_NAME_MESSAGE,
Expand Down Expand Up @@ -292,9 +325,9 @@ def _request_to_chat_completion_request(
{"role": "user", "content": content},
]
else:
# Return error if risk names are not expected ones
# Return error if risk names or criteria are not expected ones
return ErrorResponse(
message="risk_name {} is not compatible with context analysis".format(
message="risk_name or criteria_id {} is not compatible with context analysis".format(
risk_name
),
type="BadRequestError",
Expand Down Expand Up @@ -448,10 +481,18 @@ async def generation_analyze(

# If risk_name is not specifically provided for this endpoint, we will add a
# risk_name, since the user has already decided to use this particular endpoint
# Granite Guardian 3.2 and earlier
if "risk_name" not in request.detector_params:
request.detector_params[
"risk_name"
] = self.DEFAULT_GENERATION_DETECTION_RISK
# Granite Guardian 3.3+
# Generally the additional/repeated risk is not problematic
# This avoids having to verify Guardian version at this step
if "criteria_id" not in request.detector_params:
request.detector_params[
"criteria_id"
] = self.DEFAULT_GENERATION_DETECTION_RISK

# Task template not applied for generation analysis at this time
# Make model-dependent adjustments for the request
Expand Down