Skip to content

Commit 78e3cb3

Browse files
committed
add parallel completion test
1 parent 1c2f0f7 commit 78e3cb3

File tree

2 files changed

+131
-43
lines changed

2 files changed

+131
-43
lines changed

examples/server/tests/unit/test_completion.py

Lines changed: 95 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import pytest
2+
import time
23
from openai import OpenAI
34
from utils import *
45

@@ -10,7 +11,6 @@ def create_server():
1011
global server
1112
server = ServerPreset.tinyllama2()
1213

13-
1414
@pytest.mark.parametrize("prompt,n_predict,re_content,n_prompt,n_predicted,truncated", [
1515
("I believe the meaning of life is", 8, "(going|bed)+", 18, 8, False),
1616
("Write a joke about AI from a very long prompt which will not be truncated", 256, "(princesses|everyone|kids|Anna|forest)+", 46, 64, False),
@@ -52,24 +52,6 @@ def test_completion_stream(prompt: str, n_predict: int, re_content: str, n_promp
5252
content += data["content"]
5353

5454

55-
# FIXME: This test is not working because /completions endpoint is not OAI-compatible
56-
@pytest.mark.skip(reason="Only /chat/completions is OAI-compatible for now")
57-
def test_completion_with_openai_library():
58-
global server
59-
server.start()
60-
client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}")
61-
res = client.completions.create(
62-
model="gpt-3.5-turbo-instruct",
63-
prompt="I believe the meaning of life is",
64-
max_tokens=8,
65-
seed=42,
66-
temperature=0.8,
67-
)
68-
print(res)
69-
assert res.choices[0].finish_reason == "length"
70-
assert match_regex("(going|bed)+", res.choices[0].text)
71-
72-
7355
@pytest.mark.parametrize("n_slots", [1, 2])
7456
def test_consistent_result_same_seed(n_slots: int):
7557
global server
@@ -121,4 +103,97 @@ def test_consistent_result_different_batch_size(n_batch: int, temperature: float
121103
assert res.body["content"] == last_res.body["content"]
122104
last_res = res
123105

124-
# TODO: add completion with tokens as input, mixed token+string input
106+
107+
def test_completion_with_tokens_input():
108+
global server
109+
server.temperature = 0.0
110+
server.start()
111+
prompt_str = "I believe the meaning of life is"
112+
res = server.make_request("POST", "/tokenize", data={
113+
"content": prompt_str,
114+
"add_special": True,
115+
})
116+
assert res.status_code == 200
117+
tokens = res.body["tokens"]
118+
119+
# single completion
120+
res = server.make_request("POST", "/completion", data={
121+
"prompt": tokens,
122+
})
123+
assert res.status_code == 200
124+
assert type(res.body["content"]) == str
125+
126+
# batch completion
127+
res = server.make_request("POST", "/completion", data={
128+
"prompt": [tokens, tokens],
129+
})
130+
assert res.status_code == 200
131+
assert type(res.body) == list
132+
assert len(res.body) == 2
133+
assert res.body[0]["content"] == res.body[1]["content"]
134+
135+
# mixed string and tokens
136+
res = server.make_request("POST", "/completion", data={
137+
"prompt": [tokens, prompt_str],
138+
})
139+
assert res.status_code == 200
140+
assert type(res.body) == list
141+
assert len(res.body) == 2
142+
assert res.body[0]["content"] == res.body[1]["content"]
143+
144+
# mixed string and tokens in one sequence
145+
res = server.make_request("POST", "/completion", data={
146+
"prompt": [1, 2, 3, 4, 5, 6, prompt_str, 7, 8, 9, 10, prompt_str],
147+
})
148+
assert res.status_code == 200
149+
assert type(res.body["content"]) == str
150+
151+
152+
@pytest.mark.parametrize("n_slots,n_requests", [
153+
(1, 3),
154+
(2, 2),
155+
(2, 4),
156+
(4, 2), # some slots must be idle
157+
(4, 6),
158+
])
159+
def test_completion_parallel_slots(n_slots: int, n_requests: int):
160+
global server
161+
server.n_slots = n_slots
162+
server.temperature = 0.0
163+
server.start()
164+
165+
PROMPTS = [
166+
("Write a very long book.", "(very|special|big)+"),
167+
("Write another a poem.", "(small|house)+"),
168+
("What is LLM?", "(Dad|said)+"),
169+
("The sky is blue and I love it.", "(climb|leaf)+"),
170+
("Write another very long music lyrics.", "(friends|step|sky)+"),
171+
("Write a very long joke.", "(cat|Whiskers)+"),
172+
]
173+
def check_slots_status():
174+
should_all_slots_busy = n_requests >= n_slots
175+
time.sleep(0.1)
176+
res = server.make_request("GET", "/slots")
177+
n_busy = sum([1 for slot in res.body if slot["is_processing"]])
178+
if should_all_slots_busy:
179+
assert n_busy == n_slots
180+
else:
181+
assert n_busy <= n_slots
182+
183+
tasks = []
184+
for i in range(n_requests):
185+
prompt, re_content = PROMPTS[i % len(PROMPTS)]
186+
tasks.append((server.make_request, ("POST", "/completion", {
187+
"prompt": prompt,
188+
"seed": 42,
189+
"temperature": 1.0,
190+
})))
191+
tasks.append((check_slots_status, ()))
192+
results = parallel_function_calls(tasks)
193+
194+
# check results
195+
for i in range(n_requests):
196+
prompt, re_content = PROMPTS[i % len(PROMPTS)]
197+
res = results[i]
198+
assert res.status_code == 200
199+
assert match_regex(re_content, res.body["content"])

examples/server/tests/utils.py

Lines changed: 36 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import threading
1212
import requests
1313
import time
14+
from concurrent.futures import ThreadPoolExecutor, as_completed
1415
from typing import (
1516
Any,
1617
Callable,
@@ -19,7 +20,7 @@
1920
Iterator,
2021
List,
2122
Literal,
22-
Sequence,
23+
Tuple,
2324
Set,
2425
)
2526
from re import RegexFlag
@@ -28,7 +29,7 @@
2829
class ServerResponse:
2930
headers: dict
3031
status_code: int
31-
body: dict
32+
body: dict | Any
3233

3334

3435
class ServerProcess:
@@ -322,30 +323,42 @@ def jina_reranker_tiny() -> ServerProcess:
322323
return server
323324

324325

325-
def multiple_post_requests(
326-
server: ServerProcess, path: str, data: Sequence[dict], headers: dict | None = None
327-
) -> Sequence[ServerResponse]:
328-
def worker(data_chunk):
329-
try:
330-
return server.make_request("POST", path, data=data_chunk, headers=headers)
331-
except Exception as e:
332-
print(f"Error occurred: {e}", file=sys.stderr)
333-
os._exit(1) # terminate main thread
334-
335-
threads = []
336-
results = []
326+
def parallel_function_calls(function_list: List[Tuple[Callable[..., Any], Tuple[Any, ...]]]) -> List[Any]:
327+
"""
328+
Run multiple functions in parallel and return results in the same order as calls. Equivalent to Promise.all in JS.
337329
338-
def thread_target(data_chunk):
339-
result = worker(data_chunk)
340-
results.append(result)
330+
Example usage:
341331
342-
for chunk in data:
343-
thread = threading.Thread(target=thread_target, args=(chunk,))
344-
threads.append(thread)
345-
thread.start()
332+
results = parallel_function_calls([
333+
(func1, (arg1, arg2)),
334+
(func2, (arg3, arg4)),
335+
])
336+
"""
337+
results = [None] * len(function_list)
338+
exceptions = []
346339

347-
for thread in threads:
348-
thread.join()
340+
def worker(index, func, args):
341+
try:
342+
result = func(*args)
343+
results[index] = result
344+
except Exception as e:
345+
exceptions.append((index, str(e)))
346+
347+
with ThreadPoolExecutor() as executor:
348+
futures = []
349+
for i, (func, args) in enumerate(function_list):
350+
future = executor.submit(worker, i, func, args)
351+
futures.append(future)
352+
353+
# Wait for all futures to complete
354+
for future in as_completed(futures):
355+
pass
356+
357+
# Check if there were any exceptions
358+
if exceptions:
359+
print("Exceptions occurred:")
360+
for index, error in exceptions:
361+
print(f"Function at index {index}: {error}")
349362

350363
return results
351364

0 commit comments

Comments
 (0)