diff --git a/examples/pydantic_ai_examples/test_prompt_caching.py b/examples/pydantic_ai_examples/test_prompt_caching.py new file mode 100644 index 000000000..f0dee784a --- /dev/null +++ b/examples/pydantic_ai_examples/test_prompt_caching.py @@ -0,0 +1,479 @@ +#!/usr/bin/env python3 +"""Example script to test prompt caching with AWS Bedrock. + +This script demonstrates: +1. Manual cache point insertion +2. Cache usage metrics +3. Cost savings from caching + +Prerequisites: +- AWS credentials configured +- Access to Claude 4 Sonnet on Bedrock EU region +- Install dependencies: pip install "pydantic-ai-slim[bedrock]" +""" + +import asyncio +import os +from typing import Callable + +from pydantic_ai import Agent, RunContext +from pydantic_ai.messages import ( + CachePoint, + ModelMessage, + ModelRequest, + ToolReturnPart, + UserPromptPart, +) +from pydantic_ai.models.bedrock import BedrockConverseModel +from pydantic_ai.providers.bedrock import BedrockProvider + +# Some long context for testing +# Note that Bedrock prompt caching has a minimum token length of 1024 tokens. +LONG_CONTEXT = """ +This is a comprehensive company handbook containing detailed information about our organization. + +COMPANY BACKGROUND AND HISTORY: +We are a leading technology company founded in 2018 that specializes in cutting-edge artificial intelligence and machine learning solutions. Our company was started by a team of former researchers from top universities including MIT, Stanford, and Carnegie Mellon. We began as a small startup with just 5 people working out of a garage in Palo Alto, and have since grown to over 250 employees across multiple offices worldwide. + +Our initial focus was on natural language processing, but we quickly expanded into computer vision, predictive analytics, and reinforcement learning. We've raised over $150 million in funding across Series A, B, and C rounds from leading venture capital firms including Andreessen Horowitz, Sequoia Capital, and Google Ventures. + +PRODUCTS AND SERVICES: +Our flagship products include: + +1. NLP Suite - A comprehensive natural language processing platform that includes: + - Sentiment analysis with 99.2% accuracy across 50+ languages + - Named entity recognition supporting multilingual content + - Text summarization using transformer models and neural architectures + - Question answering systems for enterprise knowledge bases + - Conversational AI chatbots with advanced context understanding + - Document classification and content moderation tools + +2. Vision AI Platform - Computer vision solutions featuring: + - Object detection and classification with real-time processing capabilities + - Facial recognition and emotion detection systems (privacy-compliant) + - Medical image analysis for diagnostic assistance in healthcare + - Quality control systems for manufacturing and production lines + - Autonomous vehicle perception systems for transportation + - Augmented reality content recognition and spatial mapping + +3. Predictive Analytics Engine - Advanced forecasting tools including: + - Demand forecasting for retail and e-commerce optimization + - Financial risk assessment and fraud detection algorithms + - Customer churn prediction and retention strategies + - Supply chain optimization and logistics planning + - Energy consumption prediction for smart buildings + - Weather impact modeling for agricultural planning + +INDUSTRY VERTICALS AND CLIENT BASE: +We serve clients across multiple industries with customized AI solutions: + +Healthcare Sector: +- Diagnostic assistance systems for radiologists and pathologists +- Drug discovery acceleration using machine learning algorithms +- Patient outcome prediction models for treatment optimization +- Medical record digitization and automated analysis systems +- Telemedicine platform integration and remote monitoring +- Clinical trial optimization and patient recruitment systems + +Financial Services: +- Risk assessment platforms for loan approvals and credit scoring +- Algorithmic trading systems with advanced market analysis +- Fraud detection and prevention for payment processors +- Regulatory compliance monitoring and automated reporting +- Customer service automation for banking operations +- Investment portfolio optimization using quantitative algorithms + +Retail and E-commerce: +- Demand forecasting and inventory management optimization +- Personalized recommendation engines for online platforms +- Dynamic pricing algorithms for revenue optimization +- Customer behavior analysis and market segmentation +- Supply chain visibility and logistics coordination +- Visual search and product discovery technologies + +Manufacturing and Industrial: +- Computer vision quality control systems for production lines +- Predictive maintenance algorithms for equipment monitoring +- Process optimization for increased efficiency and waste reduction +- Safety monitoring systems using IoT sensors and AI analytics +- Energy consumption optimization for sustainable operations +- Robotic automation and intelligent manufacturing systems + +TEAM AND ORGANIZATIONAL STRUCTURE: +Our diverse team consists of world-class talent: +- 150+ engineers and data scientists with advanced degrees +- 50+ product managers, designers, and business development professionals +- 30+ sales, marketing, and customer success specialists +- 20+ operations, legal, and administrative support staff + +We maintain offices in strategic locations: +- San Francisco, California (Headquarters) - 120 employees +- New York, New York (East Coast Operations) - 60 employees +- London, United Kingdom (European Operations) - 40 employees +- Toronto, Canada (AI Research Laboratory) - 30 employees + +Our company culture emphasizes innovation, ethical AI development, diversity and inclusion, work-life balance, open source contributions, and environmental sustainability. + +TECHNOLOGY STACK AND INFRASTRUCTURE: +Machine Learning Development: +- Python ecosystem with TensorFlow, PyTorch, and scikit-learn +- CUDA and distributed computing for GPU acceleration +- MLflow for experiment tracking and model management +- Apache Airflow for ML pipeline orchestration +- Jupyter notebooks and collaborative development environments + +Cloud Infrastructure: +- Multi-cloud architecture using AWS, Google Cloud, and Azure +- Kubernetes for container orchestration and microservices +- Docker for containerization and application packaging +- Terraform for infrastructure as code automation +- Jenkins and GitLab CI/CD for continuous deployment + +Data Engineering: +- Apache Kafka for real-time data streaming +- Apache Spark for large-scale data processing +- Elasticsearch for search and analytics +- PostgreSQL, MongoDB, and Redis databases +- Apache Hadoop for distributed storage + +Security and Compliance: +- Zero-trust security architecture with end-to-end encryption +- Multi-factor authentication and role-based access control +- GDPR, HIPAA, and SOC2 compliance frameworks +- Regular security audits and penetration testing +- Privacy-preserving machine learning techniques + +MISSION AND VALUES: +Our mission is to democratize artificial intelligence and make advanced machine learning capabilities accessible to businesses of all sizes. We believe AI should augment human capabilities and are committed to developing ethical, explainable, and trustworthy systems that solve real-world problems while considering societal impact. +""" + +# TODO: add simple example for manual cache point insertion + + +async def demo_manual_cache_points() -> None: + """Demonstrate manual cache point insertion.""" + print('=== Manual Cache Points Demo ===') + + bedrock_model = BedrockConverseModel( + 'eu.anthropic.claude-sonnet-4-20250514-v1:0', + provider=BedrockProvider(profile_name='co2-s3'), + ) + + _amazon_model = BedrockConverseModel( + 'eu.amazon.nova-pro-v1:0', + provider=BedrockProvider(profile_name='co2-s3'), + ) + + # anthropic_model = AnthropicModel( + # 'claude-sonnet-4-20250514', + # ) + + agent = Agent( + model=bedrock_model, + system_prompt='You are a helpful assistant with access to company information.', + ) + + print('Running first query with cache point...') + # First query with cache point - this should cache the long context + result1 = await agent.run( + [ + LONG_CONTEXT, + CachePoint(), # Cache everything above this point + 'Read this, and then I will ask you a question.', + ] + ) + + print(f'Response: {result1.output}') + result_1_usage = result1.usage() + if result_1_usage: + print(f'Usage: {result_1_usage}') + if result_1_usage.cache_write_tokens: + print(f'Cache write tokens: {result_1_usage.cache_write_tokens}') + if result_1_usage.cache_read_tokens: + print(f'Cache read tokens: {result_1_usage.cache_read_tokens}') + + print('\nRunning second query (should use cache)...') + # Second query with same cached context + result2 = await agent.run( + [ + LONG_CONTEXT, + CachePoint(), + 'What technology stack does the company use?', + ] + ) + + print(f'Response: {result2.output}') + result_2_usage = result2.usage() + if result_2_usage: + print(f'Usage: {result_2_usage}') + if result_2_usage.cache_write_tokens: + print(f'Cache write tokens: {result_2_usage.cache_write_tokens}') + if result_2_usage.cache_read_tokens: + print(f'Cache read tokens: {result_2_usage.cache_read_tokens}') + + # Calculate potential savings + if result_2_usage.cache_read_tokens and result_2_usage.input_tokens: + cache_percentage = ( + result_2_usage.cache_read_tokens / result_2_usage.input_tokens + ) * 100 + print(f'Cache hit rate: {cache_percentage:.1f}% of input tokens') + # Cached tokens typically cost ~10% of normal tokens + savings = result_2_usage.cache_read_tokens * 0.9 + print(f'Estimated savings: ~{savings:.0f} token-equivalents') + + +# TODO(larryhudson): For this to work in a long thread, you also need to add a processor to remove the cache points from the non-last messages +def cache_long_tool_returns( + min_tokens: int = 1024, +) -> Callable[[list[ModelMessage]], list[ModelMessage]]: + """Add cache points after long tool results. + + This processor only examines the last ModelRequest in the message history + and automatically adds cache points after UserPromptPart content that + likely came from ToolReturn.content when the content exceeds the token threshold. + + Args: + min_tokens: Minimum estimated tokens before adding a cache point + + Returns: + A processor function that adds cache points after long tool results + """ + + def processor(messages: list[ModelMessage]) -> list[ModelMessage]: + if not messages: + return messages + + last_message = messages[-1] + if not isinstance(last_message, ModelRequest): + return messages + + tool_return_part = next( + (part for part in last_message.parts if isinstance(part, ToolReturnPart)), + None, + ) + + if not tool_return_part: + return messages + + if len(tool_return_part.model_response_str()) > min_tokens * 4: + last_message.parts.append(UserPromptPart(content=[CachePoint()])) + + return messages + + return processor + + +async def demo_tool_result_caching() -> None: + """Demonstrate caching of long tool results using ToolReturn.content with CachePoint.""" + print('\n=== Tool Result Caching Demo ===') + + bedrock_model = BedrockConverseModel( + 'eu.anthropic.claude-sonnet-4-20250514-v1:0', + provider=BedrockProvider(profile_name='co2-s3'), + ) + + # _anthropic_model = AnthropicModel( + # 'claude-sonnet-4-20250514', + # ) + + _amazon_model = BedrockConverseModel( + 'eu.amazon.nova-pro-v1:0', + provider=BedrockProvider(profile_name='co2-s3'), + ) + + agent = Agent( + model=bedrock_model, + system_prompt='You are a helpful assistant that processes large datasets and documents. When I ask you to use multiple tools, call them all in sequence.', + history_processors=[cache_long_tool_returns(min_tokens=1024)], + ) + + @agent.tool + def fetch_api_documentation(ctx: RunContext) -> str: + """Fetch comprehensive API documentation.""" + # Simulate fetching large API documentation + api_docs = ( + f""" +API Reference Documentation - MyService v2.0 + +{'=' * 60} +AUTHENTICATION +{'=' * 60} +All endpoints require Bearer token authentication. +Header: Authorization: Bearer + +{'=' * 60} +ENDPOINTS +{'=' * 60} +""" + + 'Detailed endpoint documentation with examples... ' * 150 + ) + + return api_docs + + @agent.tool + def fetch_user_guide(ctx: RunContext) -> str: + """Fetch comprehensive user guide that references API docs.""" + # Simulate fetching large user guide + user_guide = ( + f""" +User Guide - MyService Integration + +{'=' * 60} +GETTING STARTED +{'=' * 60} +This guide assumes you have read the API documentation above. + +{'=' * 60} +STEP-BY-STEP TUTORIALS +{'=' * 60} +""" + + 'Detailed step-by-step tutorials and examples... ' * 120 + ) + + return user_guide + + @agent.tool + def generate_code_examples(ctx: RunContext) -> str: + """Generate code examples that reference both cached docs.""" + # Simulate generating code examples + code_examples = ( + f""" +Code Examples - MyService Integration + +{'=' * 60} +BASIC AUTHENTICATION EXAMPLE +{'=' * 60} +Based on the API documentation and user guide above: + +import requests + +# Use the authentication method described in the API docs +headers = {{'Authorization': 'Bearer your-token'}} +response = requests.get('https://api.myservice.com/data', headers=headers) + +{'=' * 60} +ADVANCED EXAMPLES +{'=' * 60} +""" + + 'More detailed code examples and best practices... ' * 80 + ) + + return code_examples + + print('Demo: Request that triggers multiple tool calls with automatic caching...') + print( + 'The history processor will automatically add cache points after long tool results.' + ) + print('Expected: cache write → cache read + write → cache read + write') + + result = await agent.run( + 'Please help me integrate with MyService by: 1) fetching the API documentation, ' + + 'then stop and think before get the user guide, then read the user guide and then finally generating code examples. Use all three tools in sequence, one at a time.' + ) + + print(f'Response: {result.output[:200]}...') + + usage = result.usage() + if usage: + print(f'\nUsage Summary: {usage}') + if usage.cache_write_tokens: + print( + f'āœ… Cache writes: {usage.cache_write_tokens} tokens (automatically cached tool results)' + ) + if usage.cache_read_tokens: + print( + f'āœ… Cache reads: {usage.cache_read_tokens} tokens (reusing previous tool results)' + ) + if usage.cache_read_tokens and usage.input_tokens: + cache_percentage = (usage.cache_read_tokens / usage.input_tokens) * 100 + print(f'Cache efficiency: {cache_percentage:.1f}% of input was cached') + savings = usage.cache_read_tokens * 0.9 # Cached tokens cost ~10% of normal + print(f'Estimated savings: ~{savings:.0f} token-equivalents') + + print('\nšŸŽÆ How automatic tool result caching works:') + print('1. Tools return large content via ToolReturn.content (no manual CachePoint)') + print('2. History processor detects long content and automatically adds CachePoint') + print('3. First tool → cache write only (new content)') + print('4. Second tool → cache read (previous results) + cache write (new content)') + print('5. Third tool → cache read (all previous) + cache write (new content)') + print('6. Result shows both cache_read_tokens and cache_write_tokens') + + +async def demo_system_prompt_caching() -> None: + """Demonstrate system prompt caching.""" + print('\n=== System Prompt Caching Demo ===') + + bedrock_model = BedrockConverseModel( + 'eu.anthropic.claude-sonnet-4-20250514-v1:0', + provider=BedrockProvider(profile_name='co2-s3'), + ) + + agent = Agent(model=bedrock_model, system_prompt=LONG_CONTEXT) + + print('Running first financial query...') + result1 = await agent.run( + [ + # By adding a CachePoint at the start of the first user prompt part, the system prompt will be cached. + CachePoint(), + "What are the key ratios to analyze when evaluating a tech company's financial health?", + ] + ) + print(f'Response length: {len(result1.output)} characters') + result_1_usage = result1.usage() + if result_1_usage: + print(f'Usage: {result_1_usage}') + + print('\nRunning second financial query...') + result2 = await agent.run( + 'How should I approach valuing a SaaS startup with recurring revenue?' + ) + print(f'Response length: {len(result2.output)} characters') + result_2_usage = result2.usage() + if result_2_usage: + print(f'Usage: {result_2_usage}') + if result_2_usage.cache_read_tokens: + print(f'Cache read tokens: {result_2_usage.cache_read_tokens}') + + +# TODO(larryhudson): Make this more minimal and clear +async def main() -> None: + """Main function to run all demos.""" + print('Prompt Caching Demo with AWS Bedrock') + print('====================================') + + # Check if AWS credentials are available + if not any( + key in os.environ for key in ['AWS_ACCESS_KEY_ID', 'AWS_PROFILE'] + ) and not os.path.exists(os.path.expanduser('~/.aws/credentials')): + print('āš ļø Warning: No AWS credentials found!') + print('Please configure AWS credentials using one of:') + print('1. AWS CLI: aws configure') + print('2. Environment variables: AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY') + print('3. IAM role (if running on EC2)') + return + + try: + await demo_manual_cache_points() + await demo_tool_result_caching() + await demo_system_prompt_caching() + + print('\nāœ… Demo completed successfully!') + print('\nKey takeaways:') + print('- Cache points reduce costs by up to 90% for repeated context') + print('- Cache read tokens appear in usage metrics') + print('- Same context + CachePoint = cache hit') + print('- Great for long system prompts, documents, conversation history') + print('- ToolReturn.content supports CachePoint for caching tool results') + print('- Message processors can automatically add cache points to tool results') + + except Exception as e: + print(f'āŒ Error running demo: {e}') + print('\nTroubleshooting:') + print('1. Check AWS credentials and region configuration') + print('2. Verify access to Claude 3.5 Sonnet on Bedrock') + print('3. Ensure your AWS account has Bedrock permissions') + raise e + + +if __name__ == '__main__': + asyncio.run(main()) diff --git a/pydantic_ai_slim/pydantic_ai/messages.py b/pydantic_ai_slim/pydantic_ai/messages.py index 28447187e..27e1f5161 100644 --- a/pydantic_ai_slim/pydantic_ai/messages.py +++ b/pydantic_ai_slim/pydantic_ai/messages.py @@ -423,7 +423,19 @@ def format(self) -> str: __repr__ = _utils.dataclasses_no_defaults_repr -UserContent: TypeAlias = 'str | ImageUrl | AudioUrl | DocumentUrl | VideoUrl | BinaryContent' +@dataclass +class CachePoint: + """A cache point marker for prompt caching. + + Can be inserted into UserPromptPart.content to mark cache boundaries. + Models that don't support caching will filter these out. + """ + + kind: Literal['cache-point'] = 'cache-point' + """Type identifier, this is available on all parts as a discriminator.""" + + +UserContent: TypeAlias = 'str | ImageUrl | AudioUrl | DocumentUrl | VideoUrl | BinaryContent | CachePoint' @dataclass(repr=False) diff --git a/pydantic_ai_slim/pydantic_ai/models/anthropic.py b/pydantic_ai_slim/pydantic_ai/models/anthropic.py index dbc35cfd9..e607cd3f2 100644 --- a/pydantic_ai_slim/pydantic_ai/models/anthropic.py +++ b/pydantic_ai_slim/pydantic_ai/models/anthropic.py @@ -20,6 +20,7 @@ BinaryContent, BuiltinToolCallPart, BuiltinToolReturnPart, + CachePoint, DocumentUrl, ImageUrl, ModelMessage, @@ -46,6 +47,7 @@ from anthropic.types.beta import ( BetaBase64PDFBlockParam, BetaBase64PDFSourceParam, + BetaCacheControlEphemeralParam, BetaCitationsDelta, BetaCodeExecutionTool20250522Param, BetaCodeExecutionToolResultBlock, @@ -387,7 +389,15 @@ async def _map_message(self, messages: list[ModelMessage]) -> tuple[str, list[Be system_prompt_parts.append(request_part.content) elif isinstance(request_part, UserPromptPart): async for content in self._map_user_prompt(request_part): - user_content_params.append(content) + if isinstance(content, dict) and content.get('type') == 'ephemeral': + if user_content_params: + # TODO(larryhudson): Ensure the last user content param supports cache_control + user_content_params[-1]['cache_control'] = cast( + BetaCacheControlEphemeralParam, content + ) + continue + else: + user_content_params.append(content) elif isinstance(request_part, ToolReturnPart): tool_result_block_param = BetaToolResultBlockParam( tool_use_id=_guard_tool_call_id(t=request_part), @@ -476,7 +486,7 @@ async def _map_message(self, messages: list[ModelMessage]) -> tuple[str, list[Be @staticmethod async def _map_user_prompt( part: UserPromptPart, - ) -> AsyncGenerator[BetaContentBlockParam]: + ) -> AsyncGenerator[BetaContentBlockParam | BetaCacheControlEphemeralParam]: if isinstance(part.content, str): if part.content: # Only yield non-empty text yield BetaTextBlockParam(text=part.content, type='text') @@ -517,6 +527,8 @@ async def _map_user_prompt( ) else: # pragma: no cover raise RuntimeError(f'Unsupported media type: {item.media_type}') + elif isinstance(item, CachePoint): + yield BetaCacheControlEphemeralParam(type='ephemeral') else: raise RuntimeError(f'Unsupported content type: {type(item)}') # pragma: no cover diff --git a/pydantic_ai_slim/pydantic_ai/models/bedrock.py b/pydantic_ai_slim/pydantic_ai/models/bedrock.py index 995e7e833..d3b74030b 100644 --- a/pydantic_ai_slim/pydantic_ai/models/bedrock.py +++ b/pydantic_ai_slim/pydantic_ai/models/bedrock.py @@ -22,6 +22,7 @@ BinaryContent, BuiltinToolCallPart, BuiltinToolReturnPart, + CachePoint, DocumentUrl, ImageUrl, ModelMessage, @@ -298,9 +299,13 @@ async def _process_response(self, response: ConverseResponseTypeDef) -> ModelRes tool_call_id=tool_use['toolUseId'], ), ) + print('DEBUG: raw usage', response['usage']) u = usage.RequestUsage( input_tokens=response['usage']['inputTokens'], output_tokens=response['usage']['outputTokens'], + # TODO(larryhudson): Failing type check here because boto3 bedrock runtime type defs are not updated yet (needs 1.40.11). + cache_read_tokens=response['usage'].get('cacheReadInputTokens', 0), + cache_write_tokens=response['usage'].get('cacheWriteInputTokens', 0), ) vendor_id = response.get('ResponseMetadata', {}).get('RequestId', None) return ModelResponse(items, usage=u, model_name=self.model_name, provider_request_id=vendor_id) @@ -416,16 +421,31 @@ async def _map_messages( # noqa: C901 Groups consecutive ToolReturnPart objects into a single user message as required by Bedrock Claude/Nova models. """ profile = BedrockModelProfile.from_profile(self.profile) + supports_caching = profile.bedrock_supports_prompt_caching system_prompt: list[SystemContentBlockTypeDef] = [] bedrock_messages: list[MessageUnionTypeDef] = [] document_count: Iterator[int] = count(1) for message in messages: if isinstance(message, ModelRequest): - for part in message.parts: + for i, part in enumerate(message.parts): if isinstance(part, SystemPromptPart): system_prompt.append({'text': part.content}) elif isinstance(part, UserPromptPart): - bedrock_messages.extend(await self._map_user_prompt(part, document_count)) + # Handle case where UserPromptPart starts with a CachePoint and follows the SystemPromptPart + cache_point_for_system_prompt, user_prompt_part = self._extract_leading_cache_point( + part, + supports_caching, + immediately_follows_system_prompt=i > 0 + and isinstance(message.parts[i - 1], SystemPromptPart), + ) + if cache_point_for_system_prompt is not None: + # TODO: Failing type check here because boto3 bedrock runtime type defs are not updated yet (needs 1.40.11). + system_prompt.append(cache_point_for_system_prompt) + + if user_prompt_part is not None: + bedrock_messages.extend( + await self._map_user_prompt(user_prompt_part, document_count, supports_caching) + ) elif isinstance(part, ToolReturnPart): assert part.tool_call_id is not None bedrock_messages.append( @@ -516,16 +536,57 @@ async def _map_messages( # noqa: C901 if instructions := self._get_instructions(messages): system_prompt.insert(0, {'text': instructions}) + print('DEBUG: system_prompt', system_prompt) + print('DEBUG: processed_messages', processed_messages) return system_prompt, processed_messages - @staticmethod - async def _map_user_prompt(part: UserPromptPart, document_count: Iterator[int]) -> list[MessageUnionTypeDef]: + def _extract_leading_cache_point( + self, + part: UserPromptPart, + supports_caching: bool, + immediately_follows_system_prompt: bool, + ) -> tuple[SystemContentBlockTypeDef | None, UserPromptPart | None]: + """Extract leading CachePoint from UserPromptPart if conditions are met. + + Returns a tuple of: + - cache_point_for_system_prompt: Cache point block to add to system prompt, or None + - user_prompt_part: UserPromptPart with cache point removed, or None if no content remains + """ + # Only remove cache point if this UserPromptPart immediately follows a SystemPromptPart + # within the same message parts, and the cache point is the first item in the user prompt + + if ( + immediately_follows_system_prompt + and isinstance(part.content, list) + and part.content + and isinstance(part.content[0], CachePoint) + ): + cache_point_for_system_prompt = {'cachePoint': {'type': 'default'}} if supports_caching else None + + remaining_content = part.content[1:] + print('DEBUG: removing cache point from the part') + user_prompt_part = UserPromptPart(content=remaining_content) if remaining_content else None + + # TODO: Failing type check here because boto3 bedrock runtime type defs are not updated yet (needs 1.40.11). + return cache_point_for_system_prompt, user_prompt_part + + return None, part + + async def _map_user_prompt( + self, part: UserPromptPart, document_count: Iterator[int], supports_caching: bool + ) -> list[MessageUnionTypeDef]: content: list[ContentBlockUnionTypeDef] = [] + if isinstance(part.content, str): content.append({'text': part.content}) else: for item in part.content: - if isinstance(item, str): + if isinstance(item, CachePoint): + if supports_caching: + # TODO: update the boto3 bedrock type defs so 'cachePoint' is available + content.append({'cachePoint': {'type': 'default'}}) + continue + elif isinstance(item, str): content.append({'text': item}) elif isinstance(item, BinaryContent): format = item.format @@ -578,6 +639,7 @@ async def _map_user_prompt(part: UserPromptPart, document_count: Iterator[int]) raise NotImplementedError('Audio is not supported yet.') else: assert_never(item) + print('DEBUG: content', content) return [{'role': 'user', 'content': content}] @staticmethod @@ -669,9 +731,14 @@ def model_name(self) -> str: return self._model_name def _map_usage(self, metadata: ConverseStreamMetadataEventTypeDef) -> usage.RequestUsage: + print('DEBUG: raw usage', metadata['usage']) + return usage.RequestUsage( input_tokens=metadata['usage']['inputTokens'], output_tokens=metadata['usage']['outputTokens'], + # TODO(larryhudson): Failing type check here because boto3 bedrock runtime type defs are not updated yet (needs 1.40.11). + cache_write_tokens=metadata['usage'].get('cacheWriteInputTokens', 0), + cache_read_tokens=metadata['usage'].get('cacheReadInputTokens', 0), ) diff --git a/pydantic_ai_slim/pydantic_ai/providers/bedrock.py b/pydantic_ai_slim/pydantic_ai/providers/bedrock.py index cf19ce290..bd8a1c12d 100644 --- a/pydantic_ai_slim/pydantic_ai/providers/bedrock.py +++ b/pydantic_ai_slim/pydantic_ai/providers/bedrock.py @@ -37,6 +37,45 @@ class BedrockModelProfile(ModelProfile): bedrock_supports_tool_choice: bool = True bedrock_tool_result_format: Literal['text', 'json'] = 'text' bedrock_send_back_thinking_parts: bool = False + bedrock_supports_prompt_caching: bool = False + + +# Supported models: https://docs.aws.amazon.com/bedrock/latest/userguide/prompt-caching.html +ANTHROPIC_CACHING_SUPPORTED_MODELS = ['claude-3-5-sonnet', 'claude-3-5-haiku', 'claude-3-7-sonnet', 'claude-sonnet-4'] + + +def bedrock_anthropic_model_profile(model_name: str) -> ModelProfile | None: + """Create a Bedrock model profile for Anthropic models with caching support where applicable.""" + # Check if this model supports prompt caching + if any(supported in model_name for supported in ANTHROPIC_CACHING_SUPPORTED_MODELS): + return BedrockModelProfile( + bedrock_supports_tool_choice=False, + bedrock_send_back_thinking_parts=True, + bedrock_supports_prompt_caching=True, + ).update(anthropic_model_profile(model_name)) + else: + return BedrockModelProfile(bedrock_supports_tool_choice=False, bedrock_send_back_thinking_parts=True).update( + anthropic_model_profile(model_name) + ) + + +# Supported models: https://docs.aws.amazon.com/bedrock/latest/userguide/prompt-caching.html +AMAZON_CACHING_SUPPORTED_MODELS = ['nova-micro', 'nova-lite', 'nova-pro', 'nova-premier'] + + +# TODO(larryhudson): With Amazon models, you can add 'cachePoint' to user messages but not to tool return messages :( +# Need to make it so that we avoid adding cache points to tool return messages +def bedrock_amazon_model_profile(model_name: str) -> ModelProfile | None: + """Create a Bedrock model profile for Amazon models with caching support where applicable.""" + if any(supported in model_name for supported in AMAZON_CACHING_SUPPORTED_MODELS): + return BedrockModelProfile(bedrock_supports_prompt_caching=True).update(amazon_model_profile(model_name)) + else: + return amazon_model_profile(model_name) + + +def bedrock_mistral_model_profile(model_name: str) -> ModelProfile | None: + """Create a Bedrock model profile for Mistral models.""" + return BedrockModelProfile(bedrock_tool_result_format='json').update(mistral_model_profile(model_name)) class BedrockProvider(Provider[BaseClient]): @@ -56,14 +95,10 @@ def client(self) -> BaseClient: def model_profile(self, model_name: str) -> ModelProfile | None: provider_to_profile: dict[str, Callable[[str], ModelProfile | None]] = { - 'anthropic': lambda model_name: BedrockModelProfile( - bedrock_supports_tool_choice=False, bedrock_send_back_thinking_parts=True - ).update(anthropic_model_profile(model_name)), - 'mistral': lambda model_name: BedrockModelProfile(bedrock_tool_result_format='json').update( - mistral_model_profile(model_name) - ), + 'anthropic': bedrock_anthropic_model_profile, + 'mistral': bedrock_mistral_model_profile, 'cohere': cohere_model_profile, - 'amazon': amazon_model_profile, + 'amazon': bedrock_amazon_model_profile, 'meta': meta_model_profile, 'deepseek': deepseek_model_profile, }