66
77import httpx
88import requests
9+ from copy import deepcopy
910from typing import Optional , Dict , Any , List
1011
1112import tuneapi .utils as tu
@@ -18,11 +19,12 @@ def __init__(
1819 self ,
1920 id : Optional [str ] = "claude-3-haiku-20240307" ,
2021 base_url : str = "https://api.anthropic.com/v1/messages" ,
22+ api_token : Optional [str ] = None ,
2123 extra_headers : Optional [Dict [str , str ]] = None ,
2224 ):
2325 self .model_id = id
2426 self .base_url = base_url
25- self .api_token = tu .ENV .ANTHROPIC_TOKEN ("" )
27+ self .api_token = api_token or tu .ENV .ANTHROPIC_TOKEN ("" )
2628 self .extra_headers = extra_headers
2729
2830 def set_api_token (self , token : str ) -> None :
@@ -60,13 +62,17 @@ def _process_input(
6062 prev_tool_id = tu .get_random_string (5 )
6163 for m in thread .chats [int (system != "" ) :]:
6264 if m .role == tt .Message .HUMAN :
63- msg = {
64- "role" : "user" ,
65- "content" : [{"type" : "text" , "text" : m .value .strip ()}],
66- }
65+ if isinstance (m .value , str ):
66+ content = [{"type" : "text" , "text" : m .value }]
67+ elif isinstance (m .value , list ):
68+ content = deepcopy (m .value )
69+ else :
70+ raise Exception (
71+ f"Unknown message type. Got: '{ type (m .value )} ', expected 'List[Dict[str, Any]]' or 'str'"
72+ )
6773 if m .images :
6874 for i in m .images :
69- msg [ " content" ] .append (
75+ content .append (
7076 {
7177 "type" : "image" ,
7278 "source" : {
@@ -76,14 +82,19 @@ def _process_input(
7682 },
7783 }
7884 )
85+ msg = {"role" : "user" , "content" : content }
7986 elif m .role == tt .Message .GPT :
80- msg = {
81- "role" : "assistant" ,
82- "content" : [{"type" : "text" , "text" : m .value .strip ()}],
83- }
87+ if isinstance (m .value , str ):
88+ content = [{"type" : "text" , "text" : m .value }]
89+ elif isinstance (m .value , list ):
90+ content = deepcopy (m .value )
91+ else :
92+ raise Exception (
93+ f"Unknown message type. Got: '{ type (m .value )} ', expected 'List[Dict[str, Any]]' or 'str'"
94+ )
8495 if m .images :
8596 for i in m .images :
86- msg [ " content" ] .append (
97+ content .append (
8798 {
8899 "type" : "image" ,
89100 "source" : {
@@ -93,6 +104,7 @@ def _process_input(
93104 },
94105 }
95106 )
107+ msg = {"role" : "assistant" , "content" : content }
96108 elif m .role == tt .Message .FUNCTION_CALL :
97109 _m = tu .from_json (m .value ) if isinstance (m .value , str ) else m .value
98110 msg = {
@@ -159,49 +171,64 @@ def _process_input(
159171
160172 return headers , data
161173
162- def _process_output (self , raw : bool , lines_fn : callable ):
174+ def _process_output (self , raw : bool , lines_fn : callable , yield_usage : bool ):
163175 fn_call = None
176+ usage_dict = {}
164177 for line in lines_fn ():
165178 if isinstance (line , bytes ):
166179 line = line .decode ().strip ()
167180 if not line or not "data:" in line :
168181 continue
169182
170- try :
171- # print(line)
172- resp = tu .from_json (line .replace ("data:" , "" ).strip ())
173- if resp ["type" ] == "content_block_start" :
174- if resp ["content_block" ]["type" ] == "tool_use" :
175- fn_call = {
176- "name" : resp ["content_block" ]["name" ],
177- "arguments" : "" ,
178- }
179- elif resp ["type" ] == "content_block_delta" :
180- delta = resp ["delta" ]
181- delta_type = delta ["type" ]
182- if delta_type == "text_delta" :
183- if raw :
184- yield b"data: " + tu .to_json (
185- {
186- "object" : delta_type ,
187- "choices" : [{"delta" : {"content" : delta ["text" ]}}],
188- },
189- tight = True ,
190- ).encode ()
191- yield b"" # uncomment this line if you want 1:1 with OpenAI
192- else :
193- yield delta ["text" ]
194- elif delta_type == "input_json_delta" :
195- fn_call ["arguments" ] += delta ["partial_json" ]
196- elif resp ["type" ] == "content_block_stop" :
197- if fn_call :
198- fn_call ["arguments" ] = tu .from_json (
199- fn_call ["arguments" ] or "{}"
200- )
201- yield fn_call
202- fn_call = None
203- except :
204- break
183+ resp = tu .from_json (line .replace ("data:" , "" ).strip ())
184+ if resp ["type" ] == "message_start" :
185+ usage = resp ["message" ]["usage" ]
186+ usage_dict .update (usage )
187+ elif resp ["type" ] == "content_block_start" :
188+ if resp ["content_block" ]["type" ] == "tool_use" :
189+ fn_call = {
190+ "name" : resp ["content_block" ]["name" ],
191+ "arguments" : "" ,
192+ }
193+ elif resp ["type" ] == "content_block_delta" :
194+ delta = resp ["delta" ]
195+ delta_type = delta ["type" ]
196+ if delta_type == "text_delta" :
197+ if raw :
198+ yield b"data: " + tu .to_json (
199+ {
200+ "object" : delta_type ,
201+ "choices" : [{"delta" : {"content" : delta ["text" ]}}],
202+ },
203+ tight = True ,
204+ ).encode ()
205+ yield b"" # uncomment this line if you want 1:1 with OpenAI
206+ else :
207+ yield delta ["text" ]
208+ elif delta_type == "input_json_delta" :
209+ fn_call ["arguments" ] += delta ["partial_json" ]
210+ elif resp ["type" ] == "content_block_stop" :
211+ if fn_call :
212+ fn_call ["arguments" ] = tu .from_json (fn_call ["arguments" ] or "{}" )
213+ yield fn_call
214+ fn_call = None
215+ elif resp ["type" ] == "message_delta" :
216+ usage_dict ["output_tokens" ] += resp ["usage" ]["output_tokens" ]
217+ cached_tokens = usage_dict .get (
218+ "cache_read_input_tokens" , 0
219+ ) or usage_dict .get ("cache_creation_input_tokens" , 0 )
220+ usage_obj = tt .Usage (
221+ input_tokens = usage_dict .pop ("input_tokens" ),
222+ output_tokens = usage_dict .pop ("output_tokens" ),
223+ cached_tokens = cached_tokens ,
224+ ** usage_dict ,
225+ )
226+ if yield_usage :
227+ if raw :
228+ yield b"data: " + usage_obj .to_json (tight = True ).encode ()
229+ yield b"" # uncomment this line if you want 1:1 with OpenAI
230+ else :
231+ yield usage_obj
205232
206233 # Interaction methods
207234
@@ -212,30 +239,35 @@ def chat(
212239 max_tokens : int = 1024 ,
213240 temperature : Optional [float ] = None ,
214241 token : Optional [str ] = None ,
215- return_message : bool = False ,
242+ usage : bool = False ,
216243 extra_headers : Optional [Dict [str , str ]] = None ,
217244 ** kwargs ,
218245 ):
219246 output = ""
247+ usage_obj = None
220248 fn_call = None
221249 for i in self .stream_chat (
222250 chats = chats ,
223251 model = model ,
224252 max_tokens = max_tokens ,
225253 temperature = temperature ,
226254 token = token ,
255+ usage = usage ,
227256 extra_headers = extra_headers ,
228257 raw = False ,
229258 ** kwargs ,
230259 ):
231260 if isinstance (i , dict ):
232261 fn_call = i .copy ()
262+ elif isinstance (i , tt .Usage ):
263+ usage_obj = i
233264 else :
234265 output += i
235- if return_message :
236- return output , fn_call
237266 if fn_call :
238- return fn_call
267+ output = fn_call
268+
269+ if usage :
270+ return output , usage_obj
239271 return output
240272
241273 def stream_chat (
@@ -246,6 +278,7 @@ def stream_chat(
246278 temperature : Optional [float ] = None ,
247279 token : Optional [str ] = None ,
248280 debug : bool = False ,
281+ usage : bool = False ,
249282 extra_headers : Optional [Dict [str , str ]] = None ,
250283 timeout = (5 , 30 ),
251284 raw : bool = False ,
@@ -262,19 +295,23 @@ def stream_chat(
262295 extra_headers = extra_headers ,
263296 ** kwargs ,
264297 )
265- r = requests .post (
266- self .base_url ,
267- headers = headers ,
268- json = data ,
269- timeout = timeout ,
270- )
271298 try :
299+ r = requests .post (
300+ self .base_url ,
301+ headers = headers ,
302+ json = data ,
303+ timeout = timeout ,
304+ )
272305 r .raise_for_status ()
273306 except Exception as e :
274307 yield r .text
275308 raise e
276309
277- yield from self ._process_output (raw = raw , lines_fn = r .iter_lines )
310+ yield from self ._process_output (
311+ raw = raw ,
312+ lines_fn = r .iter_lines ,
313+ yield_usage = usage ,
314+ )
278315
279316 async def chat_async (
280317 self ,
@@ -283,30 +320,35 @@ async def chat_async(
283320 max_tokens : int = 1024 ,
284321 temperature : Optional [float ] = None ,
285322 token : Optional [str ] = None ,
286- return_message : bool = False ,
323+ usage : bool = False ,
287324 extra_headers : Optional [Dict [str , str ]] = None ,
288325 ** kwargs ,
289326 ):
290327 output = ""
328+ usage_obj = None
291329 fn_call = None
292330 async for i in self .stream_chat_async (
293331 chats = chats ,
294332 model = model ,
295333 max_tokens = max_tokens ,
296334 temperature = temperature ,
297335 token = token ,
336+ usage = usage ,
298337 extra_headers = extra_headers ,
299338 raw = False ,
300339 ** kwargs ,
301340 ):
302341 if isinstance (i , dict ):
303342 fn_call = i .copy ()
343+ elif isinstance (i , tt .Usage ):
344+ usage_obj = i
304345 else :
305346 output += i
306- if return_message :
307- return output , fn_call
347+
308348 if fn_call :
309- return fn_call
349+ output = fn_call
350+ if usage :
351+ return output , usage_obj
310352 return output
311353
312354 async def stream_chat_async (
@@ -317,6 +359,7 @@ async def stream_chat_async(
317359 temperature : Optional [float ] = None ,
318360 token : Optional [str ] = None ,
319361 debug : bool = False ,
362+ usage : bool = False ,
320363 extra_headers : Optional [Dict [str , str ]] = None ,
321364 timeout = (5 , 30 ),
322365 raw : bool = False ,
@@ -351,6 +394,7 @@ async def stream_chat_async(
351394 for x in self ._process_output (
352395 raw = raw ,
353396 lines_fn = chunk .decode ("utf-8" ).splitlines ,
397+ yield_usage = usage ,
354398 ):
355399 yield x
356400
0 commit comments