Skip to content

Commit 4f85e86

Browse files
lizzijcopybara-github
authored andcommitted
feat: Add support for extracting cache-related token counts from LiteLLM usage
Closes #3049 Co-authored-by: Eliza Huang <[email protected]> PiperOrigin-RevId: 828091671
1 parent abdc2bb commit 4f85e86

File tree

2 files changed

+103
-7
lines changed

2 files changed

+103
-7
lines changed

src/google/adk/models/lite_llm.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@ class UsageMetadataChunk(BaseModel):
101101
prompt_tokens: int
102102
completion_tokens: int
103103
total_tokens: int
104+
cached_prompt_tokens: int = 0
104105

105106

106107
class LiteLLMClient:
@@ -217,6 +218,59 @@ def _append_fallback_user_content_if_missing(
217218
)
218219

219220

221+
def _extract_cached_prompt_tokens(usage: Any) -> int:
222+
"""Extracts cached prompt tokens from LiteLLM usage.
223+
224+
Providers expose cached token metrics in different shapes. Common patterns:
225+
- usage["prompt_tokens_details"]["cached_tokens"] (OpenAI/Azure style)
226+
- usage["prompt_tokens_details"] is a list of dicts with cached_tokens
227+
- usage["cached_prompt_tokens"] (LiteLLM-normalized for some providers)
228+
- usage["cached_tokens"] (flat)
229+
230+
Args:
231+
usage: Usage dictionary from LiteLLM response.
232+
233+
Returns:
234+
Integer number of cached prompt tokens if present; otherwise 0.
235+
"""
236+
try:
237+
usage_dict = usage
238+
if hasattr(usage, "model_dump"):
239+
usage_dict = usage.model_dump()
240+
elif isinstance(usage, str):
241+
try:
242+
usage_dict = json.loads(usage)
243+
except json.JSONDecodeError:
244+
return 0
245+
246+
if not isinstance(usage_dict, dict):
247+
return 0
248+
249+
details = usage_dict.get("prompt_tokens_details")
250+
if isinstance(details, dict):
251+
value = details.get("cached_tokens")
252+
if isinstance(value, int):
253+
return value
254+
elif isinstance(details, list):
255+
total = sum(
256+
item.get("cached_tokens", 0)
257+
for item in details
258+
if isinstance(item, dict)
259+
and isinstance(item.get("cached_tokens"), int)
260+
)
261+
if total > 0:
262+
return total
263+
264+
for key in ("cached_prompt_tokens", "cached_tokens"):
265+
value = usage_dict.get(key)
266+
if isinstance(value, int):
267+
return value
268+
except (TypeError, AttributeError) as e:
269+
logger.debug("Error extracting cached prompt tokens: %s", e)
270+
271+
return 0
272+
273+
220274
def _content_to_message_param(
221275
content: types.Content,
222276
) -> Union[Message, list[Message]]:
@@ -533,6 +587,7 @@ def _model_response_to_chunk(
533587
prompt_tokens=response["usage"].get("prompt_tokens", 0),
534588
completion_tokens=response["usage"].get("completion_tokens", 0),
535589
total_tokens=response["usage"].get("total_tokens", 0),
590+
cached_prompt_tokens=_extract_cached_prompt_tokens(response["usage"]),
536591
), None
537592

538593

@@ -576,6 +631,9 @@ def _model_response_to_generate_content_response(
576631
prompt_token_count=response["usage"].get("prompt_tokens", 0),
577632
candidates_token_count=response["usage"].get("completion_tokens", 0),
578633
total_token_count=response["usage"].get("total_tokens", 0),
634+
cached_content_token_count=_extract_cached_prompt_tokens(
635+
response["usage"]
636+
),
579637
)
580638
return llm_response
581639

@@ -965,6 +1023,7 @@ async def generate_content_async(
9651023
prompt_token_count=chunk.prompt_tokens,
9661024
candidates_token_count=chunk.completion_tokens,
9671025
total_token_count=chunk.total_tokens,
1026+
cached_content_token_count=chunk.cached_prompt_tokens,
9681027
)
9691028

