Skip to content

Commit 16cc3cd

Browse files
committed
Add pricing tracking for Anthropic model and refactor pricing functions
1 parent 544908e commit 16cc3cd

File tree

2 files changed

+52
-2
lines changed

2 files changed

+52
-2
lines changed

src/agentlab/llm/response_api.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
from anthropic import Anthropic
1010
from openai import OpenAI
1111

12+
from agentlab.llm import tracking
13+
1214
from .base_api import BaseModelArgs
1315

1416
type ContentItem = Dict[str, Any]
@@ -269,6 +271,20 @@ def __init__(
269271
max_tokens=max_tokens,
270272
extra_kwargs=extra_kwargs,
271273
)
274+
275+
# Get pricing information
276+
277+
try:
278+
pricing = tracking.get_pricing_anthropic()
279+
self.input_cost = float(pricing[model_name]["prompt"])
280+
self.output_cost = float(pricing[model_name]["completion"])
281+
except KeyError:
282+
logging.warning(
283+
f"Model {model_name} not found in the pricing information, prices are set to 0. Maybe try upgrading langchain_community."
284+
)
285+
self.input_cost = 0.0
286+
self.output_cost = 0.0
287+
272288
self.client = Anthropic(api_key=api_key)
273289

274290
def _call_api(self, messages: list[dict | MessageBuilder]) -> dict:
@@ -286,6 +302,17 @@ def _call_api(self, messages: list[dict | MessageBuilder]) -> dict:
286302
max_tokens=self.max_tokens,
287303
**self.extra_kwargs,
288304
)
305+
input_tokens = response.usage.input_tokens
306+
output_tokens = response.usage.output_tokens
307+
cost = input_tokens * self.input_cost + output_tokens * self.output_cost
308+
309+
print(f"response.usage: {response.usage}")
310+
311+
if hasattr(tracking.TRACKER, "instance") and isinstance(
312+
tracking.TRACKER.instance, tracking.LLMTracker
313+
):
314+
tracking.TRACKER.instance(input_tokens, output_tokens, cost)
315+
289316
return response
290317
except Exception as e:
291318
logging.error(f"Failed to get a response from the API: {e}")

src/agentlab/llm/tracking.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
import os
2+
import re
23
import threading
34
from contextlib import contextmanager
45
from functools import cache
56

67
import requests
7-
from langchain_community.callbacks.openai_info import MODEL_COST_PER_1K_TOKENS
8+
from langchain_community.callbacks import bedrock_anthropic_callback, openai_info
89

910
TRACKER = threading.local()
1011

@@ -85,7 +86,7 @@ def get_pricing_openrouter():
8586

8687

8788
def get_pricing_openai():
88-
cost_dict = MODEL_COST_PER_1K_TOKENS
89+
cost_dict = openai_info.MODEL_COST_PER_1K_TOKENS
8990
cost_dict = {k: v / 1000 for k, v in cost_dict.items()}
9091
res = {}
9192
for k in cost_dict:
@@ -99,3 +100,25 @@ def get_pricing_openai():
99100
"completion": cost_dict[completion_key],
100101
}
101102
return res
103+
104+
105+
def _remove_version_suffix(model_name):
106+
no_version = re.sub(r"-v\d+(?:[.:]\d+)?$", "", model_name)
107+
return re.sub(r"anthropic.", "", no_version)
108+
109+
110+
def get_pricing_anthropic():
111+
input_cost_dict = bedrock_anthropic_callback.MODEL_COST_PER_1K_INPUT_TOKENS
112+
output_cost_dict = bedrock_anthropic_callback.MODEL_COST_PER_1K_OUTPUT_TOKENS
113+
114+
res = {}
115+
for k, v in input_cost_dict.items():
116+
k = _remove_version_suffix(k)
117+
res[k] = {"prompt": v / 1000}
118+
119+
for k, v in output_cost_dict.items():
120+
k = _remove_version_suffix(k)
121+
if k not in res:
122+
res[k] = {}
123+
res[k]["completion"] = v / 1000
124+
return res

0 commit comments

Comments
 (0)