Skip to content

Commit 3db3348

Browse files
authored
Merge pull request #119 from AgentOps-AI/cached-tokens
Add support for cached tokens in the costs module.
2 parents 209f987 + af8af44 commit 3db3348

File tree

4 files changed

+107
-10
lines changed

4 files changed

+107
-10
lines changed

.github/workflows/run-tests.yml

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
name: Run Tests
2+
3+
on:
4+
pull_request:
5+
branches: [ main ]
6+
push:
7+
branches: [ main ]
8+
9+
jobs:
10+
test:
11+
runs-on: ubuntu-latest
12+
strategy:
13+
matrix:
14+
python-version: ['3.10', '3.11']
15+
16+
steps:
17+
- uses: actions/checkout@v3
18+
19+
- name: Set up Python ${{ matrix.python-version }}
20+
uses: actions/setup-python@v4
21+
with:
22+
python-version: ${{ matrix.python-version }}
23+
24+
- name: Install dependencies
25+
run: |
26+
python -m pip install --upgrade pip
27+
python -m pip install -e ".[dev]"
28+
29+
- name: Run tests
30+
run: |
31+
python -m pytest tests
32+
env:
33+
# This handles the Anthropic tests that require an API key
34+
# Uses a dummy key for tests since they're skipped without a key
35+
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY || 'dummy-key-for-testing' }}

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ dev = [
3838
"tach>=0.6.9",
3939
"tabulate>=0.9.0",
4040
"pandas>=2.1.0",
41+
"python-dotenv>=1.0.0",
4142
]
4243

4344
[project.urls]

tests/test_costs.py

Lines changed: 39 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,14 @@
11
#!/usr/bin/env python
22
# -*- coding: utf-8 -*-
33

4+
import os
45
import pytest
56
from decimal import Decimal
7+
from dotenv import load_dotenv
8+
9+
# Load environment variables for ANTHROPIC_API_KEY
10+
load_dotenv()
11+
612
from tokencost.costs import (
713
count_message_tokens,
814
count_string_tokens,
@@ -46,7 +52,11 @@
4652
("gpt-4-vision-preview", 15),
4753
("gpt-4o", 15),
4854
("azure/gpt-4o", 15),
49-
("claude-3-opus-latest", 11),
55+
pytest.param("claude-3-opus-latest", 11,
56+
marks=pytest.mark.skipif(
57+
bool(os.getenv("ANTHROPIC_API_KEY")),
58+
reason="ANTHROPIC_API_KEY environment variable not set"
59+
)),
5060
],
5161
)
5262
def test_count_message_tokens(model, expected_output):
@@ -154,8 +164,12 @@ def test_count_string_invalid_model():
154164
(MESSAGES, "gpt-4-1106-preview", Decimal("0.00015")),
155165
(MESSAGES, "gpt-4-vision-preview", Decimal("0.00015")),
156166
(MESSAGES, "gpt-4o", Decimal("0.0000375")),
157-
(MESSAGES, "azure/gpt-4o", Decimal("0.000075")),
158-
(MESSAGES, "claude-3-opus-latest", Decimal("0.000165")),
167+
(MESSAGES, "azure/gpt-4o", Decimal("0.0000375")),
168+
pytest.param(MESSAGES, "claude-3-opus-latest", Decimal("0.000165"),
169+
marks=pytest.mark.skipif(
170+
bool(os.getenv("ANTHROPIC_API_KEY")),
171+
reason="ANTHROPIC_API_KEY environment variable not set"
172+
)),
159173
(STRING, "text-embedding-ada-002", Decimal("0.0000004")),
160174
],
161175
)
@@ -191,7 +205,7 @@ def test_invalid_prompt_format():
191205
(STRING, "gpt-4-1106-preview", Decimal("0.00012")),
192206
(STRING, "gpt-4-vision-preview", Decimal("0.00012")),
193207
(STRING, "gpt-4o", Decimal("0.00004")),
194-
(STRING, "azure/gpt-4o", Decimal("0.000060")),
208+
(STRING, "azure/gpt-4o", Decimal("0.00004")),
195209
# (STRING, "claude-3-opus-latest", Decimal("0.000096")), # NOTE: Claude only supports messages
196210
(STRING, "text-embedding-ada-002", 0),
197211
],
@@ -230,9 +244,30 @@ def test_calculate_invalid_input_types():
230244
(10, "gpt-3.5-turbo", "input", Decimal("0.0000150")), # Example values
231245
(5, "gpt-4", "output", Decimal("0.00030")), # Example values
232246
(10, "ai21.j2-mid-v1", "input", Decimal("0.0001250")), # Example values
247+
(100, "gpt-4o", "cached", Decimal("0.000125")), # Cache tokens test
233248
],
234249
)
235250
def test_calculate_cost_by_tokens(num_tokens, model, token_type, expected_output):
236251
"""Test that the token cost calculation is correct."""
237252
cost = calculate_cost_by_tokens(num_tokens, model, token_type)
238253
assert cost == expected_output
254+
255+
256+
def test_calculate_cached_tokens_cost():
257+
"""Test that cached tokens cost calculation works correctly."""
258+
# Basic test for cache token cost calculation
259+
model = "gpt-4o"
260+
num_tokens = 1000
261+
token_type = "cached"
262+
263+
# Get the expected cost from the TOKEN_COSTS dictionary
264+
from tokencost.constants import TOKEN_COSTS
265+
cache_cost_per_token = TOKEN_COSTS[model]["cache_read_input_token_cost"]
266+
expected_cost = Decimal(str(cache_cost_per_token)) * Decimal(num_tokens)
267+
268+
# Calculate the actual cost
269+
actual_cost = calculate_cost_by_tokens(num_tokens, model, token_type)
270+
271+
# Assert that the costs match
272+
assert actual_cost == expected_cost
273+
assert actual_cost > 0, "Cache token cost should be greater than zero"

