99from anthropic import Anthropic
1010from openai import OpenAI
1111
12+ from agentlab .llm import tracking
13+
1214from .base_api import BaseModelArgs
1315
1416type ContentItem = Dict [str , Any ]
@@ -269,6 +271,20 @@ def __init__(
269271 max_tokens = max_tokens ,
270272 extra_kwargs = extra_kwargs ,
271273 )
274+
275+ # Get pricing information
276+
277+ try :
278+ pricing = tracking .get_pricing_anthropic ()
279+ self .input_cost = float (pricing [model_name ]["prompt" ])
280+ self .output_cost = float (pricing [model_name ]["completion" ])
281+ except KeyError :
282+ logging .warning (
283+ f"Model { model_name } not found in the pricing information, prices are set to 0. Maybe try upgrading langchain_community."
284+ )
285+ self .input_cost = 0.0
286+ self .output_cost = 0.0
287+
272288 self .client = Anthropic (api_key = api_key )
273289
274290 def _call_api (self , messages : list [dict | MessageBuilder ]) -> dict :
@@ -286,6 +302,17 @@ def _call_api(self, messages: list[dict | MessageBuilder]) -> dict:
286302 max_tokens = self .max_tokens ,
287303 ** self .extra_kwargs ,
288304 )
305+ input_tokens = response .usage .input_tokens
306+ output_tokens = response .usage .output_tokens
307+ cost = input_tokens * self .input_cost + output_tokens * self .output_cost
308+
309+ print (f"response.usage: { response .usage } " )
310+
311+ if hasattr (tracking .TRACKER , "instance" ) and isinstance (
312+ tracking .TRACKER .instance , tracking .LLMTracker
313+ ):
314+ tracking .TRACKER .instance (input_tokens , output_tokens , cost )
315+
289316 return response
290317 except Exception as e :
291318 logging .error (f"Failed to get a response from the API: { e } " )
0 commit comments