Skip to content

Commit 519fc15

Browse files
authored
✨ Support think parameter for Granite Guardian 3.3 (#101)
* ✨ Support Granite Guardian 3.3 think param Signed-off-by: Evaline Ju <[email protected]> * ✅ Tests with think result Signed-off-by: Evaline Ju <[email protected]> * 🎨 Formatting Signed-off-by: Evaline Ju <[email protected]> --------- Signed-off-by: Evaline Ju <[email protected]>
1 parent 16c7212 commit 519fc15

File tree

2 files changed

+148
-0
lines changed

2 files changed

+148
-0
lines changed

tests/generative_detectors/test_granite_guardian.py

Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -292,6 +292,60 @@ def granite_guardian_completion_response_3_3_plus_no_think():
292292
)
293293

294294

295+
@pytest.fixture(scope="function")
296+
def granite_guardian_completion_response_3_3_plus_think():
297+
"""Granite Guardian 3.3+ response with think/trace output"""
298+
log_probs_content_yes = ChatCompletionLogProbsContent(
299+
token=" yes",
300+
logprob=0.00,
301+
# 5 logprobs requested for scoring, skipping bytes for conciseness
302+
top_logprobs=[
303+
ChatCompletionLogProb(token=" yes", logprob=0.00),
304+
ChatCompletionLogProb(token="yes", logprob=-14.00),
305+
ChatCompletionLogProb(token=" '", logprob=-14.43),
306+
ChatCompletionLogProb(token=" indeed", logprob=-14.93),
307+
ChatCompletionLogProb(token=" Yes", logprob=-15.56),
308+
],
309+
)
310+
log_probs_content_random = ChatCompletionLogProbsContent(
311+
token="<",
312+
logprob=0.00,
313+
# 5 logprobs requested for scoring, skipping bytes for conciseness
314+
top_logprobs=[
315+
ChatCompletionLogProb(token="<", logprob=0.00),
316+
ChatCompletionLogProb(token="First", logprob=-12.12),
317+
ChatCompletionLogProb(token="The", logprob=-13.63),
318+
ChatCompletionLogProb(token=" <", logprob=-13.87),
319+
ChatCompletionLogProb(token=" First", logprob=-16.43),
320+
],
321+
)
322+
choice_0 = ChatCompletionResponseChoice(
323+
index=0,
324+
message=ChatMessage(
325+
role="assistant",
326+
content="<think> First analyzing the user request...there is a risk associated, so the score is yes.\n</think>\n<score> yes </score>",
327+
),
328+
logprobs=ChatCompletionLogProbs(
329+
content=[log_probs_content_yes, log_probs_content_random]
330+
),
331+
)
332+
choice_1 = ChatCompletionResponseChoice(
333+
index=1,
334+
message=ChatMessage(
335+
role="assistant",
336+
content="<think> Since there is no risk associated, the score is no.\n</think>\n<score> no </score>",
337+
),
338+
logprobs=ChatCompletionLogProbs(
339+
content=[log_probs_content_random, log_probs_content_yes]
340+
),
341+
)
342+
yield ChatCompletionResponse(
343+
model=MODEL_NAME,
344+
choices=[choice_0, choice_1],
345+
usage=UsageInfo(prompt_tokens=122, total_tokens=910, completion_tokens=788),
346+
)
347+
348+
295349
### Tests #####################################################################
296350

297351
#### Private tools request tests
@@ -456,6 +510,31 @@ def test__extract_tag_info_no_think_result_and_score(
456510
)
457511

458512

513+
def test__extract_tag_info_think_result_and_score(
514+
granite_guardian_detection, granite_guardian_completion_response_3_3_plus_think
515+
):
516+
# In Granite Guardian 3.3+, think and score tags are provided
517+
granite_guardian_detection_instance = asyncio.run(granite_guardian_detection)
518+
choice_index = 0
519+
content = granite_guardian_completion_response_3_3_plus_think.choices[
520+
choice_index
521+
].message.content
522+
# NOTE: private function tested here
523+
metadata = granite_guardian_detection_instance._extract_tag_info(
524+
granite_guardian_completion_response_3_3_plus_think, choice_index, content
525+
)
526+
assert metadata == {
527+
"think": "First analyzing the user request...there is a risk associated, so the score is yes."
528+
}
529+
# Content should be updated without the tags
530+
assert (
531+
granite_guardian_completion_response_3_3_plus_think.choices[
532+
choice_index
533+
].message.content
534+
== "yes"
535+
)
536+
537+
459538
#### Helper function tests
460539

461540

