Skip to content

Commit 97009ad

Browse files
committed
[chore] add multithreaded chat benchmarking code
1 parent 091a33d commit 97009ad

File tree

5 files changed

+280
-2
lines changed

5 files changed

+280
-2
lines changed

benchmarks/threaded_map.py

Lines changed: 275 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,275 @@
1+
# Copyright © 2024- Frello Technology Private Limited
2+
3+
from tuneapi.types import Thread, ModelInterface, human, system
4+
from tuneapi.utils import from_json
5+
6+
# import re
7+
import queue
8+
import threading
9+
from tqdm import trange
10+
from typing import List, Optional
11+
from dataclasses import dataclass
12+
13+
from concurrent.futures import ThreadPoolExecutor, as_completed, Future
14+
15+
16+
@dataclass
17+
class Task:
18+
index: int
19+
model: ModelInterface
20+
prompt: Thread
21+
retry_count: int = 0
22+
23+
24+
@dataclass
25+
class Result:
26+
index: int
27+
data: any
28+
success: bool
29+
error: Optional[Exception] = None
30+
31+
32+
def bulk_chat(
33+
model: ModelInterface,
34+
prompts: List[Thread],
35+
post_logic: Optional[callable] = None,
36+
max_threads: int = 10,
37+
retry: int = 3,
38+
pbar=True,
39+
):
40+
task_channel = queue.Queue()
41+
result_channel = queue.Queue()
42+
43+
# Initialize results container
44+
results = [None for _ in range(len(prompts))]
45+
46+
def worker():
47+
while True:
48+
try:
49+
task: Task = task_channel.get(timeout=1)
50+
if task is None: # Poison pill
51+
break
52+
53+
try:
54+
# print(">")
55+
out = task.model.chat(task.prompt)
56+
if post_logic:
57+
out = post_logic(out)
58+
result_channel.put(Result(task.index, out, True))
59+
except Exception as e:
60+
if task.retry_count < retry:
61+
# Create new model instance for retry
62+
nm = model.__class__(
63+
id=model.model_id,
64+
base_url=model.base_url,
65+
extra_headers=model.extra_headers,
66+
)
67+
nm.set_api_token(model.api_token)
68+
# Increment retry count and requeue
69+
task_channel.put(
70+
Task(task.index, nm, task.prompt, task.retry_count + 1)
71+
)
72+
else:
73+
# If we've exhausted retries, store the error
74+
result_channel.put(Result(task.index, e, False, e))
75+
finally:
76+
task_channel.task_done()
77+
except queue.Empty:
78+
continue
79+
80+
# Create and start worker threads
81+
workers = []
82+
for _ in range(max_threads):
83+
t = threading.Thread(target=worker)
84+
t.start()
85+
workers.append(t)
86+
87+
# Initialize progress bar
88+
_pbar = trange(len(prompts), desc="Processing", unit=" input") if pbar else None
89+
90+
# Queue initial tasks
91+
for i, p in enumerate(prompts):
92+
nm = model.__class__(
93+
id=model.model_id,
94+
base_url=model.base_url,
95+
extra_headers=model.extra_headers,
96+
)
97+
nm.set_api_token(model.api_token)
98+
task_channel.put(Task(i, nm, p))
99+
100+
# Process results
101+
completed = 0
102+
while completed < len(prompts):
103+
try:
104+
result = result_channel.get(timeout=1)
105+
results[result.index] = result.data if result.success else result.error
106+
if _pbar:
107+
_pbar.update(1)
108+
completed += 1
109+
result_channel.task_done()
110+
except queue.Empty:
111+
continue
112+
113+
# Cleanup
114+
for _ in workers:
115+
task_channel.put(None) # Send poison pills
116+
for w in workers:
117+
w.join()
118+
119+
if _pbar:
120+
_pbar.close()
121+
122+
return results
123+
124+
125+
prompts = []
126+
for x in range(100):
127+
prompts.append(
128+
Thread(
129+
system(
130+
"""## Response schmea
131+
132+
Respond with the following schema **ensure sending <json> and </json> tags**.
133+
134+
```
135+
<json>
136+
{{
137+
"code": "...",
138+
}}
139+
</json>
140+
```
141+
"""
142+
),
143+
human(
144+
f"what is the value of 10 ^ {max(x, 10)}. Write down the answer in Indian number system. given in coe tag."
145+
),
146+
)
147+
)
148+
149+
150+
import random
151+
152+
153+
def get_tagged_section(tag: str, input_str: str):
154+
if random.random() > 0.5:
155+
import re
156+
157+
html_pattern = re.compile("<" + tag + ">(.*?)</" + tag + ">", re.DOTALL)
158+
match = html_pattern.search(input_str)
159+
if match:
160+
return match.group(1)
161+
162+
md_pattern = re.compile("```" + tag + "(.*?)```", re.DOTALL)
163+
match = md_pattern.search(input_str)
164+
if match:
165+
return match.group(1)
166+
return None
167+
168+
169+
post_logic = lambda out: from_json(get_tagged_section("json", out))["code"]
170+
171+
# import re
172+
from tuneapi.apis import Openai
173+
174+
from time import time
175+
176+
st = time()
177+
out = bulk_chat(
178+
Openai(),
179+
prompts,
180+
post_logic=post_logic,
181+
max_threads=5,
182+
pbar=True,
183+
retry=3,
184+
)
185+
print(out)
186+
print(len(out))
187+
print(len([x for x in out if x is None]))
188+
189+
print(f"Endtime: {time() - st:0.4f}s")
190+
191+
print("\n\n\n")
192+
193+
194+
from uuid import uuid4
195+
from typing import Generator
196+
197+
198+
def bulk_chat_2(
199+
model: ModelInterface,
200+
prompts: List[Thread],
201+
post_logic: Optional[callable] = None,
202+
max_threads: int = 10,
203+
retry: int = 3,
204+
pbar=True,
205+
):
206+
def _chat(model: ModelInterface, prompt: Thread):
207+
out = model.chat(prompt)
208+
if post_logic:
209+
return post_logic(out) # The mapped function
210+
return out
211+
212+
# create all the inputs
213+
retry = int(retry) # so False becomes 0 and True becomes 1
214+
inputs = []
215+
for p in prompts:
216+
nm = model.__class__(
217+
id=model.model_id,
218+
base_url=model.base_url,
219+
extra_headers=model.extra_headers,
220+
)
221+
nm.set_api_token(model.api_token)
222+
inputs.append((nm, p))
223+
224+
# run the executor
225+
_name = str(uuid4())
226+
if isinstance(inputs, Generator):
227+
inputs = list(inputs)
228+
229+
results = [None for _ in range(len(inputs))]
230+
_pbar = trange(len(inputs), desc="Processing", unit=" input") if pbar else None
231+
with ThreadPoolExecutor(max_workers=max_threads, thread_name_prefix=_name) as exe:
232+
_fn = lambda x: _chat(*x)
233+
loop_cntr = 0
234+
done = False
235+
inputs = [(i, x) for i, x in enumerate(inputs)]
236+
237+
# loop over things
238+
while not done:
239+
failed = []
240+
_pbar.set_description(f"Starting master loop #{loop_cntr:02d}")
241+
futures = {exe.submit(_fn, x): (i, x) for (i, x) in inputs}
242+
for fut in as_completed(futures):
243+
# print(">")
244+
i, x = futures[fut] # indexes
245+
try:
246+
res = fut.result()
247+
if _pbar:
248+
_pbar.update(1)
249+
results[i] = res
250+
except Exception as e:
251+
failed.append((i, x))
252+
253+
# overide values for the loop
254+
inputs = failed
255+
256+
# the done flag
257+
loop_cntr += 1
258+
done = len(failed) == 0 or loop_cntr > retry
259+
return results
260+
261+
262+
st = time()
263+
out = bulk_chat_2(
264+
Openai(),
265+
prompts,
266+
post_logic=post_logic,
267+
max_threads=5,
268+
pbar=True,
269+
retry=3,
270+
)
271+
print(out)
272+
print(len(out))
273+
print(len([x for x in out if x is None]))
274+
275+
print(f"Endtime: {time() - st:0.4f}s")

