Skip to content

Commit 73823b9

Browse files
committed
[0.7.0] adds async support for distributed chat
1 parent ac3af6b commit 73823b9

File tree

11 files changed

+992
-17
lines changed

11 files changed

+992
-17
lines changed

cookbooks/function_calling.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -297,7 +297,7 @@
297297
"name": "python",
298298
"nbconvert_exporter": "python",
299299
"pygments_lexer": "ipython3",
300-
"version": "3.11.7"
300+
"version": "3.13.0"
301301
}
302302
},
303303
"nbformat": 4,

docs/changelog.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,12 @@ minor versions.
77

88
All relevant steps to be taken will be mentioned here.
99

10+
0.7.0
11+
-----
12+
13+
- All models now have ``<model>.distributed_chat_async`` that can be used in servers without blocking the main event
14+
loop. This will give a much needed UX improvement to the entire system.
15+
1016
0.6.3
1117
-----
1218

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "tuneapi"
3-
version = "0.6.3"
3+
version = "0.7.0"
44
description = "Tune AI APIs."
55
authors = ["Frello Technology Private Limited <[email protected]>"]
66
license = "MIT"
@@ -17,6 +17,7 @@ tqdm = "^4.66.1"
1717
snowflake_id = "1.0.2"
1818
nutree = "0.8.0"
1919
pillow = "^10.2.0"
20+
httpx = "^0.28.1"
2021
protobuf = { version = "^5.27.3", optional = true }
2122
boto3 = { version = "1.29.6", optional = true }
2223

tuneapi/apis/model_anthropic.py

Lines changed: 161 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,13 @@
44

55
# Copyright © 2024- Frello Technology Private Limited
66

7-
import json
7+
import httpx
88
import requests
99
from typing import Optional, Dict, Any, Tuple, List
1010

1111
import tuneapi.utils as tu
1212
import tuneapi.types as tt
13-
from tuneapi.apis.turbo import distributed_chat
13+
from tuneapi.apis.turbo import distributed_chat, distributed_chat_async
1414

1515