@@ -869,6 +948,30 @@ def test_post_process_completion_with_no_think_result_and_score(
869948
assert chat_completion_response.choices[1].message.content == "yes"
870949

871950

951+
def test_post_process_completion_with_think_result_and_score(
952+
granite_guardian_detection, granite_guardian_completion_response_3_3_plus_think
953+
):
954+
# In Granite Guardian 3.3+, think and score tags are provided
955+
granite_guardian_detection_instance = asyncio.run(granite_guardian_detection)
956+
dummy_scores = [0.2, 0.2]
957+
(chat_completion_response, _, _, metadata_list) = asyncio.run(
958+
granite_guardian_detection_instance.post_process_completion_results(
959+
granite_guardian_completion_response_3_3_plus_think, dummy_scores, "risk"
960+
)
961+
)
962+
assert len(metadata_list) == 2 # 2 choices
963+
# Think content
964+
assert metadata_list[0] == {
965+
"think": "First analyzing the user request...there is a risk associated, so the score is yes."
966+
}
967+
assert metadata_list[1] == {
968+
"think": "Since there is no risk associated, the score is no."
969+
}
970+
# Chat completion response should be updated
971+
assert chat_completion_response.choices[0].message.content == "yes"
972+
assert chat_completion_response.choices[1].message.content == "no"
973+
974+
872975
#### Context analysis tests
873976

874977

@@ -1145,6 +1248,43 @@ def test_chat_detection_with_no_think_and_score_results(
11451248
assert detection_1["metadata"] == {}
11461249

11471250

1251+
def test_chat_detection_with_think_and_score_results(
1252+
granite_guardian_detection, granite_guardian_completion_response_3_3_plus_think
1253+
):
1254+
granite_guardian_detection_instance = asyncio.run(granite_guardian_detection)
1255+
chat_request = ChatDetectionRequest(
1256+
messages=[
1257+
DetectionChatMessageParam(
1258+
role="user", content="How do I figure out how to break into a house?"
1259+
)
1260+
]
1261+
)
1262+
with patch(
1263+
"vllm_detector_adapter.generative_detectors.granite_guardian.GraniteGuardian.create_chat_completion",
1264+
return_value=granite_guardian_completion_response_3_3_plus_think,
1265+
):
1266+
detection_response = asyncio.run(
1267+
granite_guardian_detection_instance.chat(chat_request)
1268+
)
1269+
assert type(detection_response) == DetectionResponse
1270+
detections = detection_response.model_dump()
1271+
assert len(detections) == 2 # 2 choices
1272+
detection_0 = detections[0]
1273+
assert detection_0["detection"] == "yes"
1274+
assert detection_0["detection_type"] == "risk"
1275+
assert pytest.approx(detection_0["score"]) == 1.0
1276+
assert detection_0["metadata"] == {
1277+
"think": "First analyzing the user request...there is a risk associated, so the score is yes."
1278+
}
1279+
detection_1 = detections[1]
1280+
assert detection_1["detection"] == "no"
1281+
assert detection_1["detection_type"] == "risk"
1282+
assert pytest.approx(detection_1["score"]) == 1.0
1283+
assert detection_1["metadata"] == {
1284+
"think": "Since there is no risk associated, the score is no."
1285+
}
1286+
1287+
11481288
def test_chat_detection_with_tools(
11491289
granite_guardian_detection, granite_guardian_completion_response
11501290
):

vllm_detector_adapter/generative_detectors/granite_guardian.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@ def __preprocess(
105105
guardian_config["risk_name"] = risk_name
106106
if risk_definition := request.detector_params.pop("risk_definition", None):
107107
guardian_config["risk_definition"] = risk_definition
108+
108109
# Guardian 3.3+
109110
if criteria_id := request.detector_params.pop("criteria_id", None):
110111
guardian_config["criteria_id"] = criteria_id
@@ -114,6 +115,13 @@ def __preprocess(
114115
"custom_scoring_schema", None
115116
):
116117
guardian_config["custom_scoring_schema"] = custom_scoring_schema
118+
if think := request.detector_params.pop("think", None):
119+
if "chat_template_kwargs" in request.detector_params:
120+
# Avoid overwriting other existent chat_template_kwargs
121+
request.detector_params["chat_template_kwargs"]["think"] = think
122+
else:
123+
request.detector_params["chat_template_kwargs"] = {"think": think}
124+
117125
if guardian_config:
118126
logger.debug("guardian_config {} provided for request", guardian_config)
119127
# Move the parameters to chat_template_kwargs

0 commit comments

Comments
 (0)