Skip to content

Commit accb936

Browse files
committed
Add Bedrock/Anthropic prompt caching
1 parent eeeb32d commit accb936

File tree

10 files changed

+1107
-49
lines changed

10 files changed

+1107
-49
lines changed
Lines changed: 169 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,169 @@
1+
"""
2+
Example message history processors for automatic cache point insertion.
3+
4+
This module demonstrates how to use message history processors to automatically
5+
insert CachePoint objects for prompt caching optimization.
6+
"""
7+
8+
from typing import Callable
9+
10+
from pydantic_ai.messages import CachePoint, ModelMessage, ModelRequest, SystemPromptPart, UserPromptPart
11+
12+
13+
def cache_system_prompt_processor(messages: list[ModelMessage]) -> list[ModelMessage]:
14+
"""Add cache point after the last system prompt.
15+
16+
This processor finds the last system prompt in the message history and
17+
adds a cache point to the beginning of the next user message, effectively
18+
caching all system prompts.
19+
20+
Args:
21+
messages: List of model messages to process
22+
23+
Returns:
24+
Modified list of messages with cache points added
25+
"""
26+
result = []
27+
last_system_idx = -1
28+
29+
for i, message in enumerate(messages):
30+
if isinstance(message, ModelRequest):
31+
for part in message.parts:
32+
if isinstance(part, SystemPromptPart):
33+
last_system_idx = i
34+
result.append(message)
35+
36+
# Insert cache point after last system prompt
37+
if last_system_idx >= 0 and last_system_idx < len(result) - 1:
38+
next_message = result[last_system_idx + 1]
39+
if isinstance(next_message, ModelRequest):
40+
for part in next_message.parts:
41+
if isinstance(part, UserPromptPart) and isinstance(part.content, list):
42+
part.content.insert(0, CachePoint())
43+
break
44+
elif isinstance(part, UserPromptPart) and isinstance(part.content, str):
45+
# Convert string content to list and add cache point
46+
part.content = [CachePoint(), part.content]
47+
break
48+
49+
return result
50+
51+
52+
def cache_long_context_processor(min_tokens: int = 1024) -> Callable[[list[ModelMessage]], list[ModelMessage]]:
53+
"""Add cache points before content that likely exceeds token threshold.
54+
55+
This is a simplified example that estimates content length. In a real
56+
implementation, you would want to use a proper tokenizer for accurate counts.
57+
58+
Args:
59+
min_tokens: Minimum estimated tokens before adding a cache point
60+
61+
Returns:
62+
A processor function that adds cache points for long content
63+
"""
64+
def processor(messages: list[ModelMessage]) -> list[ModelMessage]:
65+
result = []
66+
67+
for message in messages:
68+
if isinstance(message, ModelRequest):
69+
for part in message.parts:
70+
if isinstance(part, UserPromptPart):
71+
if isinstance(part.content, str):
72+
# Simple estimation: ~4 characters per token
73+
if len(part.content) > min_tokens * 4:
74+
part.content = [CachePoint(), part.content]
75+
elif isinstance(part.content, list):
76+
# Look for large text blocks
77+
for i, item in enumerate(part.content):
78+
if isinstance(item, str) and len(item) > min_tokens * 4:
79+
# Insert cache point before large text
80+
part.content.insert(i, CachePoint())
81+
break
82+
result.append(message)
83+
84+
return result
85+
86+
return processor
87+
88+
89+
def cache_document_context_processor(messages: list[ModelMessage]) -> list[ModelMessage]:
90+
"""Add cache points after document content.
91+
92+
This processor adds cache points after any document or binary content
93+
to cache large context documents.
94+
95+
Args:
96+
messages: List of model messages to process
97+
98+
Returns:
99+
Modified list of messages with cache points added
100+
"""
101+
result = []
102+
103+
for message in messages:
104+
if isinstance(message, ModelRequest):
105+
for part in message.parts:
106+
if isinstance(part, UserPromptPart) and isinstance(part.content, list):
107+
new_content = []
108+
for item in part.content:
109+
new_content.append(item)
110+
# Add cache point after document/binary content
111+
if hasattr(item, 'media_type') or hasattr(item, 'data'):
112+
new_content.append(CachePoint())
113+
part.content = new_content
114+
result.append(message)
115+
116+
return result
117+
118+
119+
def cache_conversation_turns_processor(messages: list[ModelMessage]) -> list[ModelMessage]:
120+
"""Add cache points at regular conversation intervals.
121+
122+
This processor adds cache points every few conversation turns to cache
123+
conversational context progressively.
124+
125+
Args:
126+
messages: List of model messages to process
127+
128+
Returns:
129+
Modified list of messages with cache points added
130+
"""
131+
result = []
132+
turn_count = 0
133+
134+
for message in messages:
135+
if isinstance(message, ModelRequest):
136+
for part in message.parts:
137+
if isinstance(part, UserPromptPart):
138+
turn_count += 1
139+
# Add cache point every 3 turns
140+
if turn_count % 3 == 0:
141+
if isinstance(part.content, str):
142+
part.content = [CachePoint(), part.content]
143+
elif isinstance(part.content, list):
144+
part.content.insert(0, CachePoint())
145+
result.append(message)
146+
147+
return result
148+
149+
150+
def multi_level_cache_processor(messages: list[ModelMessage]) -> list[ModelMessage]:
151+
"""Example of multiple cache points for hierarchical caching.
152+
153+
This processor demonstrates adding multiple cache points at different levels:
154+
- After system prompts
155+
- After large context
156+
- At conversation intervals
157+
158+
Args:
159+
messages: List of model messages to process
160+
161+
Returns:
162+
Modified list of messages with cache points added
163+
"""
164+
# Apply multiple processors in sequence
165+
processed = cache_system_prompt_processor(messages)
166+
processed = cache_long_context_processor(512)(processed)
167+
processed = cache_conversation_turns_processor(processed)
168+
169+
return processed

0 commit comments

Comments
 (0)