77import httpx
88import requests
99from copy import deepcopy
10- from typing import Optional , Dict , Any , List
10+ from typing import Optional , Dict , Any , List , Tuple
1111
1212import tuneapi .utils as tu
1313import tuneapi .types as tt
@@ -24,35 +24,27 @@ def __init__(
2424 ):
2525 self .model_id = id
2626 self .base_url = base_url
27+ self .batch_url = base_url + "/batches"
2728 self .api_token = api_token or tu .ENV .ANTHROPIC_TOKEN ("" )
2829 self .extra_headers = extra_headers
2930
3031 def set_api_token (self , token : str ) -> None :
3132 self .api_token = token
3233
33- def _process_input (
34- self ,
35- chats : tt .Thread | str ,
36- model : Optional [str ] = None ,
37- max_tokens : int = 1024 ,
38- temperature : Optional [float ] = None ,
39- token : Optional [str ] = None ,
40- debug : bool = False ,
41- extra_headers : Optional [Dict [str , str ]] = None ,
42- ** kwargs ,
43- ):
34+ def _process_header (self , token : str ) -> Dict [str , str ]:
4435 if not token and not self .api_token : # type: ignore
4536 raise Exception (
4637 "Please set ANTHROPIC_TOKEN environment variable or pass through function"
4738 )
4839 token = token or self .api_token
49- if isinstance ( chats , tt . Thread ):
50- thread = chats
51- elif isinstance ( chats , str ):
52- thread = tt . Thread ( tt . human ( chats ))
53- else :
54- raise Exception ( "Invalid input" )
40+ return {
41+ "x-api-key" : token ,
42+ "Content-Type" : "application/json" ,
43+ "anthropic-version" : "2023-06-01" ,
44+ "anthropic-beta" : "tools-2024-05-16" ,
45+ }
5546
47+ def _process_thread (self , thread : tt .Thread ) -> Tuple [str , List [Dict [str , Any ]]]:
5648 # create the anthropic style data
5749 system = ""
5850 if thread .chats [0 ].role == tt .Message .SYSTEM :
@@ -134,13 +126,29 @@ def _process_input(
134126 raise Exception (f"Unknown role: { m .role } " )
135127 claude_messages .append (msg )
136128
137- headers = {
138- "x-api-key" : token ,
139- "Content-Type" : "application/json" ,
140- "anthropic-version" : "2023-06-01" ,
141- "anthropic-beta" : "tools-2024-05-16" ,
142- }
143- # return headers, system.strip(), claude_messages
129+ return system , claude_messages
130+
131+ def _process_input (
132+ self ,
133+ chats : tt .Thread | str ,
134+ model : Optional [str ] = None ,
135+ max_tokens : int = 1024 ,
136+ temperature : Optional [float ] = None ,
137+ token : Optional [str ] = None ,
138+ debug : bool = False ,
139+ extra_headers : Optional [Dict [str , str ]] = None ,
140+ stream : bool = True ,
141+ ** kwargs ,
142+ ):
143+ if isinstance (chats , tt .Thread ):
144+ thread = chats
145+ elif isinstance (chats , str ):
146+ thread = tt .Thread (tt .human (chats ))
147+ else :
148+ raise Exception ("Invalid input" )
149+
150+ system , claude_messages = self ._process_thread (thread )
151+ headers = self ._process_header (token )
144152
145153 tools = []
146154 if isinstance (chats , tt .Thread ) and chats .tools :
@@ -157,7 +165,7 @@ def _process_input(
157165 "messages" : claude_messages ,
158166 "system" : system ,
159167 "tools" : tools ,
160- "stream" : True ,
168+ "stream" : stream ,
161169 }
162170 if temperature :
163171 data ["temperature" ] = temperature
@@ -274,7 +282,7 @@ def stream_chat(
274282 self ,
275283 chats : tt .Thread | str ,
276284 model : Optional [str ] = None ,
277- max_tokens : int = 1024 ,
285+ max_tokens : int = 4096 ,
278286 temperature : Optional [float ] = None ,
279287 token : Optional [str ] = None ,
280288 debug : bool = False ,
@@ -355,7 +363,7 @@ async def stream_chat_async(
355363 self ,
356364 chats : tt .Thread | str ,
357365 model : Optional [str ] = None ,
358- max_tokens : int = 1024 ,
366+ max_tokens : int = 4096 ,
359367 temperature : Optional [float ] = None ,
360368 token : Optional [str ] = None ,
361369 debug : bool = False ,
@@ -439,3 +447,131 @@ async def distributed_chat_async(
439447 debug = debug ,
440448 ** kwargs ,
441449 )
450+
451+ def submit_batch (
452+ self ,
453+ threads : List [tt .Thread | str ],
454+ model : Optional [str ] = None ,
455+ max_tokens : int = 4096 ,
456+ temperature : Optional [float ] = None ,
457+ token : Optional [str ] = None ,
458+ debug : bool = False ,
459+ extra_headers : Optional [Dict [str , str ]] = None ,
460+ timeout = (5 , 30 ),
461+ raw : bool = False ,
462+ ** kwargs ,
463+ ) -> Tuple [str , List [str ]] | Dict :
464+ bodies = []
465+ custom_ids = []
466+ for chats in threads :
467+ headers , data = self ._process_input (
468+ chats = chats ,
469+ model = model ,
470+ max_tokens = max_tokens ,
471+ temperature = temperature ,
472+ token = token ,
473+ extra_headers = extra_headers ,
474+ stream = False ,
475+ ** kwargs ,
476+ )
477+ custom_id = "tuneapi_" + tu .get_random_string (10 )
478+ custom_ids .append (custom_id )
479+ bodies .append ({"custom_id" : custom_id , "params" : data })
480+ body = {"requests" : bodies }
481+ if debug :
482+ fp = "sample_anthropic_batch.json"
483+ print ("Saving at path " + fp )
484+ tu .to_json (body , fp = fp )
485+
486+ r = requests .post (
487+ url = self .batch_url ,
488+ headers = headers ,
489+ timeout = timeout ,
490+ json = body ,
491+ )
492+ try :
493+ r .raise_for_status ()
494+ except Exception as e :
495+ tu .logger .error (f"Coudn't submit batch: { r .text } " )
496+ raise e
497+ resp = r .json ()
498+
499+ if raw :
500+ return resp
501+ return resp ["id" ], custom_ids
502+
503+ def get_batch (
504+ self ,
505+ batch_id : str ,
506+ custom_ids : Optional [List [str ]] = None ,
507+ usage : bool = False ,
508+ token : Optional [str ] = None ,
509+ raw : bool = False ,
510+ verbose : bool = False ,
511+ ) -> Tuple [List [Any ] | Dict , str | None ]:
512+ headers = self ._process_header (token )
513+ r = requests .get (self .batch_url + "/" + batch_id , headers = headers )
514+ try :
515+ r .raise_for_status ()
516+ except Exception as e :
517+ tu .logger .error (f"Coudn't get batch: { r .text } " )
518+ raise e
519+ resp = r .json ()
520+ if resp ["processing_status" ] != "ended" :
521+ if verbose :
522+ tu .logger .info (
523+ f"Batch { batch_id } has not ended. Status: { resp ['processing_status' ]} "
524+ )
525+ return None , resp ["processing_status" ]
526+ results_url = resp ["results_url" ]
527+
528+ # fetch the results, response is a JSONL, fucntion return shoudl be a List of JSONs
529+ r = requests .get (results_url , headers = headers )
530+ try :
531+ r .raise_for_status ()
532+ except Exception as e :
533+ tu .logger .error (f"Coudn't get batch results: { r .text } " )
534+ raise e
535+
536+ output = []
537+ for line in r .iter_lines ():
538+ if not line :
539+ continue
540+ output .append (tu .from_json (line ))
541+
542+ if custom_ids :
543+ # each item in output has a key called "custom_id" sort on the basis of incoming custom_ids
544+ output = sorted (output , key = lambda x : custom_ids .index (x ["custom_id" ]))
545+
546+ if raw :
547+ return output , None
548+
549+ _usage = tt .Usage (0 , 0 )
550+ for o in output :
551+ u = o ["result" ]["message" ]["usage" ]
552+ _usage += tt .Usage (
553+ input_tokens = u .pop ("input_tokens" ),
554+ output_tokens = u .pop ("output_tokens" ),
555+ cached_tokens = u .get ("cache_read_input_tokens" , 0 )
556+ or u .get ("cache_creation_input_tokens" , 0 ),
557+ ** u ,
558+ )
559+
560+ parsed_output = [o ["result" ]["message" ]["content" ][0 ] for o in output ]
561+ final_output = []
562+ for o in parsed_output :
563+ if o ["type" ] == "text" :
564+ final_output .append (o ["text" ])
565+ elif o ["type" ] == "tool_use" :
566+ final_output .append (
567+ {
568+ "name" : o ["name" ],
569+ "arguments" : o ["input" ],
570+ }
571+ )
572+ else :
573+ raise ValueError (f"Unknown message content: { o ['type' ]} " )
574+
575+ if usage :
576+ return final_output , None , _usage
577+ return final_output , None
0 commit comments