Skip to content

Commit 84c745c

Browse files
committed
Add Bedrock/Anthropic prompt caching
1 parent eeeb32d commit 84c745c

File tree

10 files changed

+1176
-48
lines changed

10 files changed

+1176
-48
lines changed
Lines changed: 181 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,181 @@
1+
"""Example message history processors for automatic cache point insertion.
2+
3+
This module demonstrates how to use message history processors to automatically
4+
insert CachePoint objects for prompt caching optimization.
5+
"""
6+
7+
from typing import Callable
8+
9+
from pydantic_ai.messages import (
10+
CachePoint,
11+
ModelMessage,
12+
ModelRequest,
13+
SystemPromptPart,
14+
UserPromptPart,
15+
)
16+
17+
18+
def cache_system_prompt_processor(messages: list[ModelMessage]) -> list[ModelMessage]:
19+
"""Add cache point after the last system prompt.
20+
21+
This processor finds the last system prompt in the message history and
22+
adds a cache point to the beginning of the next user message, effectively
23+
caching all system prompts.
24+
25+
Args:
26+
messages: List of model messages to process
27+
28+
Returns:
29+
Modified list of messages with cache points added
30+
"""
31+
result = []
32+
last_system_idx = -1
33+
34+
for i, message in enumerate(messages):
35+
if isinstance(message, ModelRequest):
36+
for part in message.parts:
37+
if isinstance(part, SystemPromptPart):
38+
last_system_idx = i
39+
result.append(message)
40+
41+
# Insert cache point after last system prompt
42+
if last_system_idx >= 0 and last_system_idx < len(result) - 1:
43+
next_message = result[last_system_idx + 1]
44+
if isinstance(next_message, ModelRequest):
45+
for part in next_message.parts:
46+
if isinstance(part, UserPromptPart) and isinstance(part.content, list):
47+
part.content.insert(0, CachePoint())
48+
break
49+
elif isinstance(part, UserPromptPart) and isinstance(part.content, str):
50+
# Convert string content to list and add cache point
51+
part.content = [CachePoint(), part.content]
52+
break
53+
54+
return result
55+
56+
57+
def cache_long_context_processor(
58+
min_tokens: int = 1024,
59+
) -> Callable[[list[ModelMessage]], list[ModelMessage]]:
60+
"""Add cache points before content that likely exceeds token threshold.
61+
62+
This is a simplified example that estimates content length. In a real
63+
implementation, you would want to use a proper tokenizer for accurate counts.
64+
65+
Args:
66+
min_tokens: Minimum estimated tokens before adding a cache point
67+
68+
Returns:
69+
A processor function that adds cache points for long content
70+
"""
71+
72+
def processor(messages: list[ModelMessage]) -> list[ModelMessage]:
73+
result = []
74+
75+
for message in messages:
76+
if isinstance(message, ModelRequest):
77+
for part in message.parts:
78+
if isinstance(part, UserPromptPart):
79+
if isinstance(part.content, str):
80+
# Simple estimation: ~4 characters per token
81+
if len(part.content) > min_tokens * 4:
82+
part.content = [CachePoint(), part.content]
83+
elif isinstance(part.content, list):
84+
# Look for large text blocks
85+
for i, item in enumerate(part.content):
86+
if isinstance(item, str) and len(item) > min_tokens * 4:
87+
# Insert cache point before large text
88+
part.content.insert(i, CachePoint())
89+
break
90+
result.append(message)
91+
92+
return result
93+
94+
return processor
95+
96+
97+
def cache_document_context_processor(
98+
messages: list[ModelMessage],
99+
) -> list[ModelMessage]:
100+
"""Add cache points after document content.
101+
102+
This processor adds cache points after any document or binary content
103+
to cache large context documents.
104+
105+
Args:
106+
messages: List of model messages to process
107+
108+
Returns:
109+
Modified list of messages with cache points added
110+
"""
111+
result = []
112+
113+
for message in messages:
114+
if isinstance(message, ModelRequest):
115+
for part in message.parts:
116+
if isinstance(part, UserPromptPart) and isinstance(part.content, list):
117+
new_content = []
118+
for item in part.content:
119+
new_content.append(item)
120+
# Add cache point after document/binary content
121+
if hasattr(item, 'media_type') or hasattr(item, 'data'):
122+
new_content.append(CachePoint())
123+
part.content = new_content
124+
result.append(message)
125+
126+
return result
127+
128+
129+
def cache_conversation_turns_processor(
130+
messages: list[ModelMessage],
131+
) -> list[ModelMessage]:
132+
"""Add cache points at regular conversation intervals.
133+
134+
This processor adds cache points every few conversation turns to cache
135+
conversational context progressively.
136+
137+
Args:
138+
messages: List of model messages to process
139+
140+
Returns:
141+
Modified list of messages with cache points added
142+
"""
143+
result = []
144+
turn_count = 0
145+
146+
for message in messages:
147+
if isinstance(message, ModelRequest):
148+
for part in message.parts:
149+
if isinstance(part, UserPromptPart):
150+
turn_count += 1
151+
# Add cache point every 3 turns
152+
if turn_count % 3 == 0:
153+
if isinstance(part.content, str):
154+
part.content = [CachePoint(), part.content]
155+
elif isinstance(part.content, list):
156+
part.content.insert(0, CachePoint())
157+
result.append(message)
158+
159+
return result
160+
161+
162+
def multi_level_cache_processor(messages: list[ModelMessage]) -> list[ModelMessage]:
163+
"""Example of multiple cache points for hierarchical caching.
164+
165+
This processor demonstrates adding multiple cache points at different levels:
166+
- After system prompts
167+
- After large context
168+
- At conversation intervals
169+
170+
Args:
171+
messages: List of model messages to process
172+
173+
Returns:
174+
Modified list of messages with cache points added
175+
"""
176+
# Apply multiple processors in sequence
177+
processed = cache_system_prompt_processor(messages)
178+
processed = cache_long_context_processor(512)(processed)
179+
processed = cache_conversation_turns_processor(processed)
180+
181+
return processed

0 commit comments

Comments
 (0)