11import pytest
2+ import time
23from openai import OpenAI
34from utils import *
45
@@ -10,7 +11,6 @@ def create_server():
1011 global server
1112 server = ServerPreset .tinyllama2 ()
1213
13-
1414@pytest .mark .parametrize ("prompt,n_predict,re_content,n_prompt,n_predicted,truncated" , [
1515 ("I believe the meaning of life is" , 8 , "(going|bed)+" , 18 , 8 , False ),
1616 ("Write a joke about AI from a very long prompt which will not be truncated" , 256 , "(princesses|everyone|kids|Anna|forest)+" , 46 , 64 , False ),
@@ -52,24 +52,6 @@ def test_completion_stream(prompt: str, n_predict: int, re_content: str, n_promp
5252 content += data ["content" ]
5353
5454
55- # FIXME: This test is not working because /completions endpoint is not OAI-compatible
56- @pytest .mark .skip (reason = "Only /chat/completions is OAI-compatible for now" )
57- def test_completion_with_openai_library ():
58- global server
59- server .start ()
60- client = OpenAI (api_key = "dummy" , base_url = f"http://{ server .server_host } :{ server .server_port } " )
61- res = client .completions .create (
62- model = "gpt-3.5-turbo-instruct" ,
63- prompt = "I believe the meaning of life is" ,
64- max_tokens = 8 ,
65- seed = 42 ,
66- temperature = 0.8 ,
67- )
68- print (res )
69- assert res .choices [0 ].finish_reason == "length"
70- assert match_regex ("(going|bed)+" , res .choices [0 ].text )
71-
72-
7355@pytest .mark .parametrize ("n_slots" , [1 , 2 ])
7456def test_consistent_result_same_seed (n_slots : int ):
7557 global server
@@ -121,4 +103,97 @@ def test_consistent_result_different_batch_size(n_batch: int, temperature: float
121103 assert res .body ["content" ] == last_res .body ["content" ]
122104 last_res = res
123105
124- # TODO: add completion with tokens as input, mixed token+string input
106+
107+ def test_completion_with_tokens_input ():
108+ global server
109+ server .temperature = 0.0
110+ server .start ()
111+ prompt_str = "I believe the meaning of life is"
112+ res = server .make_request ("POST" , "/tokenize" , data = {
113+ "content" : prompt_str ,
114+ "add_special" : True ,
115+ })
116+ assert res .status_code == 200
117+ tokens = res .body ["tokens" ]
118+
119+ # single completion
120+ res = server .make_request ("POST" , "/completion" , data = {
121+ "prompt" : tokens ,
122+ })
123+ assert res .status_code == 200
124+ assert type (res .body ["content" ]) == str
125+
126+ # batch completion
127+ res = server .make_request ("POST" , "/completion" , data = {
128+ "prompt" : [tokens , tokens ],
129+ })
130+ assert res .status_code == 200
131+ assert type (res .body ) == list
132+ assert len (res .body ) == 2
133+ assert res .body [0 ]["content" ] == res .body [1 ]["content" ]
134+
135+ # mixed string and tokens
136+ res = server .make_request ("POST" , "/completion" , data = {
137+ "prompt" : [tokens , prompt_str ],
138+ })
139+ assert res .status_code == 200
140+ assert type (res .body ) == list
141+ assert len (res .body ) == 2
142+ assert res .body [0 ]["content" ] == res .body [1 ]["content" ]
143+
144+ # mixed string and tokens in one sequence
145+ res = server .make_request ("POST" , "/completion" , data = {
146+ "prompt" : [1 , 2 , 3 , 4 , 5 , 6 , prompt_str , 7 , 8 , 9 , 10 , prompt_str ],
147+ })
148+ assert res .status_code == 200
149+ assert type (res .body ["content" ]) == str
150+
151+
152+ @pytest .mark .parametrize ("n_slots,n_requests" , [
153+ (1 , 3 ),
154+ (2 , 2 ),
155+ (2 , 4 ),
156+ (4 , 2 ), # some slots must be idle
157+ (4 , 6 ),
158+ ])
159+ def test_completion_parallel_slots (n_slots : int , n_requests : int ):
160+ global server
161+ server .n_slots = n_slots
162+ server .temperature = 0.0
163+ server .start ()
164+
165+ PROMPTS = [
166+ ("Write a very long book." , "(very|special|big)+" ),
167+ ("Write another a poem." , "(small|house)+" ),
168+ ("What is LLM?" , "(Dad|said)+" ),
169+ ("The sky is blue and I love it." , "(climb|leaf)+" ),
170+ ("Write another very long music lyrics." , "(friends|step|sky)+" ),
171+ ("Write a very long joke." , "(cat|Whiskers)+" ),
172+ ]
173+ def check_slots_status ():
174+ should_all_slots_busy = n_requests >= n_slots
175+ time .sleep (0.1 )
176+ res = server .make_request ("GET" , "/slots" )
177+ n_busy = sum ([1 for slot in res .body if slot ["is_processing" ]])
178+ if should_all_slots_busy :
179+ assert n_busy == n_slots
180+ else :
181+ assert n_busy <= n_slots
182+
183+ tasks = []
184+ for i in range (n_requests ):
185+ prompt , re_content = PROMPTS [i % len (PROMPTS )]
186+ tasks .append ((server .make_request , ("POST" , "/completion" , {
187+ "prompt" : prompt ,
188+ "seed" : 42 ,
189+ "temperature" : 1.0 ,
190+ })))
191+ tasks .append ((check_slots_status , ()))
192+ results = parallel_function_calls (tasks )
193+
194+ # check results
195+ for i in range (n_requests ):
196+ prompt , re_content = PROMPTS [i % len (PROMPTS )]
197+ res = results [i ]
198+ assert res .status_code == 200
199+ assert match_regex (re_content , res .body ["content" ])
0 commit comments