Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions .github/workflows/run-tests.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
name: Run Tests

on:
pull_request:
branches: [ main ]
push:
branches: [ main ]

jobs:
test:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ['3.10', '3.11']

steps:
- uses: actions/checkout@v3

- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}

- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip install -e ".[dev]"

- name: Run tests
run: |
python -m pytest tests
env:
# This handles the Anthropic tests that require an API key
# Uses a dummy key for tests since they're skipped without a key
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY || 'dummy-key-for-testing' }}
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ dev = [
"tach>=0.6.9",
"tabulate>=0.9.0",
"pandas>=2.1.0",
"python-dotenv>=1.0.0",
]

[project.urls]
Expand Down
43 changes: 39 additions & 4 deletions tests/test_costs.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,14 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-

import os
import pytest
from decimal import Decimal
from dotenv import load_dotenv

# Load environment variables for ANTHROPIC_API_KEY
load_dotenv()

from tokencost.costs import (
count_message_tokens,
count_string_tokens,
Expand Down Expand Up @@ -46,7 +52,11 @@
("gpt-4-vision-preview", 15),
("gpt-4o", 15),
("azure/gpt-4o", 15),
("claude-3-opus-latest", 11),
pytest.param("claude-3-opus-latest", 11,
marks=pytest.mark.skipif(
bool(os.getenv("ANTHROPIC_API_KEY")),
reason="ANTHROPIC_API_KEY environment variable not set"
)),
],
)
def test_count_message_tokens(model, expected_output):
Expand Down Expand Up @@ -154,8 +164,12 @@ def test_count_string_invalid_model():
(MESSAGES, "gpt-4-1106-preview", Decimal("0.00015")),
(MESSAGES, "gpt-4-vision-preview", Decimal("0.00015")),
(MESSAGES, "gpt-4o", Decimal("0.0000375")),
(MESSAGES, "azure/gpt-4o", Decimal("0.000075")),
(MESSAGES, "claude-3-opus-latest", Decimal("0.000165")),
(MESSAGES, "azure/gpt-4o", Decimal("0.0000375")),
pytest.param(MESSAGES, "claude-3-opus-latest", Decimal("0.000165"),
marks=pytest.mark.skipif(
bool(os.getenv("ANTHROPIC_API_KEY")),
reason="ANTHROPIC_API_KEY environment variable not set"
)),
(STRING, "text-embedding-ada-002", Decimal("0.0000004")),
],
)
Expand Down Expand Up @@ -191,7 +205,7 @@ def test_invalid_prompt_format():
(STRING, "gpt-4-1106-preview", Decimal("0.00012")),
(STRING, "gpt-4-vision-preview", Decimal("0.00012")),
(STRING, "gpt-4o", Decimal("0.00004")),
(STRING, "azure/gpt-4o", Decimal("0.000060")),
(STRING, "azure/gpt-4o", Decimal("0.00004")),
# (STRING, "claude-3-opus-latest", Decimal("0.000096")), # NOTE: Claude only supports messages
(STRING, "text-embedding-ada-002", 0),
],
Expand Down Expand Up @@ -230,9 +244,30 @@ def test_calculate_invalid_input_types():
(10, "gpt-3.5-turbo", "input", Decimal("0.0000150")), # Example values
(5, "gpt-4", "output", Decimal("0.00030")), # Example values
(10, "ai21.j2-mid-v1", "input", Decimal("0.0001250")), # Example values
(100, "gpt-4o", "cached", Decimal("0.000125")), # Cache tokens test
],
)
def test_calculate_cost_by_tokens(num_tokens, model, token_type, expected_output):
"""Test that the token cost calculation is correct."""
cost = calculate_cost_by_tokens(num_tokens, model, token_type)
assert cost == expected_output


def test_calculate_cached_tokens_cost():
"""Test that cached tokens cost calculation works correctly."""
# Basic test for cache token cost calculation
model = "gpt-4o"
num_tokens = 1000
token_type = "cached"

# Get the expected cost from the TOKEN_COSTS dictionary
from tokencost.constants import TOKEN_COSTS
cache_cost_per_token = TOKEN_COSTS[model]["cache_read_input_token_cost"]
expected_cost = Decimal(str(cache_cost_per_token)) * Decimal(num_tokens)

# Calculate the actual cost
actual_cost = calculate_cost_by_tokens(num_tokens, model, token_type)

# Assert that the costs match
assert actual_cost == expected_cost
assert actual_cost > 0, "Cache token cost should be greater than zero"
38 changes: 32 additions & 6 deletions tokencost/costs.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import os
import tiktoken
import anthropic
from typing import Union, List, Dict
from typing import Union, List, Dict, Literal
from .constants import TOKEN_COSTS
from decimal import Decimal
import logging
Expand All @@ -16,6 +16,31 @@
# https://github.com/anthropics/anthropic-tokenizer-typescript/blob/main/index.ts


TokenType = Literal["input", "output", "cached"]


def _get_field_from_token_type(token_type: TokenType) -> str:
"""
Get the field name from the token type.

Args:
token_type (TokenType): The token type.

Returns:
str: The field name to use for the token cost data in the TOKEN_COSTS dictionary.
"""
lookups = {
"input": "input_cost_per_token",
"output": "output_cost_per_token",
"cached": "cache_read_input_token_cost",
}

try:
return lookups[token_type]
except KeyError:
raise ValueError(f"Invalid token type: {token_type}.")


def get_anthropic_token_count(messages: List[Dict[str, str]], model: str) -> int:
if not any(
supported_model in model
Expand Down Expand Up @@ -160,7 +185,7 @@ def count_string_tokens(prompt: str, model: str) -> int:
return len(encoding.encode(prompt))


def calculate_cost_by_tokens(num_tokens: int, model: str, token_type: str) -> Decimal:
def calculate_cost_by_tokens(num_tokens: int, model: str, token_type: TokenType) -> Decimal:
"""
Calculate the cost based on the number of tokens and the model.

Expand All @@ -179,10 +204,11 @@ def calculate_cost_by_tokens(num_tokens: int, model: str, token_type: str) -> De
Double-check your spelling, or submit an issue/PR"""
)

cost_per_token_key = (
"input_cost_per_token" if token_type == "input" else "output_cost_per_token"
)
cost_per_token = TOKEN_COSTS[model][cost_per_token_key]
try:
token_key = _get_field_from_token_type(token_type)
cost_per_token = TOKEN_COSTS[model][token_key]
except KeyError:
raise KeyError(f"Model {model} does not have cost data for `{token_type}` tokens.")

return Decimal(str(cost_per_token)) * Decimal(num_tokens)

Expand Down