Skip to content

Commit c5b976b

Browse files
seanzhougooglecopybara-github
authored andcommitted
chore: Create the context cache based on the token count of previous request
before this change, we estimate the token count of the contents to cache and use it to compare with the threshold user set. but that's not precise , so we use the actual prompt token count of previous llm request. We won't create cache for the very initial request PiperOrigin-RevId: 814484840
1 parent 420df25 commit c5b976b

File tree

5 files changed

+323
-22
lines changed

5 files changed

+323
-22
lines changed

src/google/adk/flows/llm_flows/context_cache_processor.py

Lines changed: 46 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -62,9 +62,11 @@ async def run_async(
6262
# Set cache config to request
6363
llm_request.cache_config = invocation_context.context_cache_config
6464

65-
# Find latest cache metadata from session events
66-
latest_cache_metadata = self._find_latest_cache_metadata(
67-
invocation_context, agent.name, invocation_context.invocation_id
65+
# Find latest cache metadata and previous token count from session events
66+
latest_cache_metadata, previous_token_count = (
67+
self._find_cache_info_from_events(
68+
invocation_context, agent.name, invocation_context.invocation_id
69+
)
6870
)
6971

7072
if latest_cache_metadata:
@@ -77,51 +79,78 @@ async def run_async(
7779
latest_cache_metadata.cached_contents_count,
7880
)
7981

82+
if previous_token_count is not None:
83+
llm_request.cacheable_contents_token_count = previous_token_count
84+
logger.debug(
85+
'Found previous prompt token count for agent %s: %d',
86+
agent.name,
87+
previous_token_count,
88+
)
89+
8090
logger.debug('Context caching enabled for agent %s', agent.name)
8191

8292
# This processor yields no events
8393
return
8494
yield # AsyncGenerator requires a yield in function body
8595

86-
def _find_latest_cache_metadata(
96+
def _find_cache_info_from_events(
8797
self,
8898
invocation_context: 'InvocationContext',
8999
agent_name: str,
90100
current_invocation_id: str,
91-
) -> Optional[CacheMetadata]:
92-
"""Find the latest cache metadata from session events.
101+
) -> tuple[Optional[CacheMetadata], Optional[int]]:
102+
"""Find cache metadata and previous token count from session events.
93103
94104
Args:
95105
invocation_context: Context containing session with events
96-
agent_name: Name of agent to find cache metadata for
106+
agent_name: Name of agent to find cache info for
97107
current_invocation_id: Current invocation ID to compare for increment
98108
99109
Returns:
100-
Latest cache metadata for the agent (with updated invocations_used
101-
if needed), or None if not found
110+
Tuple of (cache_metadata, previous_prompt_token_count)
111+
cache_metadata: Latest cache metadata with updated invocations_used if needed
112+
previous_prompt_token_count: Most recent prompt token count from LLM response
102113
"""
103114
if not invocation_context.session or not invocation_context.session.events:
104-
return None
115+
return None, None
116+
117+
cache_metadata = None
118+
previous_token_count = None
105119

106120
# Search events from most recent to oldest using index traversal
107121
events = invocation_context.session.events
108122
for i in range(len(events) - 1, -1, -1):
109123
event = events[i]
110-
if event.cache_metadata is not None and event.author == agent_name:
111-
112-
cache_metadata = event.cache_metadata
124+
if event.author != agent_name:
125+
continue
113126

127+
# Look for cache metadata (only in actual LLM response events)
128+
if cache_metadata is None and event.cache_metadata is not None:
114129
# Check if this is a different invocation - increment invocations_used
115130
if event.invocation_id and event.invocation_id != current_invocation_id:
116131
# Different invocation - increment invocations_used
117-
return cache_metadata.model_copy(
118-
update={'invocations_used': cache_metadata.invocations_used + 1}
132+
cache_metadata = event.cache_metadata.model_copy(
133+
update={
134+
'invocations_used': event.cache_metadata.invocations_used + 1
135+
}
119136
)
120137
else:
121138
# Same invocation or no invocation_id - return as-is
122-
return cache_metadata
139+
cache_metadata = event.cache_metadata
140+
141+
# Look for previous prompt token count (from actual LLM response events)
142+
if (
143+
previous_token_count is None
144+
and event.usage_metadata
145+
and event.usage_metadata.prompt_token_count is not None
146+
):
147+
previous_token_count = event.usage_metadata.prompt_token_count
148+
149+
# Stop early if we found both pieces of information
150+
if cache_metadata is not None and previous_token_count is not None:
151+
break
123152

124-
return None
153+
return cache_metadata, previous_token_count
125154

126155

127156
# Create processor instance for use in flows

src/google/adk/models/gemini_context_cache_manager.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -257,12 +257,21 @@ async def _create_new_cache_with_contents(
257257
Returns:
258258
Cache metadata if successful, None otherwise
259259
"""
260-
# Estimate token count for minimum cache size check
261-
estimated_tokens = self._estimate_request_tokens(llm_request)
262-
if estimated_tokens < llm_request.cache_config.min_tokens:
260+
# Check if we have token count from previous response for cache size validation
261+
if llm_request.cacheable_contents_token_count is None:
263262
logger.info(
264-
"Request too small for caching (%d < %d tokens)",
265-
estimated_tokens,
263+
"No previous token count available, skipping cache creation for"
264+
" initial request"
265+
)
266+
return None
267+
268+
if (
269+
llm_request.cacheable_contents_token_count
270+
< llm_request.cache_config.min_tokens
271+
):
272+
logger.info(
273+
"Previous request too small for caching (%d < %d tokens)",
274+
llm_request.cacheable_contents_token_count,
266275
llm_request.cache_config.min_tokens,
267276
)
268277
return None

src/google/adk/models/llm_request.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,9 @@ class LlmRequest(BaseModel):
8888
cache_metadata: Optional[CacheMetadata] = None
8989
"""Cache metadata from previous requests, used for cache management."""
9090

91+
cacheable_contents_token_count: Optional[int] = None
92+
"""Token count from previous request's prompt, used for cache size validation."""
93+
9194
def append_instructions(
9295
self, instructions: Union[list[str], types.Content]
9396
) -> list[types.Content]:

tests/unittests/agents/test_gemini_context_cache_manager.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,9 @@ async def test_handle_context_caching_new_cache(self):
121121
)
122122

123123
llm_request = self.create_llm_request()
124+
llm_request.cacheable_contents_token_count = (
125+
2048 # Add token count for cache creation
126+
)
124127
start_time = time.time()
125128

126129
with patch.object(
@@ -194,6 +197,9 @@ async def test_handle_context_caching_invalid_existing_cache(self):
194197
invocations_used=15
195198
) # Exceeds cache_intervals
196199
llm_request = self.create_llm_request(cache_metadata=existing_cache)
200+
llm_request.cacheable_contents_token_count = (
201+
2048 # Add token count for cache creation
202+
)
197203

198204
with (
199205
patch.object(self.manager, "_is_cache_valid", return_value=False),
@@ -521,3 +527,65 @@ def test_parameter_types_enforcement(self):
521527
assert not hasattr(
522528
cache_metadata, "usage_metadata"
523529
) # CacheMetadata should NOT have this
530+
531+
def create_llm_request_with_token_count(
532+
self, token_count=None, cache_metadata=None
533+
):
534+
"""Helper to create LlmRequest with cacheable_contents_token_count."""
535+
llm_request = self.create_llm_request(cache_metadata=cache_metadata)
536+
llm_request.cacheable_contents_token_count = token_count
537+
return llm_request
538+
539+
async def test_cache_creation_with_sufficient_token_count(self):
540+
"""Test cache creation succeeds when token count meets minimum."""
541+
# Setup mocks
542+
mock_cached_content = AsyncMock()
543+
mock_cached_content.name = (
544+
"projects/test/locations/us-central1/cachedContents/token123"
545+
)
546+
self.manager.genai_client.aio.caches.create = AsyncMock(
547+
return_value=mock_cached_content
548+
)
549+
550+
# Create request with sufficient token count
551+
llm_request = self.create_llm_request_with_token_count(token_count=2048)
552+
553+
with patch.object(
554+
self.manager, "_generate_cache_fingerprint", return_value="test_fp"
555+
):
556+
result = await self.manager.handle_context_caching(llm_request)
557+
558+
# Should succeed in creating cache
559+
assert result is not None
560+
assert result.cache_name == mock_cached_content.name
561+
self.manager.genai_client.aio.caches.create.assert_called_once()
562+
563+
async def test_cache_creation_with_insufficient_token_count(self):
564+
"""Test cache creation fails when token count is below minimum."""
565+
# Set higher minimum token requirement
566+
self.manager.cache_config = ContextCacheConfig(
567+
cache_intervals=10,
568+
ttl_seconds=1800,
569+
min_tokens=2048,
570+
)
571+
572+
# Create request with insufficient token count
573+
llm_request = self.create_llm_request_with_token_count(token_count=1024)
574+
llm_request.cache_config = self.manager.cache_config
575+
576+
result = await self.manager.handle_context_caching(llm_request)
577+
578+
# Should not create cache
579+
assert result is None
580+
self.manager.genai_client.aio.caches.create.assert_not_called()
581+
582+
async def test_cache_creation_without_token_count(self):
583+
"""Test cache creation is skipped when no token count is available."""
584+
# Create request without token count (initial request)
585+
llm_request = self.create_llm_request_with_token_count(token_count=None)
586+
587+
result = await self.manager.handle_context_caching(llm_request)
588+
589+
# Should skip cache creation for initial request
590+
assert result is None
591+
self.manager.genai_client.aio.caches.create.assert_not_called()

0 commit comments

Comments
 (0)