Skip to content

Commit e488312

Browse files
fix(utils.py): log cache_creation_tokens in prompt token details
Closes LIT-907
1 parent 1162c52 commit e488312

File tree

5 files changed

+64
-7
lines changed

5 files changed

+64
-7
lines changed

litellm/litellm_core_utils/llm_cost_calc/utils.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,13 @@ def generic_cost_per_token(
278278
)
279279
or 0
280280
)
281+
cache_creation_tokens = (
282+
cast(
283+
Optional[int],
284+
getattr(usage.prompt_tokens_details, "cache_creation_tokens", 0),
285+
)
286+
or 0
287+
)
281288
text_tokens = (
282289
cast(
283290
Optional[int], getattr(usage.prompt_tokens_details, "text_tokens", None)
@@ -307,9 +314,8 @@ def generic_cost_per_token(
307314
or 0
308315
)
309316

310-
if getattr(usage, "_cache_creation_input_tokens", 0) is not None:
311-
cache_creation_tokens = usage._cache_creation_input_tokens
312317
## EDGE CASE - text tokens not set inside PromptTokensDetails
318+
313319
if text_tokens == 0:
314320
text_tokens = (
315321
usage.prompt_tokens
@@ -333,7 +339,7 @@ def generic_cost_per_token(
333339
)
334340

335341
### CACHE WRITING COST - Now uses tiered pricing
336-
prompt_cost += float(usage._cache_creation_input_tokens or 0) * cache_creation_cost
342+
prompt_cost += float(cache_creation_tokens) * cache_creation_cost
337343

338344
### CHARACTER COST
339345

litellm/types/utils.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,9 @@ class ModelInfoBase(ProviderSpecificModelInfo, total=False):
162162
SearchContextCostPerQuery
163163
] # Cost for using web search tool
164164
citation_cost_per_token: Optional[float] # Cost per citation token for Perplexity
165-
tiered_pricing: Optional[List[Dict[str, Any]]] # Tiered pricing structure for models like Dashscope
165+
tiered_pricing: Optional[
166+
List[Dict[str, Any]]
167+
] # Tiered pricing structure for models like Dashscope
166168
litellm_provider: Required[str]
167169
mode: Required[
168170
Literal[
@@ -880,6 +882,9 @@ class PromptTokensDetailsWrapper(
880882
video_length_seconds: Optional[float] = None
881883
"""Length of videos sent to the model. Used for Vertex AI multimodal embeddings."""
882884

885+
cache_creation_tokens: Optional[int] = None
886+
"""Number of cache creation tokens sent to the model. Used for Anthropic prompt caching."""
887+
883888
def __init__(self, *args, **kwargs):
884889
super().__init__(*args, **kwargs)
885890
if self.character_count is None:
@@ -890,6 +895,8 @@ def __init__(self, *args, **kwargs):
890895
del self.video_length_seconds
891896
if self.web_search_requests is None:
892897
del self.web_search_requests
898+
if self.cache_creation_tokens is None:
899+
del self.cache_creation_tokens
893900

894901

895902
class ServerToolUse(BaseModel):
@@ -951,6 +958,7 @@ def __init__(
951958
# handle prompt_tokens_details
952959
_prompt_tokens_details: Optional[PromptTokensDetailsWrapper] = None
953960

961+
# guarantee prompt_token_details is always a PromptTokensDetailsWrapper
954962
if prompt_tokens_details:
955963
if isinstance(prompt_tokens_details, dict):
956964
_prompt_tokens_details = PromptTokensDetailsWrapper(
@@ -985,6 +993,18 @@ def __init__(
985993
else:
986994
_prompt_tokens_details.cached_tokens = params["cache_read_input_tokens"]
987995

996+
if "cache_creation_input_tokens" in params and isinstance(
997+
params["cache_creation_input_tokens"], int
998+
):
999+
if _prompt_tokens_details is None:
1000+
_prompt_tokens_details = PromptTokensDetailsWrapper(
1001+
cache_creation_tokens=params["cache_creation_input_tokens"]
1002+
)
1003+
else:
1004+
_prompt_tokens_details.cache_creation_tokens = params[
1005+
"cache_creation_input_tokens"
1006+
]
1007+
9881008
super().__init__(
9891009
prompt_tokens=prompt_tokens or 0,
9901010
completion_tokens=completion_tokens or 0,

tests/llm_translation/base_llm_unit_tests.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -954,6 +954,7 @@ def test_image_url_string(self):
954954

955955
@pytest.mark.flaky(retries=4, delay=1)
956956
def test_prompt_caching(self):
957+
print("test_prompt_caching")
957958
litellm.set_verbose = True
958959
from litellm.utils import supports_prompt_caching
959960

@@ -1049,8 +1050,8 @@ def test_prompt_caching(self):
10491050
assert (
10501051
response.usage.prompt_tokens_details.cached_tokens > 0
10511052
), f"cached_tokens={response.usage.prompt_tokens_details.cached_tokens} should be greater than 0. Got usage={response.usage}"
1052-
except litellm.InternalServerError:
1053-
pass
1053+
except litellm.InternalServerError as e:
1054+
print("InternalServerError", e)
10541055

10551056
@pytest.fixture
10561057
def pdf_messages(self):

tests/local_testing/test_anthropic_prompt_caching.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -250,7 +250,7 @@ async def test_anthropic_api_prompt_caching_basic():
250250
"type": "text",
251251
"text": "Here is the full text of a complex legal agreement"
252252
* 400,
253-
"cache_control": {"type": "ephemeral"},
253+
: {"type": "ephemeral"},
254254
}
255255
],
256256
},
@@ -510,6 +510,7 @@ async def test_anthropic_api_prompt_caching_streaming():
510510
if hasattr(chunk, "usage") and hasattr(
511511
chunk.usage, "cache_creation_input_tokens"
512512
):
513+
print("chunk.usage", chunk.usage)
513514
is_cache_creation_input_tokens_in_usage = True
514515

515516
idx += 1

tests/test_litellm/litellm_core_utils/llm_cost_calc/test_llm_cost_calc_utils.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,35 @@ def test_generic_cost_per_token_anthropic_prompt_caching():
174174
assert prompt_cost < 0.085
175175

176176

177+
def test_generic_cost_per_token_anthropic_prompt_caching_with_cache_creation():
178+
model = "claude-3-5-haiku-20241022"
179+
usage = Usage(
180+
completion_tokens=90,
181+
prompt_tokens=28436,
182+
total_tokens=28526,
183+
completion_tokens_details=CompletionTokensDetailsWrapper(
184+
accepted_prediction_tokens=None,
185+
audio_tokens=None,
186+
reasoning_tokens=0,
187+
rejected_prediction_tokens=None,
188+
text_tokens=None,
189+
),
190+
prompt_tokens_details=None,
191+
cache_creation_input_tokens=2000,
192+
)
193+
194+
custom_llm_provider = "anthropic"
195+
196+
prompt_cost, completion_cost = generic_cost_per_token(
197+
model=model,
198+
usage=usage,
199+
custom_llm_provider=custom_llm_provider,
200+
)
201+
202+
print(f"prompt_cost: {prompt_cost}")
203+
assert round(prompt_cost, 3) == 0.023
204+
205+
177206
def test_string_cost_values():
178207
"""Test that cost values defined as strings are properly converted to floats."""
179208
from unittest.mock import patch

0 commit comments

Comments
 (0)