Skip to content

Commit 865f3e9

Browse files
committed
✨ Add llama-guard separate safety category in output
Signed-off-by: Gaurav-Kumbhat <Gaurav.Kumbhat@ibm.com>
1 parent 7e6d3f0 commit 865f3e9

File tree

3 files changed

+118
-14
lines changed

3 files changed

+118
-14
lines changed

vllm_detector_adapter/generative_detectors/base.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ def apply_task_template(
102102
return request
103103

104104
@detector_dispatcher(types=[DetectorType.TEXT_CHAT])
105-
def preprocess_request( # noqa: F811
105+
def preprocess_request( # noqa: F811
106106
self, request: ChatDetectionRequest
107107
) -> Union[ChatDetectionRequest, ErrorResponse]:
108108
"""Preprocess chat request"""
@@ -112,7 +112,7 @@ def preprocess_request( # noqa: F811
112112
##### Contents request processing functions ####################################
113113

114114
@detector_dispatcher(types=[DetectorType.TEXT_CONTENT])
115-
def preprocess_request( # noqa: F811
115+
def preprocess_request( # noqa: F811
116116
self, request: ContentsDetectionRequest
117117
) -> Union[ContentsDetectionRequest, ErrorResponse]:
118118
"""Preprocess contents request and convert it into appropriate chat request"""

vllm_detector_adapter/generative_detectors/llama_guard.py

Lines changed: 111 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,20 @@
1+
# Standard
2+
from typing import Optional
3+
import asyncio
4+
import copy
5+
6+
# Third Party
7+
from fastapi import Request
8+
from vllm.entrypoints.openai.protocol import ErrorResponse
9+
110
# Local
211
from vllm_detector_adapter.generative_detectors.base import ChatCompletionDetectionBase
312
from vllm_detector_adapter.logging import init_logger
13+
from vllm_detector_adapter.protocol import (
14+
ContentsDetectionRequest,
15+
ContentsDetectionResponse,
16+
)
17+
from vllm_detector_adapter.utils import DetectorType
418

519
logger = init_logger(__name__)
620

@@ -13,5 +27,100 @@ class LlamaGuard(ChatCompletionDetectionBase):
1327
SAFE_TOKEN = "safe"
1428
UNSAFE_TOKEN = "unsafe"
1529

16-
# NOTE: More intelligent template parsing can be done here, potentially
17-
# as a regex template for safe vs. unsafe and the 'unsafe' category
30+
def __post_process_results(self, results):
31+
# NOTE: Llama-guard returns specific safety categories in the last line and in a csv format
32+
# this is guided by the prompt definition of the model, so we expect llama_guard to adhere to it
33+
# atleast for Llama-Guard-3 (latest at the time of writing)
34+
35+
# NOTE: The concept of "choice" doesn't exist for content type detector API, so
36+
# we will essentially flatten out the responses, so different categories in 1 choice
37+
# will also look like another choice.
38+
39+
(responses, scores, detection_type) = results
40+
41+
new_choices = []
42+
new_scores = []
43+
44+
for i, choice in enumerate(responses.choices):
45+
content = choice.message.content
46+
if self.UNSAFE_TOKEN in content:
47+
# We will create multiple results for each unsafe category
48+
# in addition to "unsafe" as a category itself
49+
# NOTE: need to deepcopy, otherwise, choice will get overwritten
50+
unsafe_choice = copy.deepcopy(choice)
51+
unsafe_choice.message.content = self.UNSAFE_TOKEN
52+
53+
new_choices.append(unsafe_choice)
54+
new_scores.append(scores[i])
55+
56+
# Fetch categories as the last line in the response available in csv format
57+
for category in content.strip().split("\n")[-1].split(","):
58+
category_choice = copy.deepcopy(choice)
59+
category_choice.message.content = category
60+
new_choices.append(category_choice)
61+
# NOTE: currently using same score as "unsafe"
62+
# but we need to see if we can revisit this to get better score
63+
new_scores.append(scores[i])
64+
else:
65+
# "safe" case
66+
new_choices.append(choice)
67+
new_scores.append(scores[i])
68+
69+
responses.choices = new_choices
70+
return (responses, new_scores, detection_type)
71+
72+
async def content_analysis(
73+
self,
74+
request: ContentsDetectionRequest,
75+
raw_request: Optional[Request] = None,
76+
):
77+
"""Function used to call chat detection and provide a /text/contents response"""
78+
79+
# Apply task template if it exists
80+
if self.task_template:
81+
request = self.apply_task_template(
82+
request, fn_type=DetectorType.TEXT_CONTENT
83+
)
84+
if isinstance(request, ErrorResponse):
85+
# Propagate any request problems that will not allow
86+
# task template to be applied
87+
return request
88+
89+
# Since separate batch processing function doesn't exist at the time of writing,
90+
# we are just going to collect all the text from content request and fire up
91+
# separate requests and wait asynchronously.
92+
# This mirrors how batching is handled in run_batch function in entrypoints/openai/
93+
# in vLLM codebase.
94+
completion_requests = self.preprocess_request(
95+
request, fn_type=DetectorType.TEXT_CONTENT
96+
)
97+
98+
# Send all the completion requests asynchronously.
99+
tasks = [
100+
asyncio.create_task(
101+
self.process_chat_completion_with_scores(
102+
completion_request, raw_request
103+
)
104+
)
105+
for completion_request in completion_requests
106+
]
107+
108+
# Gather all the results
109+
# NOTE: The results are guaranteed to be in order of requests
110+
results = await asyncio.gather(*tasks)
111+
112+
# If there is any error, return that otherwise, return the whole response
113+
# properly formatted.
114+
categorized_results = []
115+
for result in results:
116+
# NOTE: we are only sending 1 of the error results
117+
# and not every or not cumulative
118+
if isinstance(result, ErrorResponse):
119+
return result
120+
else:
121+
# Process results to split out safety categories into separate objects
122+
categorized_results.append(self.__post_process_results(result))
123+
124+
return ContentsDetectionResponse.from_chat_completion_response(
125+
categorized_results, request.contents
126+
)

vllm_detector_adapter/protocol.py

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -48,10 +48,7 @@ class ContentsDetectionResponse(RootModel):
4848
root: List[List[ContentsDetectionResponseObject]]
4949

5050
@staticmethod
51-
def from_chat_completion_response(
52-
results,
53-
# responses: ChatCompletionResponse, scores: List[float], detection_type: str
54-
):
51+
def from_chat_completion_response(results, contents: List[str]):
5552
"""Function to convert openai chat completion response to [fms] contents detection response
5653
5754
Args:
@@ -63,23 +60,23 @@ def from_chat_completion_response(
6360
"""
6461
contents_detection_responses = []
6562

66-
for (responses, scores, detection_type) in results:
63+
for content_idx, (responses, scores, detection_type) in enumerate(results):
6764

6865
detection_responses = []
6966
for i, choice in enumerate(responses.choices):
7067
content = choice.message.content
7168
# NOTE: for providing spans, we currently consider entire generated text as a span.
7269
# This is because, at the time of writing, the generative guardrail models does not
73-
# provide spefific information about text, which can be used to deduce spans.
70+
# provide spefific information about input text, which can be used to deduce spans.
7471
start = 0
75-
end = len(content)
72+
end = len(contents[content_idx])
7673
if content and isinstance(content, str):
7774
response_object = ContentsDetectionResponseObject(
7875
detection_type=detection_type,
7976
detection=content.strip(),
8077
start=start,
8178
end=end,
82-
text=content,
79+
text=contents[content_idx],
8380
score=scores[i],
8481
).model_dump()
8582
detection_responses.append(response_object)
@@ -94,8 +91,6 @@ def from_chat_completion_response(
9491
type="BadRequestError",
9592
code=HTTPStatus.BAD_REQUEST.value,
9693
)
97-
98-
# return ContentsDetectionResponse(root=detection_responses)
9994
contents_detection_responses.append(detection_responses)
10095

10196
return ContentsDetectionResponse(root=contents_detection_responses)

0 commit comments

Comments
 (0)