11# Standard
22from dataclasses import dataclass
33from typing import Optional
4+ from unittest .mock import patch
45import asyncio
56
67# Third Party
78from vllm .config import MultiModalConfig
9+ from vllm .entrypoints .openai .protocol import (
10+ ChatCompletionLogProb ,
11+ ChatCompletionLogProbs ,
12+ ChatCompletionLogProbsContent ,
13+ ChatCompletionResponse ,
14+ ChatCompletionResponseChoice ,
15+ ChatMessage ,
16+ UsageInfo ,
17+ )
818from vllm .entrypoints .openai .serving_models import BaseModelPath , OpenAIServingModels
919import jinja2
20+ import pytest
1021import pytest_asyncio
1122
1223# Local
1324from vllm_detector_adapter .generative_detectors .base import ChatCompletionDetectionBase
25+ from vllm_detector_adapter .protocol import (
26+ ContentsDetectionRequest ,
27+ ContentsDetectionResponse ,
28+ )
1429
1530MODEL_NAME = "openai-community/gpt2"
1631CHAT_TEMPLATE = "Dummy chat template for testing {}"
@@ -82,6 +97,55 @@ async def detection_base():
8297 return _async_serving_detection_completion_init ()
8398
8499
100+ @pytest .fixture (scope = "module" )
101+ def granite_completion_response ():
102+ log_probs_content_no = ChatCompletionLogProbsContent (
103+ token = "no" ,
104+ logprob = - 0.0013 ,
105+ # 5 logprobs requested for scoring, skipping bytes for conciseness
106+ top_logprobs = [
107+ ChatCompletionLogProb (token = "no" , logprob = - 0.053 ),
108+ ChatCompletionLogProb (token = "0" , logprob = - 6.61 ),
109+ ChatCompletionLogProb (token = "1" , logprob = - 16.90 ),
110+ ChatCompletionLogProb (token = "2" , logprob = - 17.39 ),
111+ ChatCompletionLogProb (token = "3" , logprob = - 17.61 ),
112+ ],
113+ )
114+ log_probs_content_yes = ChatCompletionLogProbsContent (
115+ token = "yes" ,
116+ logprob = - 0.0013 ,
117+ # 5 logprobs requested for scoring, skipping bytes for conciseness
118+ top_logprobs = [
119+ ChatCompletionLogProb (token = "yes" , logprob = - 0.0013 ),
120+ ChatCompletionLogProb (token = "0" , logprob = - 6.61 ),
121+ ChatCompletionLogProb (token = "1" , logprob = - 16.90 ),
122+ ChatCompletionLogProb (token = "2" , logprob = - 17.39 ),
123+ ChatCompletionLogProb (token = "3" , logprob = - 17.61 ),
124+ ],
125+ )
126+ choice_0 = ChatCompletionResponseChoice (
127+ index = 0 ,
128+ message = ChatMessage (
129+ role = "assistant" ,
130+ content = "no" ,
131+ ),
132+ logprobs = ChatCompletionLogProbs (content = [log_probs_content_no ]),
133+ )
134+ choice_1 = ChatCompletionResponseChoice (
135+ index = 1 ,
136+ message = ChatMessage (
137+ role = "assistant" ,
138+ content = "yes" ,
139+ ),
140+ logprobs = ChatCompletionLogProbs (content = [log_probs_content_yes ]),
141+ )
142+ yield ChatCompletionResponse (
143+ model = MODEL_NAME ,
144+ choices = [choice_0 , choice_1 ],
145+ usage = UsageInfo (prompt_tokens = 200 , total_tokens = 206 , completion_tokens = 6 ),
146+ )
147+
148+
85149### Tests #####################################################################
86150
87151
@@ -97,3 +161,39 @@ def test_async_serving_detection_completion_init(detection_base):
97161 output_template = detection_completion .output_template
98162 assert type (output_template ) == jinja2 .environment .Template
99163 assert output_template .render (({"text" : "moose" })) == "bye moose"
164+
165+
166+ def test_content_analysis_success (detection_base , granite_completion_response ):
167+ base_instance = asyncio .run (detection_base )
168+
169+ content_request = ContentsDetectionRequest (
170+ contents = ["Where do I find geese?" , "You could go to Canada" ]
171+ )
172+
173+ scores = [0.9 , 0.1 , 0.21 , 0.54 , 0.33 ]
174+ response = (granite_completion_response , scores , "risk" )
175+ with patch (
176+ "vllm_detector_adapter.generative_detectors.base.ChatCompletionDetectionBase.process_chat_completion_with_scores" ,
177+ return_value = response ,
178+ ):
179+ result = asyncio .run (base_instance .content_analysis (content_request ))
180+ assert isinstance (result , ContentsDetectionResponse )
181+ detections = result .model_dump ()
182+ assert len (detections ) == 2
183+ print (detections )
184+ # For first content
185+ assert detections [0 ][0 ]["detection" ] == "no"
186+ assert detections [0 ][0 ]["score" ] == 0.9
187+ assert detections [0 ][0 ]["start" ] == 0
188+ assert detections [0 ][0 ]["end" ] == len (content_request .contents [0 ])
189+ # 2nd choice as 2nd label
190+ assert detections [0 ][1 ]["detection" ] == "yes"
191+ assert detections [0 ][1 ]["score" ] == 0.1
192+ assert detections [0 ][1 ]["start" ] == 0
193+ assert detections [0 ][1 ]["end" ] == len (content_request .contents [0 ])
194+ # For 2nd content, we are only testing 1st detection for simplicity
195+ # Note: detection is same, because of how mock is working.
196+ assert detections [1 ][0 ]["detection" ] == "no"
197+ assert detections [1 ][0 ]["score" ] == 0.9
198+ assert detections [1 ][0 ]["start" ] == 0
199+ assert detections [1 ][0 ]["end" ] == len (content_request .contents [1 ])
0 commit comments