Skip to content

Commit c1e27e1

Browse files
committed
use litellm pricing for Azure and Anthropic ChatModels
1 parent 9c87325 commit c1e27e1

File tree

1 file changed

+41
-7
lines changed

1 file changed

+41
-7
lines changed

src/agentlab/llm/chat_api.py

Lines changed: 41 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -433,7 +433,7 @@ def __init__(
433433
min_retry_wait_time=min_retry_wait_time,
434434
client_class=OpenAI,
435435
client_args=client_args,
436-
pricing_func=tracking.get_pricing_openai,
436+
pricing_func=tracking.partial(tracking.get_pricing_litellm, model_name=model_name),
437437
log_probs=log_probs,
438438
)
439439

@@ -492,6 +492,7 @@ def __init__(
492492
temperature=0.5,
493493
max_tokens=100,
494494
max_retry=4,
495+
pricing_func=None,
495496
):
496497
self.model_name = model_name
497498
self.temperature = temperature
@@ -501,6 +502,22 @@ def __init__(
501502
api_key = api_key or os.getenv("ANTHROPIC_API_KEY")
502503
self.client = anthropic.Anthropic(api_key=api_key)
503504

505+
# Get pricing information
506+
if pricing_func:
507+
pricings = pricing_func()
508+
try:
509+
self.input_cost = float(pricings[model_name]["prompt"])
510+
self.output_cost = float(pricings[model_name]["completion"])
511+
except KeyError:
512+
logging.warning(
513+
f"Model {model_name} not found in the pricing information, prices are set to 0. Maybe try upgrading langchain_community."
514+
)
515+
self.input_cost = 0.0
516+
self.output_cost = 0.0
517+
else:
518+
self.input_cost = 0.0
519+
self.output_cost = 0.0
520+
504521
def __call__(self, messages: list[dict], n_samples: int = 1, temperature: float = None) -> dict:
505522
# Convert OpenAI format to Anthropic format
506523
system_message = None
@@ -528,13 +545,29 @@ def __call__(self, messages: list[dict], n_samples: int = 1, temperature: float
528545

529546
response = self.client.messages.create(**kwargs)
530547

548+
usage = getattr(response, "usage", {})
549+
new_input_tokens = getattr(usage, "input_tokens", 0)
550+
output_tokens = getattr(usage, "output_tokens", 0)
551+
cache_read_tokens = getattr(usage, "cache_input_tokens", 0)
552+
cache_write_tokens = getattr(usage, "cache_creation_input_tokens", 0)
553+
cache_read_cost = (
554+
self.input_cost * tracking.ANTHROPIC_CACHE_PRICING_FACTOR["cache_read_tokens"]
555+
)
556+
cache_write_cost = (
557+
self.input_cost * tracking.ANTHROPIC_CACHE_PRICING_FACTOR["cache_write_tokens"]
558+
)
559+
cost = (
560+
new_input_tokens * self.input_cost
561+
+ output_tokens * self.output_cost
562+
+ cache_read_tokens * cache_read_cost
563+
+ cache_write_tokens * cache_write_cost
564+
)
565+
531566
# Track usage if available
532-
if hasattr(tracking.TRACKER, "instance"):
533-
tracking.TRACKER.instance(
534-
response.usage.input_tokens,
535-
response.usage.output_tokens,
536-
0, # cost calculation would need pricing info
537-
)
567+
if hasattr(tracking.TRACKER, "instance") and isinstance(
568+
tracking.TRACKER.instance, tracking.LLMTracker
569+
):
570+
tracking.TRACKER.instance(new_input_tokens, output_tokens, cost)
538571

539572
return AIMessage(response.content[0].text)
540573

@@ -552,6 +585,7 @@ def make_model(self):
552585
model_name=self.model_name,
553586
temperature=self.temperature,
554587
max_tokens=self.max_new_tokens,
588+
pricing_func=partial(tracking.get_pricing_litellm, model_name=self.model_name),
555589
)
556590

557591

0 commit comments

Comments
 (0)