1+ # Standard
2+ from typing import Optional
3+ import asyncio
4+ import copy
5+
6+ # Third Party
7+ from fastapi import Request
8+ from vllm .entrypoints .openai .protocol import ErrorResponse
9+
110# Local
211from vllm_detector_adapter .generative_detectors .base import ChatCompletionDetectionBase
312from vllm_detector_adapter .logging import init_logger
13+ from vllm_detector_adapter .protocol import (
14+ ContentsDetectionRequest ,
15+ ContentsDetectionResponse ,
16+ )
17+ from vllm_detector_adapter .utils import DetectorType
418
519logger = init_logger (__name__ )
620
@@ -13,5 +27,100 @@ class LlamaGuard(ChatCompletionDetectionBase):
1327 SAFE_TOKEN = "safe"
1428 UNSAFE_TOKEN = "unsafe"
1529
16- # NOTE: More intelligent template parsing can be done here, potentially
17- # as a regex template for safe vs. unsafe and the 'unsafe' category
30+ def __post_process_results (self , results ):
31+ # NOTE: Llama-guard returns specific safety categories in the last line and in a csv format
32+ # this is guided by the prompt definition of the model, so we expect llama_guard to adhere to it
33+ # atleast for Llama-Guard-3 (latest at the time of writing)
34+
35+ # NOTE: The concept of "choice" doesn't exist for content type detector API, so
36+ # we will essentially flatten out the responses, so different categories in 1 choice
37+ # will also look like another choice.
38+
39+ (responses , scores , detection_type ) = results
40+
41+ new_choices = []
42+ new_scores = []
43+
44+ for i , choice in enumerate (responses .choices ):
45+ content = choice .message .content
46+ if self .UNSAFE_TOKEN in content :
47+ # We will create multiple results for each unsafe category
48+ # in addition to "unsafe" as a category itself
49+ # NOTE: need to deepcopy, otherwise, choice will get overwritten
50+ unsafe_choice = copy .deepcopy (choice )
51+ unsafe_choice .message .content = self .UNSAFE_TOKEN
52+
53+ new_choices .append (unsafe_choice )
54+ new_scores .append (scores [i ])
55+
56+ # Fetch categories as the last line in the response available in csv format
57+ for category in content .strip ().split ("\n " )[- 1 ].split ("," ):
58+ category_choice = copy .deepcopy (choice )
59+ category_choice .message .content = category
60+ new_choices .append (category_choice )
61+ # NOTE: currently using same score as "unsafe"
62+ # but we need to see if we can revisit this to get better score
63+ new_scores .append (scores [i ])
64+ else :
65+ # "safe" case
66+ new_choices .append (choice )
67+ new_scores .append (scores [i ])
68+
69+ responses .choices = new_choices
70+ return (responses , new_scores , detection_type )
71+
72+ async def content_analysis (
73+ self ,
74+ request : ContentsDetectionRequest ,
75+ raw_request : Optional [Request ] = None ,
76+ ):
77+ """Function used to call chat detection and provide a /text/contents response"""
78+
79+ # Apply task template if it exists
80+ if self .task_template :
81+ request = self .apply_task_template (
82+ request , fn_type = DetectorType .TEXT_CONTENT
83+ )
84+ if isinstance (request , ErrorResponse ):
85+ # Propagate any request problems that will not allow
86+ # task template to be applied
87+ return request
88+
89+ # Since separate batch processing function doesn't exist at the time of writing,
90+ # we are just going to collect all the text from content request and fire up
91+ # separate requests and wait asynchronously.
92+ # This mirrors how batching is handled in run_batch function in entrypoints/openai/
93+ # in vLLM codebase.
94+ completion_requests = self .preprocess_request (
95+ request , fn_type = DetectorType .TEXT_CONTENT
96+ )
97+
98+ # Send all the completion requests asynchronously.
99+ tasks = [
100+ asyncio .create_task (
101+ self .process_chat_completion_with_scores (
102+ completion_request , raw_request
103+ )
104+ )
105+ for completion_request in completion_requests
106+ ]
107+
108+ # Gather all the results
109+ # NOTE: The results are guaranteed to be in order of requests
110+ results = await asyncio .gather (* tasks )
111+
112+ # If there is any error, return that otherwise, return the whole response
113+ # properly formatted.
114+ categorized_results = []
115+ for result in results :
116+ # NOTE: we are only sending 1 of the error results
117+ # and not every or not cumulative
118+ if isinstance (result , ErrorResponse ):
119+ return result
120+ else :
121+ # Process results to split out safety categories into separate objects
122+ categorized_results .append (self .__post_process_results (result ))
123+
124+ return ContentsDetectionResponse .from_chat_completion_response (
125+ categorized_results , request .contents
126+ )
0 commit comments