1616
class Anthropic(tt.ModelInterface):
@@ -203,7 +203,7 @@ def stream_chat(
203203

204204
try:
205205
# print(line)
206-
resp = json.loads(line.replace("data:", "").strip())
206+
resp = tu.from_json(line.replace("data:", "").strip())
207207
if resp["type"] == "content_block_start":
208208
if resp["content_block"]["type"] == "tool_use":
209209
fn_call = {
@@ -229,20 +229,155 @@ def stream_chat(
229229
fn_call["arguments"] += delta["partial_json"]
230230
elif resp["type"] == "content_block_stop":
231231
if fn_call:
232-
fn_call["arguments"] = json.loads(fn_call["arguments"] or "{}")
232+
fn_call["arguments"] = tu.from_json(
233+
fn_call["arguments"] or "{}"
234+
)
233235
yield fn_call
234236
fn_call = None
235237
except:
236238
break
237239
return
238240

241+
async def chat_async(
242+
self,
243+
chats: tt.Thread | str,
244+
model: Optional[str] = None,
245+
max_tokens: int = 1024,
246+
temperature: Optional[float] = None,
247+
token: Optional[str] = None,
248+
return_message: bool = False,
249+
extra_headers: Optional[Dict[str, str]] = None,
250+
**kwargs,
251+
):
252+
output = ""
253+
fn_call = None
254+
async for i in self.stream_chat_async(
255+
chats=chats,
256+
model=model,
257+
max_tokens=max_tokens,
258+
temperature=temperature,
259+
token=token,
260+
extra_headers=extra_headers,
261+
raw=False,
262+
**kwargs,
263+
):
264+
if isinstance(i, dict):
265+
fn_call = i.copy()
266+
else:
267+
output += i
268+
if return_message:
269+
return output, fn_call
270+
if fn_call:
271+
return fn_call
272+
return output
273+
274+
async def stream_chat_async(
275+
self,
276+
chats: tt.Thread | str,
277+
model: Optional[str] = None,
278+
max_tokens: int = 1024,
279+
temperature: Optional[float] = None,
280+
token: Optional[str] = None,
281+
timeout=(5, 30),
282+
raw: bool = False,
283+
debug: bool = False,
284+
extra_headers: Optional[Dict[str, str]] = None,
285+
**kwargs,
286+
) -> Any:
287+
288+
tools = []
289+
if isinstance(chats, tt.Thread):
290+
tools = [x.to_dict() for x in chats.tools]
291+
for t in tools:
292+
t["input_schema"] = t.pop("parameters")
293+
headers, system, claude_messages = self._process_input(chats=chats, token=token)
294+
extra_headers = extra_headers or self.extra_headers
295+
if extra_headers:
296+
headers.update(extra_headers)
297+
298+
data = {
299+
"model": model or self.model_id,
300+
"max_tokens": max_tokens,
301+
"messages": claude_messages,
302+
"system": system,
303+
"tools": tools,
304+
"stream": True,
305+
}
306+
if temperature:
307+
data["temperature"] = temperature
308+
if kwargs:
309+
data.update(kwargs)
310+
311+
if debug:
312+
fp = "sample_anthropic.json"
313+
print("Saving at path " + fp)
314+
tu.to_json(data, fp=fp)
315+
316+
async with httpx.AsyncClient() as client:
317+
response = await client.post(
318+
self.base_url,
319+
headers=headers,
320+
json=data,
321+
timeout=timeout,
322+
)
323+
try:
324+
response.raise_for_status()
325+
except Exception as e:
326+
yield str(e)
327+
return
328+
329+
async for chunk in response.aiter_bytes():
330+
for line in chunk.decode("utf-8").splitlines():
331+
line = line.strip()
332+
if not line or not "data:" in line:
333+
continue
334+
335+
try:
336+
# print(line)
337+
resp = tu.from_json(line.replace("data:", "").strip())
338+
if resp["type"] == "content_block_start":
339+
if resp["content_block"]["type"] == "tool_use":
340+
fn_call = {
341+
"name": resp["content_block"]["name"],
342+
"arguments": "",
343+
}
344+
elif resp["type"] == "content_block_delta":
345+
delta = resp["delta"]
346+
delta_type = delta["type"]
347+
if delta_type == "text_delta":
348+
if raw:
349+
yield b"data: " + tu.to_json(
350+
{
351+
"object": delta_type,
352+
"choices": [
353+
{"delta": {"content": delta["text"]}}
354+
],
355+
},
356+
tight=True,
357+
).encode()
358+
yield b"" # uncomment this line if you want 1:1 with OpenAI
359+
else:
360+
yield delta["text"]
361+
elif delta_type == "input_json_delta":
362+
fn_call["arguments"] += delta["partial_json"]
363+
elif resp["type"] == "content_block_stop":
364+
if fn_call:
365+
fn_call["arguments"] = tu.from_json(
366+
fn_call["arguments"] or "{}"
367+
)
368+
yield fn_call
369+
fn_call = None
370+
except:
371+
break
372+
239373
def distributed_chat(
240374
self,
241375
prompts: List[tt.Thread],
242376
post_logic: Optional[callable] = None,
243377
max_threads: int = 10,
244378
retry: int = 3,
245379
pbar=True,
380+
debug=False,
246381
**kwargs,
247382
):
248383
return distributed_chat(
@@ -252,5 +387,27 @@ def distributed_chat(
252387
max_threads=max_threads,
253388
retry=retry,
254389
pbar=pbar,
390+
debug=debug,
391+
**kwargs,
392+
)
393+
394+
async def distributed_chat_async(
395+
self,
396+
prompts: List[tt.Thread],
397+
post_logic: Optional[callable] = None,
398+
max_threads: int = 10,
399+
retry: int = 3,
400+
pbar=True,
401+
debug=False,
402+
**kwargs,
403+
):
404+
return await distributed_chat_async(
405+
self,
406+
prompts=prompts,
407+
post_logic=post_logic,
408+
max_threads=max_threads,
409+
retry=retry,
410+
pbar=pbar,
411+
debug=debug,
255412
**kwargs,
256413
)

0 commit comments

Comments
 (0)