9701029
if (

tests/unittests/models/test_litellm.py

Lines changed: 44 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1063,6 +1063,7 @@ async def test_generate_content_async_with_usage_metadata(
10631063
"prompt_tokens": 10,
10641064
"completion_tokens": 5,
10651065
"total_tokens": 15,
1066+
"cached_tokens": 8,
10661067
},
10671068
)
10681069
mock_acompletion.return_value = mock_response_with_usage_metadata
@@ -1083,6 +1084,7 @@ async def test_generate_content_async_with_usage_metadata(
10831084
assert response.usage_metadata.prompt_token_count == 10
10841085
assert response.usage_metadata.candidates_token_count == 5
10851086
assert response.usage_metadata.total_token_count == 15
1087+
assert response.usage_metadata.cached_content_token_count == 8
10861088

10871089
mock_acompletion.assert_called_once()
10881090

@@ -1718,37 +1720,42 @@ async def test_generate_content_async_stream_with_usage_metadata(
17181720

17191721

17201722
@pytest.mark.asyncio
1721-
async def test_generate_content_async_stream_with_usage_metadata_only(
1723+
async def test_generate_content_async_stream_with_usage_metadata(
17221724
mock_completion, lite_llm_instance
17231725
):
1726+
"""Tests that cached prompt tokens are propagated in streaming mode."""
17241727
streaming_model_response_with_usage_metadata = [
1728+
*STREAMING_MODEL_RESPONSE,
17251729
ModelResponse(
17261730
usage={
17271731
"prompt_tokens": 10,
17281732
"completion_tokens": 5,
17291733
"total_tokens": 15,
1734+
"cached_tokens": 8,
17301735
},
17311736
choices=[
17321737
StreamingChoices(
1733-
finish_reason="stop",
1734-
delta=Delta(content=""),
1738+
finish_reason=None,
17351739
)
17361740
],
17371741
),
17381742
]
1743+
17391744
mock_completion.return_value = iter(
17401745
streaming_model_response_with_usage_metadata
17411746
)
17421747

1743-
unused_responses = [
1748+
responses = [
17441749
response
17451750
async for response in lite_llm_instance.generate_content_async(
17461751
LLM_REQUEST_WITH_FUNCTION_DECLARATION, stream=True
17471752
)
17481753
]
1749-
mock_completion.assert_called_once()
1750-
_, kwargs = mock_completion.call_args
1751-
assert kwargs["stream_options"] == {"include_usage": True}
1754+
assert len(responses) == 4
1755+
assert responses[3].usage_metadata.prompt_token_count == 10
1756+
assert responses[3].usage_metadata.candidates_token_count == 5
1757+
assert responses[3].usage_metadata.total_token_count == 15
1758+
assert responses[3].usage_metadata.cached_content_token_count == 8
17521759

17531760

17541761
@pytest.mark.asyncio
@@ -2038,6 +2045,36 @@ def test_function_declaration_to_tool_param_edge_cases():
20382045
assert "required" not in result["function"]["parameters"]
20392046

20402047

2048+
@pytest.mark.parametrize(
2049+
"usage, expected_tokens",
2050+
[
2051+
({"prompt_tokens_details": {"cached_tokens": 123}}, 123),
2052+
(
2053+
{
2054+
"prompt_tokens_details": [
2055+
{"cached_tokens": 50},
2056+
{"cached_tokens": 25},
2057+
]
2058+
},
2059+
75,
2060+
),
2061+
({"cached_prompt_tokens": 45}, 45),
2062+
({"cached_tokens": 67}, 67),
2063+
({"prompt_tokens": 100}, 0),
2064+
({}, 0),
2065+
("not a dict", 0),
2066+
(None, 0),
2067+
({"prompt_tokens_details": {"cached_tokens": "not a number"}}, 0),
2068+
(json.dumps({"cached_tokens": 89}), 89),
2069+
(json.dumps({"some_key": "some_value"}), 0),
2070+
],
2071+
)
2072+
def test_extract_cached_prompt_tokens(usage, expected_tokens):
2073+
from google.adk.models.lite_llm import _extract_cached_prompt_tokens
2074+
2075+
assert _extract_cached_prompt_tokens(usage) == expected_tokens
2076+
2077+
20412078
def test_gemini_via_litellm_warning(monkeypatch):
20422079
"""Test that Gemini via LiteLLM shows warning."""
20432080
# Ensure environment variable is not set

0 commit comments

Comments
 (0)