|
1 | 1 | #!/usr/bin/env python |
2 | 2 | # -*- coding: utf-8 -*- |
3 | 3 |
|
| 4 | +import os |
4 | 5 | import pytest |
5 | 6 | from decimal import Decimal |
| 7 | +from dotenv import load_dotenv |
| 8 | + |
| 9 | +# Load environment variables for ANTHROPIC_API_KEY |
| 10 | +load_dotenv() |
| 11 | + |
6 | 12 | from tokencost.costs import ( |
7 | 13 | count_message_tokens, |
8 | 14 | count_string_tokens, |
|
46 | 52 | ("gpt-4-vision-preview", 15), |
47 | 53 | ("gpt-4o", 15), |
48 | 54 | ("azure/gpt-4o", 15), |
49 | | - ("claude-3-opus-latest", 11), |
| 55 | + pytest.param("claude-3-opus-latest", 11, |
| 56 | + marks=pytest.mark.skipif( |
| 57 | + bool(os.getenv("ANTHROPIC_API_KEY")), |
| 58 | + reason="ANTHROPIC_API_KEY environment variable not set" |
| 59 | + )), |
50 | 60 | ], |
51 | 61 | ) |
52 | 62 | def test_count_message_tokens(model, expected_output): |
@@ -154,8 +164,12 @@ def test_count_string_invalid_model(): |
154 | 164 | (MESSAGES, "gpt-4-1106-preview", Decimal("0.00015")), |
155 | 165 | (MESSAGES, "gpt-4-vision-preview", Decimal("0.00015")), |
156 | 166 | (MESSAGES, "gpt-4o", Decimal("0.0000375")), |
157 | | - (MESSAGES, "azure/gpt-4o", Decimal("0.000075")), |
158 | | - (MESSAGES, "claude-3-opus-latest", Decimal("0.000165")), |
| 167 | + (MESSAGES, "azure/gpt-4o", Decimal("0.0000375")), |
| 168 | + pytest.param(MESSAGES, "claude-3-opus-latest", Decimal("0.000165"), |
| 169 | + marks=pytest.mark.skipif( |
| 170 | + bool(os.getenv("ANTHROPIC_API_KEY")), |
| 171 | + reason="ANTHROPIC_API_KEY environment variable not set" |
| 172 | + )), |
159 | 173 | (STRING, "text-embedding-ada-002", Decimal("0.0000004")), |
160 | 174 | ], |
161 | 175 | ) |
@@ -191,7 +205,7 @@ def test_invalid_prompt_format(): |
191 | 205 | (STRING, "gpt-4-1106-preview", Decimal("0.00012")), |
192 | 206 | (STRING, "gpt-4-vision-preview", Decimal("0.00012")), |
193 | 207 | (STRING, "gpt-4o", Decimal("0.00004")), |
194 | | - (STRING, "azure/gpt-4o", Decimal("0.000060")), |
| 208 | + (STRING, "azure/gpt-4o", Decimal("0.00004")), |
195 | 209 | # (STRING, "claude-3-opus-latest", Decimal("0.000096")), # NOTE: Claude only supports messages |
196 | 210 | (STRING, "text-embedding-ada-002", 0), |
197 | 211 | ], |
@@ -230,9 +244,30 @@ def test_calculate_invalid_input_types(): |
230 | 244 | (10, "gpt-3.5-turbo", "input", Decimal("0.0000150")), # Example values |
231 | 245 | (5, "gpt-4", "output", Decimal("0.00030")), # Example values |
232 | 246 | (10, "ai21.j2-mid-v1", "input", Decimal("0.0001250")), # Example values |
| 247 | + (100, "gpt-4o", "cached", Decimal("0.000125")), # Cache tokens test |
233 | 248 | ], |
234 | 249 | ) |
235 | 250 | def test_calculate_cost_by_tokens(num_tokens, model, token_type, expected_output): |
236 | 251 | """Test that the token cost calculation is correct.""" |
237 | 252 | cost = calculate_cost_by_tokens(num_tokens, model, token_type) |
238 | 253 | assert cost == expected_output |
| 254 | + |
| 255 | + |
| 256 | +def test_calculate_cached_tokens_cost(): |
| 257 | + """Test that cached tokens cost calculation works correctly.""" |
| 258 | + # Basic test for cache token cost calculation |
| 259 | + model = "gpt-4o" |
| 260 | + num_tokens = 1000 |
| 261 | + token_type = "cached" |
| 262 | + |
| 263 | + # Get the expected cost from the TOKEN_COSTS dictionary |
| 264 | + from tokencost.constants import TOKEN_COSTS |
| 265 | + cache_cost_per_token = TOKEN_COSTS[model]["cache_read_input_token_cost"] |
| 266 | + expected_cost = Decimal(str(cache_cost_per_token)) * Decimal(num_tokens) |
| 267 | + |
| 268 | + # Calculate the actual cost |
| 269 | + actual_cost = calculate_cost_by_tokens(num_tokens, model, token_type) |
| 270 | + |
| 271 | + # Assert that the costs match |
| 272 | + assert actual_cost == expected_cost |
| 273 | + assert actual_cost > 0, "Cache token cost should be greater than zero" |
0 commit comments