diff --git a/src/google/adk/models/lite_llm.py b/src/google/adk/models/lite_llm.py index d8b4d7ce81..3008d02a09 100644 --- a/src/google/adk/models/lite_llm.py +++ b/src/google/adk/models/lite_llm.py @@ -85,6 +85,8 @@ class UsageMetadataChunk(BaseModel): prompt_tokens: int completion_tokens: int total_tokens: int + # Cached tokens served from provider cache for the prompt portion + cached_prompt_tokens: int = 0 class LiteLLMClient: @@ -154,6 +156,46 @@ def _safe_json_serialize(obj) -> str: return str(obj) +def _extract_cached_prompt_tokens(usage: dict[str, Any]) -> int: + """Best-effort extraction of cached prompt tokens from LiteLLM usage. + + Providers expose cached token metrics in different shapes. Common patterns: + - usage["prompt_tokens_details"]["cached_tokens"] (OpenAI/Azure style) + - usage["prompt_tokens_details"] is a list of dicts with cached_tokens + - usage["cached_prompt_tokens"] (LiteLLM-normalized for some providers) + - usage["cached_tokens"] (flat) + + Args: + usage: Usage dictionary from LiteLLM response. + + Returns: + Integer number of cached prompt tokens if present; otherwise 0. + """ + + try: + details = usage.get("prompt_tokens_details") + if isinstance(details, dict): + value = details.get("cached_tokens") + if isinstance(value, int): + return value + elif isinstance(details, list): + total = 0 + for item in details: + if isinstance(item, dict) and isinstance(item.get("cached_tokens"), int): + total += item["cached_tokens"] + if total: + return total + + for key in ("cached_prompt_tokens", "cached_tokens"): + value = usage.get(key) + if isinstance(value, int): + return value + except Exception: # noqa: BLE001 - defensive: provider-specific shapes vary + pass + + return 0 + + def _content_to_message_param( content: types.Content, ) -> Union[Message, list[Message]]: @@ -463,10 +505,12 @@ def _model_response_to_chunk( # finish_reason set. But this is not the case we are observing from litellm. # So we are sending it as a separate chunk to be set on the llm_response. if response.get("usage", None): + usage_dict = response["usage"] yield UsageMetadataChunk( - prompt_tokens=response["usage"].get("prompt_tokens", 0), - completion_tokens=response["usage"].get("completion_tokens", 0), - total_tokens=response["usage"].get("total_tokens", 0), + prompt_tokens=usage_dict.get("prompt_tokens", 0), + completion_tokens=usage_dict.get("completion_tokens", 0), + total_tokens=usage_dict.get("total_tokens", 0), + cached_prompt_tokens=_extract_cached_prompt_tokens(usage_dict), ), None @@ -491,10 +535,12 @@ def _model_response_to_generate_content_response( llm_response = _message_to_generate_content_response(message) if response.get("usage", None): + usage_dict = response["usage"] llm_response.usage_metadata = types.GenerateContentResponseUsageMetadata( - prompt_token_count=response["usage"].get("prompt_tokens", 0), - candidates_token_count=response["usage"].get("completion_tokens", 0), - total_token_count=response["usage"].get("total_tokens", 0), + prompt_token_count=usage_dict.get("prompt_tokens", 0), + candidates_token_count=usage_dict.get("completion_tokens", 0), + total_token_count=usage_dict.get("total_tokens", 0), + cached_content_token_count=_extract_cached_prompt_tokens(usage_dict), ) return llm_response @@ -874,6 +920,7 @@ async def generate_content_async( prompt_token_count=chunk.prompt_tokens, candidates_token_count=chunk.completion_tokens, total_token_count=chunk.total_tokens, + cached_content_token_count=chunk.cached_prompt_tokens, ) if ( diff --git a/tests/unittests/models/test_litellm.py b/tests/unittests/models/test_litellm.py index 84fd7f26d0..4c3883430f 100644 --- a/tests/unittests/models/test_litellm.py +++ b/tests/unittests/models/test_litellm.py @@ -1317,6 +1317,88 @@ def test_model_response_to_chunk(response, expected_chunks, expected_finished): assert usage_chunk.total_tokens == expected_chunks[1].total_tokens +@pytest.mark.asyncio +async def test_generate_content_async_with_cached_tokens_non_stream( + lite_llm_instance, mock_acompletion +): + # Simulate LiteLLM usage shapes that include cached tokens + mock_response_with_cached_usage = ModelResponse( + choices=[ + Choices( + message=ChatCompletionAssistantMessage( + role="assistant", + content="Test response", + ) + ) + ], + usage={ + "prompt_tokens": 2100, + "completion_tokens": 50, + "total_tokens": 2150, + # Common provider shapes + "prompt_tokens_details": {"cached_tokens": 1800}, + }, + ) + mock_acompletion.return_value = mock_response_with_cached_usage + + llm_request = LlmRequest( + contents=[ + types.Content(role="user", parts=[types.Part.from_text(text="q")]) + ] + ) + + results = [ + r async for r in lite_llm_instance.generate_content_async(llm_request) + ] + assert len(results) == 1 + resp = results[0] + assert resp.usage_metadata is not None + assert resp.usage_metadata.prompt_token_count == 2100 + assert resp.usage_metadata.candidates_token_count == 50 + assert resp.usage_metadata.total_token_count == 2150 + # Key assertion: cached_content_token_count is propagated + assert resp.usage_metadata.cached_content_token_count == 1800 + + +@pytest.mark.asyncio +async def test_generate_content_async_with_cached_tokens_stream( + mock_completion, lite_llm_instance +): + # Build a stream with final usage chunk that includes cached tokens + streaming_with_cached_usage = [ + *STREAMING_MODEL_RESPONSE, + ModelResponse( + usage={ + "prompt_tokens": 2100, + "completion_tokens": 50, + "total_tokens": 2150, + # Alternative flattened shape + "cached_prompt_tokens": 1700, + }, + choices=[StreamingChoices(finish_reason=None)], + ), + ] + mock_completion.return_value = iter(streaming_with_cached_usage) + + llm_request = LlmRequest( + contents=[ + types.Content(role="user", parts=[types.Part.from_text(text="q")]) + ] + ) + responses = [ + r async for r in lite_llm_instance.generate_content_async( + llm_request, stream=True + ) + ] + # Final aggregated response carries usage + assert len(responses) == 4 + final_resp = responses[-1] + assert final_resp.usage_metadata is not None + assert final_resp.usage_metadata.prompt_token_count == 2100 + assert final_resp.usage_metadata.total_token_count == 2150 + assert final_resp.usage_metadata.cached_content_token_count == 1700 + + @pytest.mark.asyncio async def test_acompletion_additional_args(mock_acompletion, mock_client): lite_llm_instance = LiteLlm(