tokencost/costs.py

Lines changed: 32 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import os
66
import tiktoken
77
import anthropic
8-
from typing import Union, List, Dict
8+
from typing import Union, List, Dict, Literal
99
from .constants import TOKEN_COSTS
1010
from decimal import Decimal
1111
import logging
@@ -16,6 +16,31 @@
1616
# https://github.com/anthropics/anthropic-tokenizer-typescript/blob/main/index.ts
1717

1818

19+
TokenType = Literal["input", "output", "cached"]
20+
21+
22+
def _get_field_from_token_type(token_type: TokenType) -> str:
23+
"""
24+
Get the field name from the token type.
25+
26+
Args:
27+
token_type (TokenType): The token type.
28+
29+
Returns:
30+
str: The field name to use for the token cost data in the TOKEN_COSTS dictionary.
31+
"""
32+
lookups = {
33+
"input": "input_cost_per_token",
34+
"output": "output_cost_per_token",
35+
"cached": "cache_read_input_token_cost",
36+
}
37+
38+
try:
39+
return lookups[token_type]
40+
except KeyError:
41+
raise ValueError(f"Invalid token type: {token_type}.")
42+
43+
1944
def get_anthropic_token_count(messages: List[Dict[str, str]], model: str) -> int:
2045
if not any(
2146
supported_model in model
@@ -160,7 +185,7 @@ def count_string_tokens(prompt: str, model: str) -> int:
160185
return len(encoding.encode(prompt))
161186

162187

163-
def calculate_cost_by_tokens(num_tokens: int, model: str, token_type: str) -> Decimal:
188+
def calculate_cost_by_tokens(num_tokens: int, model: str, token_type: TokenType) -> Decimal:
164189
"""
165190
Calculate the cost based on the number of tokens and the model.
166191
@@ -179,10 +204,11 @@ def calculate_cost_by_tokens(num_tokens: int, model: str, token_type: str) -> De
179204
Double-check your spelling, or submit an issue/PR"""
180205
)
181206

182-
cost_per_token_key = (
183-
"input_cost_per_token" if token_type == "input" else "output_cost_per_token"
184-
)
185-
cost_per_token = TOKEN_COSTS[model][cost_per_token_key]
207+
try:
208+
token_key = _get_field_from_token_type(token_type)
209+
cost_per_token = TOKEN_COSTS[model][token_key]
210+
except KeyError:
211+
raise KeyError(f"Model {model} does not have cost data for `{token_type}` tokens.")
186212

187213
return Decimal(str(cost_per_token)) * Decimal(num_tokens)
188214

0 commit comments

Comments
 (0)