44
55# Copyright © 2024- Frello Technology Private Limited
66
7- import json
7+ import httpx
88import requests
99from typing import Optional , Dict , Any , Tuple , List
1010
1111import tuneapi .utils as tu
1212import tuneapi .types as tt
13- from tuneapi .apis .turbo import distributed_chat
13+ from tuneapi .apis .turbo import distributed_chat , distributed_chat_async
1414
1515
1616class Anthropic (tt .ModelInterface ):
@@ -203,7 +203,7 @@ def stream_chat(
203203
204204 try :
205205 # print(line)
206- resp = json . loads (line .replace ("data:" , "" ).strip ())
206+ resp = tu . from_json (line .replace ("data:" , "" ).strip ())
207207 if resp ["type" ] == "content_block_start" :
208208 if resp ["content_block" ]["type" ] == "tool_use" :
209209 fn_call = {
@@ -229,20 +229,155 @@ def stream_chat(
229229 fn_call ["arguments" ] += delta ["partial_json" ]
230230 elif resp ["type" ] == "content_block_stop" :
231231 if fn_call :
232- fn_call ["arguments" ] = json .loads (fn_call ["arguments" ] or "{}" )
232+ fn_call ["arguments" ] = tu .from_json (
233+ fn_call ["arguments" ] or "{}"
234+ )
233235 yield fn_call
234236 fn_call = None
235237 except :
236238 break
237239 return
238240
241+ async def chat_async (
242+ self ,
243+ chats : tt .Thread | str ,
244+ model : Optional [str ] = None ,
245+ max_tokens : int = 1024 ,
246+ temperature : Optional [float ] = None ,
247+ token : Optional [str ] = None ,
248+ return_message : bool = False ,
249+ extra_headers : Optional [Dict [str , str ]] = None ,
250+ ** kwargs ,
251+ ):
252+ output = ""
253+ fn_call = None
254+ async for i in self .stream_chat_async (
255+ chats = chats ,
256+ model = model ,
257+ max_tokens = max_tokens ,
258+ temperature = temperature ,
259+ token = token ,
260+ extra_headers = extra_headers ,
261+ raw = False ,
262+ ** kwargs ,
263+ ):
264+ if isinstance (i , dict ):
265+ fn_call = i .copy ()
266+ else :
267+ output += i
268+ if return_message :
269+ return output , fn_call
270+ if fn_call :
271+ return fn_call
272+ return output
273+
274+ async def stream_chat_async (
275+ self ,
276+ chats : tt .Thread | str ,
277+ model : Optional [str ] = None ,
278+ max_tokens : int = 1024 ,
279+ temperature : Optional [float ] = None ,
280+ token : Optional [str ] = None ,
281+ timeout = (5 , 30 ),
282+ raw : bool = False ,
283+ debug : bool = False ,
284+ extra_headers : Optional [Dict [str , str ]] = None ,
285+ ** kwargs ,
286+ ) -> Any :
287+
288+ tools = []
289+ if isinstance (chats , tt .Thread ):
290+ tools = [x .to_dict () for x in chats .tools ]
291+ for t in tools :
292+ t ["input_schema" ] = t .pop ("parameters" )
293+ headers , system , claude_messages = self ._process_input (chats = chats , token = token )
294+ extra_headers = extra_headers or self .extra_headers
295+ if extra_headers :
296+ headers .update (extra_headers )
297+
298+ data = {
299+ "model" : model or self .model_id ,
300+ "max_tokens" : max_tokens ,
301+ "messages" : claude_messages ,
302+ "system" : system ,
303+ "tools" : tools ,
304+ "stream" : True ,
305+ }
306+ if temperature :
307+ data ["temperature" ] = temperature
308+ if kwargs :
309+ data .update (kwargs )
310+
311+ if debug :
312+ fp = "sample_anthropic.json"
313+ print ("Saving at path " + fp )
314+ tu .to_json (data , fp = fp )
315+
316+ async with httpx .AsyncClient () as client :
317+ response = await client .post (
318+ self .base_url ,
319+ headers = headers ,
320+ json = data ,
321+ timeout = timeout ,
322+ )
323+ try :
324+ response .raise_for_status ()
325+ except Exception as e :
326+ yield str (e )
327+ return
328+
329+ async for chunk in response .aiter_bytes ():
330+ for line in chunk .decode ("utf-8" ).splitlines ():
331+ line = line .strip ()
332+ if not line or not "data:" in line :
333+ continue
334+
335+ try :
336+ # print(line)
337+ resp = tu .from_json (line .replace ("data:" , "" ).strip ())
338+ if resp ["type" ] == "content_block_start" :
339+ if resp ["content_block" ]["type" ] == "tool_use" :
340+ fn_call = {
341+ "name" : resp ["content_block" ]["name" ],
342+ "arguments" : "" ,
343+ }
344+ elif resp ["type" ] == "content_block_delta" :
345+ delta = resp ["delta" ]
346+ delta_type = delta ["type" ]
347+ if delta_type == "text_delta" :
348+ if raw :
349+ yield b"data: " + tu .to_json (
350+ {
351+ "object" : delta_type ,
352+ "choices" : [
353+ {"delta" : {"content" : delta ["text" ]}}
354+ ],
355+ },
356+ tight = True ,
357+ ).encode ()
358+ yield b"" # uncomment this line if you want 1:1 with OpenAI
359+ else :
360+ yield delta ["text" ]
361+ elif delta_type == "input_json_delta" :
362+ fn_call ["arguments" ] += delta ["partial_json" ]
363+ elif resp ["type" ] == "content_block_stop" :
364+ if fn_call :
365+ fn_call ["arguments" ] = tu .from_json (
366+ fn_call ["arguments" ] or "{}"
367+ )
368+ yield fn_call
369+ fn_call = None
370+ except :
371+ break
372+
239373 def distributed_chat (
240374 self ,
241375 prompts : List [tt .Thread ],
242376 post_logic : Optional [callable ] = None ,
243377 max_threads : int = 10 ,
244378 retry : int = 3 ,
245379 pbar = True ,
380+ debug = False ,
246381 ** kwargs ,
247382 ):
248383 return distributed_chat (
@@ -252,5 +387,27 @@ def distributed_chat(
252387 max_threads = max_threads ,
253388 retry = retry ,
254389 pbar = pbar ,
390+ debug = debug ,
391+ ** kwargs ,
392+ )
393+
394+ async def distributed_chat_async (
395+ self ,
396+ prompts : List [tt .Thread ],
397+ post_logic : Optional [callable ] = None ,
398+ max_threads : int = 10 ,
399+ retry : int = 3 ,
400+ pbar = True ,
401+ debug = False ,
402+ ** kwargs ,
403+ ):
404+ return await distributed_chat_async (
405+ self ,
406+ prompts = prompts ,
407+ post_logic = post_logic ,
408+ max_threads = max_threads ,
409+ retry = retry ,
410+ pbar = pbar ,
411+ debug = debug ,
255412 ** kwargs ,
256413 )
0 commit comments