docs/changelog.rst

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ Changelog
22
=========
33

44
This package is already used in production at Tune AI, please do not wait for release ``1.x.x`` for stability, or expect
5-
to reach ``1.0.0``. We do not follow the general rules of semantic versioning, and there can be breaking changes between
5+
to reach ``1.0.0``. We **do not follow the general rules** of semantic versioning, and there can be breaking changes between
66
minor versions.
77

88
All relevant steps to be taken will be mentioned here.
@@ -12,6 +12,7 @@ All relevant steps to be taken will be mentioned here.
1212

1313
- ``distributed_chat`` functionality in ``tuneapi.apis.turbo`` support. In all APIs search for ``model.distributed_chat()``
1414
method. This enables **fault tolerant LLM API calls**.
15+
- Moved ``tuneapi.types.experimental`` to ``tuneapi.types.evals``
1516

1617
0.5.13
1718
-----

tests/test_tree.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
# Copyright © 2024- Frello Technology Private Limited
2+
13
import tuneapi.types as tt
24

35
from unittest import TestCase, main as ut_main

tuneapi/types/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,6 @@
1414
function_resp,
1515
)
1616

17-
from tuneapi.types.experimental import (
17+
from tuneapi.types.evals import (
1818
Evals,
1919
)

0 commit comments

Comments
 (0)