diff --git a/docs/models/anthropic.md b/docs/models/anthropic.md index 75abd4e82b..d55a84991e 100644 --- a/docs/models/anthropic.md +++ b/docs/models/anthropic.md @@ -77,3 +77,133 @@ model = AnthropicModel( agent = Agent(model) ... ``` + +## Prompt Caching + +Anthropic supports [prompt caching](https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching) to reduce costs by caching parts of your prompts. PydanticAI provides three ways to use prompt caching: + +### 1. Cache User Messages with `CachePoint` + +Insert a [`CachePoint`][pydantic_ai.messages.CachePoint] marker in your user messages to cache everything before it: + +```python {test="skip"} +from pydantic_ai import Agent, CachePoint + +agent = Agent('anthropic:claude-sonnet-4-5') + +async def main(): + # Everything before CachePoint will be cached + result = await agent.run([ + 'Long context that should be cached...', + CachePoint(), + 'Your question here' + ]) + print(result.output) +``` + +### 2. Cache System Instructions + +Use `anthropic_cache_instructions=True` to cache your system prompt: + +```python {test="skip"} +from pydantic_ai import Agent +from pydantic_ai.models.anthropic import AnthropicModelSettings + +agent = Agent( + 'anthropic:claude-sonnet-4-5', + system_prompt='Long detailed instructions...', + model_settings=AnthropicModelSettings( + anthropic_cache_instructions=True + ), +) + +async def main(): + result = await agent.run('Your question') + print(result.output) +``` + +### 3. Cache Tool Definitions + +Use `anthropic_cache_tools=True` to cache your tool definitions: + +```python {test="skip"} +from pydantic_ai import Agent +from pydantic_ai.models.anthropic import AnthropicModelSettings + +agent = Agent( + 'anthropic:claude-sonnet-4-5', + model_settings=AnthropicModelSettings( + anthropic_cache_tools=True + ), +) + +@agent.tool +def my_tool() -> str: + """Tool definition will be cached.""" + return 'result' + +async def main(): + result = await agent.run('Use the tool') + print(result.output) +``` + +### Combining Cache Strategies + +You can combine all three caching strategies for maximum savings: + +```python {test="skip"} +from pydantic_ai import Agent, CachePoint, RunContext +from pydantic_ai.models.anthropic import AnthropicModelSettings + +agent = Agent( + 'anthropic:claude-sonnet-4-5', + system_prompt='Detailed instructions...', + model_settings=AnthropicModelSettings( + anthropic_cache_instructions=True, + anthropic_cache_tools=True, + ), +) + +@agent.tool +def search_docs(ctx: RunContext, query: str) -> str: + """Search documentation.""" + return f'Results for {query}' + +async def main(): + # First call - writes to cache + result1 = await agent.run([ + 'Long context from documentation...', + CachePoint(), + 'First question' + ]) + + # Subsequent calls - read from cache (90% cost reduction) + result2 = await agent.run([ + 'Long context from documentation...', # Same content + CachePoint(), + 'Second question' + ]) + print(f'First: {result1.output}') + print(f'Second: {result2.output}') +``` + +Access cache usage statistics via `result.usage()`: + +```python {test="skip"} +from pydantic_ai import Agent +from pydantic_ai.models.anthropic import AnthropicModelSettings + +agent = Agent( + 'anthropic:claude-sonnet-4-5', + system_prompt='Instructions...', + model_settings=AnthropicModelSettings( + anthropic_cache_instructions=True + ), +) + +async def main(): + result = await agent.run('Your question') + usage = result.usage() + print(f'Cache write tokens: {usage.cache_write_tokens}') + print(f'Cache read tokens: {usage.cache_read_tokens}') +``` diff --git a/pydantic_ai_slim/pydantic_ai/__init__.py b/pydantic_ai_slim/pydantic_ai/__init__.py index 1054cef630..ec0137f856 100644 --- a/pydantic_ai_slim/pydantic_ai/__init__.py +++ b/pydantic_ai_slim/pydantic_ai/__init__.py @@ -42,6 +42,7 @@ BinaryImage, BuiltinToolCallPart, BuiltinToolReturnPart, + CachePoint, DocumentFormat, DocumentMediaType, DocumentUrl, @@ -141,6 +142,7 @@ 'BinaryContent', 'BuiltinToolCallPart', 'BuiltinToolReturnPart', + 'CachePoint', 'DocumentFormat', 'DocumentMediaType', 'DocumentUrl', diff --git a/pydantic_ai_slim/pydantic_ai/messages.py b/pydantic_ai_slim/pydantic_ai/messages.py index f2e3d5eef8..988430d12a 100644 --- a/pydantic_ai_slim/pydantic_ai/messages.py +++ b/pydantic_ai_slim/pydantic_ai/messages.py @@ -612,8 +612,24 @@ def __init__( raise ValueError('`BinaryImage` must be have a media type that starts with "image/"') # pragma: no cover +@dataclass +class CachePoint: + """A cache point marker for prompt caching. + + Can be inserted into UserPromptPart.content to mark cache boundaries. + Models that don't support caching will filter these out. + + Supported by: + + - Anthropic + """ + + kind: Literal['cache-point'] = 'cache-point' + """Type identifier, this is available on all parts as a discriminator.""" + + MultiModalContent = ImageUrl | AudioUrl | DocumentUrl | VideoUrl | BinaryContent -UserContent: TypeAlias = str | MultiModalContent +UserContent: TypeAlias = str | MultiModalContent | CachePoint @dataclass(repr=False) @@ -730,6 +746,9 @@ def otel_message_parts(self, settings: InstrumentationSettings) -> list[_otel_me if settings.include_content and settings.include_binary_content: converted_part['content'] = base64.b64encode(part.data).decode() parts.append(converted_part) + elif isinstance(part, CachePoint): + # CachePoint is a marker, not actual content - skip it for otel + pass else: parts.append({'type': part.kind}) # pragma: no cover return parts diff --git a/pydantic_ai_slim/pydantic_ai/models/anthropic.py b/pydantic_ai_slim/pydantic_ai/models/anthropic.py index 31351345b0..10e20c5073 100644 --- a/pydantic_ai_slim/pydantic_ai/models/anthropic.py +++ b/pydantic_ai_slim/pydantic_ai/models/anthropic.py @@ -19,6 +19,7 @@ BinaryContent, BuiltinToolCallPart, BuiltinToolReturnPart, + CachePoint, DocumentUrl, FilePart, FinishReason, @@ -58,6 +59,7 @@ from anthropic.types.beta import ( BetaBase64PDFBlockParam, BetaBase64PDFSourceParam, + BetaCacheControlEphemeralParam, BetaCitationsDelta, BetaCodeExecutionTool20250522Param, BetaCodeExecutionToolResultBlock, @@ -148,6 +150,22 @@ class AnthropicModelSettings(ModelSettings, total=False): See [the Anthropic docs](https://docs.anthropic.com/en/docs/build-with-claude/extended-thinking) for more information. """ + anthropic_cache_tools: bool + """Whether to add cache_control to the last tool definition. + + When enabled, the last tool in the tools array will have cache_control set, + allowing Anthropic to cache tool definitions and reduce costs. + See https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching for more information. + """ + + anthropic_cache_instructions: bool + """Whether to add cache_control to the last system prompt block. + + When enabled, the last system prompt will have cache_control set, + allowing Anthropic to cache system instructions and reduce costs. + See https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching for more information. + """ + @dataclass(init=False) class AnthropicModel(Model): @@ -289,7 +307,7 @@ async def _messages_create( model_request_parameters: ModelRequestParameters, ) -> BetaMessage | AsyncStream[BetaRawMessageStreamEvent]: # standalone function to make it easier to override - tools = self._get_tools(model_request_parameters) + tools = self._get_tools(model_request_parameters, model_settings) tools, mcp_servers, beta_features = self._add_builtin_tools(tools, model_request_parameters) tool_choice: BetaToolChoiceParam | None @@ -305,7 +323,7 @@ async def _messages_create( if (allow_parallel_tool_calls := model_settings.get('parallel_tool_calls')) is not None: tool_choice['disable_parallel_tool_use'] = not allow_parallel_tool_calls - system_prompt, anthropic_messages = await self._map_message(messages, model_request_parameters) + system_prompt, anthropic_messages = await self._map_message(messages, model_request_parameters, model_settings) try: extra_headers = model_settings.get('extra_headers', {}) @@ -411,8 +429,19 @@ async def _process_streamed_response( _provider_url=self._provider.base_url, ) - def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[BetaToolUnionParam]: - return [self._map_tool_definition(r) for r in model_request_parameters.tool_defs.values()] + def _get_tools( + self, model_request_parameters: ModelRequestParameters, model_settings: AnthropicModelSettings + ) -> list[BetaToolUnionParam]: + tools: list[BetaToolUnionParam] = [ + self._map_tool_definition(r) for r in model_request_parameters.tool_defs.values() + ] + + # Add cache_control to the last tool if enabled + if tools and model_settings.get('anthropic_cache_tools'): + last_tool = cast(dict[str, Any], tools[-1]) + last_tool['cache_control'] = BetaCacheControlEphemeralParam(type='ephemeral') + + return tools def _add_builtin_tools( self, tools: list[BetaToolUnionParam], model_request_parameters: ModelRequestParameters @@ -464,8 +493,11 @@ def _add_builtin_tools( return tools, mcp_servers, beta_features async def _map_message( # noqa: C901 - self, messages: list[ModelMessage], model_request_parameters: ModelRequestParameters - ) -> tuple[str, list[BetaMessageParam]]: + self, + messages: list[ModelMessage], + model_request_parameters: ModelRequestParameters, + model_settings: AnthropicModelSettings, + ) -> tuple[str | list[BetaTextBlockParam], list[BetaMessageParam]]: """Just maps a `pydantic_ai.Message` to a `anthropic.types.MessageParam`.""" system_prompt_parts: list[str] = [] anthropic_messages: list[BetaMessageParam] = [] @@ -477,7 +509,10 @@ async def _map_message( # noqa: C901 system_prompt_parts.append(request_part.content) elif isinstance(request_part, UserPromptPart): async for content in self._map_user_prompt(request_part): - user_content_params.append(content) + if isinstance(content, CachePoint): + self._add_cache_control_to_last_param(user_content_params) + else: + user_content_params.append(content) elif isinstance(request_part, ToolReturnPart): tool_result_block_param = BetaToolResultBlockParam( tool_use_id=_guard_tool_call_id(t=request_part), @@ -637,12 +672,43 @@ async def _map_message( # noqa: C901 if instructions := self._get_instructions(messages, model_request_parameters): system_prompt_parts.insert(0, instructions) system_prompt = '\n\n'.join(system_prompt_parts) + + # If anthropic_cache_instructions is enabled, return system prompt as a list with cache_control + if system_prompt and model_settings.get('anthropic_cache_instructions'): + system_prompt_blocks = [ + BetaTextBlockParam( + type='text', text=system_prompt, cache_control=BetaCacheControlEphemeralParam(type='ephemeral') + ) + ] + return system_prompt_blocks, anthropic_messages + return system_prompt, anthropic_messages + @staticmethod + def _add_cache_control_to_last_param(params: list[BetaContentBlockParam]) -> None: + """Add cache control to the last content block param. + + See https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching for more information. + """ + if not params: + raise UserError( + 'CachePoint cannot be the first content in a user message - there must be previous content to attach the CachePoint to.' + ) + + # Only certain types support cache_control + # See https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching#what-can-be-cached + cacheable_types = {'text', 'tool_use', 'server_tool_use', 'image', 'tool_result'} + last_param = cast(dict[str, Any], params[-1]) # Cast to dict for mutation + if last_param['type'] not in cacheable_types: + raise UserError(f'Cache control not supported for param type: {last_param["type"]}') + + # Add cache_control to the last param + last_param['cache_control'] = BetaCacheControlEphemeralParam(type='ephemeral') + @staticmethod async def _map_user_prompt( part: UserPromptPart, - ) -> AsyncGenerator[BetaContentBlockParam]: + ) -> AsyncGenerator[BetaContentBlockParam | CachePoint]: if isinstance(part.content, str): if part.content: # Only yield non-empty text yield BetaTextBlockParam(text=part.content, type='text') @@ -651,6 +717,8 @@ async def _map_user_prompt( if isinstance(item, str): if item: # Only yield non-empty text yield BetaTextBlockParam(text=item, type='text') + elif isinstance(item, CachePoint): + yield item elif isinstance(item, BinaryContent): if item.is_image: yield BetaImageBlockParam( @@ -717,6 +785,8 @@ def _map_usage( key: value for key, value in response_usage.model_dump().items() if isinstance(value, int) } + # Note: genai-prices already extracts cache_creation_input_tokens and cache_read_input_tokens + # from the Anthropic response and maps them to cache_write_tokens and cache_read_tokens return usage.RequestUsage.extract( dict(model=model, usage=details), provider=provider, diff --git a/pydantic_ai_slim/pydantic_ai/models/bedrock.py b/pydantic_ai_slim/pydantic_ai/models/bedrock.py index df8e8746b7..f420b9907e 100644 --- a/pydantic_ai_slim/pydantic_ai/models/bedrock.py +++ b/pydantic_ai_slim/pydantic_ai/models/bedrock.py @@ -19,6 +19,7 @@ BinaryContent, BuiltinToolCallPart, BuiltinToolReturnPart, + CachePoint, DocumentUrl, FinishReason, ImageUrl, @@ -672,6 +673,9 @@ async def _map_user_prompt(part: UserPromptPart, document_count: Iterator[int]) content.append({'video': video}) elif isinstance(item, AudioUrl): # pragma: no cover raise NotImplementedError('Audio is not supported yet.') + elif isinstance(item, CachePoint): + # Bedrock doesn't support prompt caching via CachePoint in this implementation + pass else: assert_never(item) return [{'role': 'user', 'content': content}] diff --git a/pydantic_ai_slim/pydantic_ai/models/gemini.py b/pydantic_ai_slim/pydantic_ai/models/gemini.py index afc2bd7156..10c227d0db 100644 --- a/pydantic_ai_slim/pydantic_ai/models/gemini.py +++ b/pydantic_ai_slim/pydantic_ai/models/gemini.py @@ -21,6 +21,7 @@ BinaryContent, BuiltinToolCallPart, BuiltinToolReturnPart, + CachePoint, FilePart, FileUrl, ModelMessage, @@ -391,6 +392,9 @@ async def _map_user_prompt(self, part: UserPromptPart) -> list[_GeminiPartUnion] else: # pragma: lax no cover file_data = _GeminiFileDataPart(file_data={'file_uri': item.url, 'mime_type': item.media_type}) content.append(file_data) + elif isinstance(item, CachePoint): + # Gemini doesn't support prompt caching via CachePoint + pass else: assert_never(item) # pragma: lax no cover return content diff --git a/pydantic_ai_slim/pydantic_ai/models/google.py b/pydantic_ai_slim/pydantic_ai/models/google.py index 071f65fa66..3a5cfe9258 100644 --- a/pydantic_ai_slim/pydantic_ai/models/google.py +++ b/pydantic_ai_slim/pydantic_ai/models/google.py @@ -19,6 +19,7 @@ BinaryContent, BuiltinToolCallPart, BuiltinToolReturnPart, + CachePoint, FilePart, FileUrl, FinishReason, @@ -602,6 +603,9 @@ async def _map_user_prompt(self, part: UserPromptPart) -> list[PartDict]: else: file_data_dict: FileDataDict = {'file_uri': item.url, 'mime_type': item.media_type} content.append({'file_data': file_data_dict}) # pragma: lax no cover + elif isinstance(item, CachePoint): + # Google Gemini doesn't support prompt caching via CachePoint + pass else: assert_never(item) return content diff --git a/pydantic_ai_slim/pydantic_ai/models/huggingface.py b/pydantic_ai_slim/pydantic_ai/models/huggingface.py index 7ca3199473..94598aee7e 100644 --- a/pydantic_ai_slim/pydantic_ai/models/huggingface.py +++ b/pydantic_ai_slim/pydantic_ai/models/huggingface.py @@ -18,6 +18,7 @@ BinaryContent, BuiltinToolCallPart, BuiltinToolReturnPart, + CachePoint, DocumentUrl, FilePart, FinishReason, @@ -447,6 +448,9 @@ async def _map_user_prompt(part: UserPromptPart) -> ChatCompletionInputMessage: raise NotImplementedError('DocumentUrl is not supported for Hugging Face') elif isinstance(item, VideoUrl): raise NotImplementedError('VideoUrl is not supported for Hugging Face') + elif isinstance(item, CachePoint): + # Hugging Face doesn't support prompt caching via CachePoint + pass else: assert_never(item) return ChatCompletionInputMessage(role='user', content=content) # type: ignore diff --git a/pydantic_ai_slim/pydantic_ai/models/openai.py b/pydantic_ai_slim/pydantic_ai/models/openai.py index ed1e711823..ae4ca51122 100644 --- a/pydantic_ai_slim/pydantic_ai/models/openai.py +++ b/pydantic_ai_slim/pydantic_ai/models/openai.py @@ -26,6 +26,7 @@ BinaryImage, BuiltinToolCallPart, BuiltinToolReturnPart, + CachePoint, DocumentUrl, FilePart, FinishReason, @@ -860,6 +861,9 @@ async def _map_user_prompt(self, part: UserPromptPart) -> chat.ChatCompletionUse ) elif isinstance(item, VideoUrl): # pragma: no cover raise NotImplementedError('VideoUrl is not supported for OpenAI') + elif isinstance(item, CachePoint): + # OpenAI doesn't support prompt caching via CachePoint, so we filter it out + pass else: assert_never(item) return chat.ChatCompletionUserMessageParam(role='user', content=content) @@ -1598,7 +1602,7 @@ def _map_json_schema(self, o: OutputObjectDefinition) -> responses.ResponseForma return response_format_param @staticmethod - async def _map_user_prompt(part: UserPromptPart) -> responses.EasyInputMessageParam: + async def _map_user_prompt(part: UserPromptPart) -> responses.EasyInputMessageParam: # noqa: C901 content: str | list[responses.ResponseInputContentParam] if isinstance(part.content, str): content = part.content @@ -1673,6 +1677,9 @@ async def _map_user_prompt(part: UserPromptPart) -> responses.EasyInputMessagePa ) elif isinstance(item, VideoUrl): # pragma: no cover raise NotImplementedError('VideoUrl is not supported for OpenAI.') + elif isinstance(item, CachePoint): + # OpenAI doesn't support prompt caching via CachePoint, so we filter it out + pass else: assert_never(item) return responses.EasyInputMessageParam(role='user', content=content) diff --git a/tests/models/test_anthropic.py b/tests/models/test_anthropic.py index b7b66404e0..397a8e0979 100644 --- a/tests/models/test_anthropic.py +++ b/tests/models/test_anthropic.py @@ -20,6 +20,7 @@ BinaryContent, BuiltinToolCallPart, BuiltinToolReturnPart, + CachePoint, DocumentUrl, FinalResultEvent, ImageUrl, @@ -292,6 +293,262 @@ async def test_async_request_prompt_caching(allow_model_requests: None): assert last_message.cost().total_price == snapshot(Decimal('0.00002688')) +async def test_cache_point_adds_cache_control(allow_model_requests: None): + """Test that CachePoint correctly adds cache_control to content blocks.""" + c = completion_message( + [BetaTextBlock(text='response', type='text')], + usage=BetaUsage(input_tokens=3, output_tokens=5), + ) + mock_client = MockAnthropic.create_mock(c) + m = AnthropicModel('claude-haiku-4-5', provider=AnthropicProvider(anthropic_client=mock_client)) + agent = Agent(m) + + # Test with CachePoint after text content + await agent.run(['Some context to cache', CachePoint(), 'Now the question']) + + # Verify cache_control was added to the right content block + completion_kwargs = get_mock_chat_completion_kwargs(mock_client)[0] + messages = completion_kwargs['messages'] + assert messages == snapshot( + [ + { + 'role': 'user', + 'content': [ + {'text': 'Some context to cache', 'type': 'text', 'cache_control': {'type': 'ephemeral'}}, + {'text': 'Now the question', 'type': 'text'}, + ], + } + ] + ) + + +async def test_cache_point_multiple_markers(allow_model_requests: None): + """Test multiple CachePoint markers in a single prompt.""" + c = completion_message( + [BetaTextBlock(text='response', type='text')], + usage=BetaUsage(input_tokens=3, output_tokens=5), + ) + mock_client = MockAnthropic.create_mock(c) + m = AnthropicModel('claude-haiku-4-5', provider=AnthropicProvider(anthropic_client=mock_client)) + agent = Agent(m) + + await agent.run(['First chunk', CachePoint(), 'Second chunk', CachePoint(), 'Question']) + + completion_kwargs = get_mock_chat_completion_kwargs(mock_client)[0] + content = completion_kwargs['messages'][0]['content'] + + assert content == snapshot( + [ + {'text': 'First chunk', 'type': 'text', 'cache_control': {'type': 'ephemeral'}}, + {'text': 'Second chunk', 'type': 'text', 'cache_control': {'type': 'ephemeral'}}, + {'text': 'Question', 'type': 'text'}, + ] + ) + + +async def test_cache_point_as_first_content_raises_error(allow_model_requests: None): + """Test that CachePoint as first content raises UserError.""" + c = completion_message( + [BetaTextBlock(text='response', type='text')], + usage=BetaUsage(input_tokens=3, output_tokens=5), + ) + mock_client = MockAnthropic.create_mock(c) + m = AnthropicModel('claude-haiku-4-5', provider=AnthropicProvider(anthropic_client=mock_client)) + agent = Agent(m) + + with pytest.raises( + UserError, + match='CachePoint cannot be the first content in a user message - there must be previous content to attach the CachePoint to.', + ): + await agent.run([CachePoint(), 'This should fail']) + + +async def test_cache_point_with_image_content(allow_model_requests: None): + """Test CachePoint works with image content.""" + c = completion_message( + [BetaTextBlock(text='response', type='text')], + usage=BetaUsage(input_tokens=3, output_tokens=5), + ) + mock_client = MockAnthropic.create_mock(c) + m = AnthropicModel('claude-haiku-4-5', provider=AnthropicProvider(anthropic_client=mock_client)) + agent = Agent(m) + + await agent.run( + [ + ImageUrl('https://example.com/image.jpg'), + CachePoint(), + 'What is in this image?', + ] + ) + + completion_kwargs = get_mock_chat_completion_kwargs(mock_client)[0] + content = completion_kwargs['messages'][0]['content'] + + assert content == snapshot( + [ + { + 'source': {'type': 'url', 'url': 'https://example.com/image.jpg'}, + 'type': 'image', + 'cache_control': {'type': 'ephemeral'}, + }, + {'text': 'What is in this image?', 'type': 'text'}, + ] + ) + + +async def test_cache_point_in_otel_message_parts(allow_model_requests: None): + """Test that CachePoint is handled correctly in otel message parts conversion.""" + from pydantic_ai.agent import InstrumentationSettings + from pydantic_ai.messages import UserPromptPart + + # Create a UserPromptPart with CachePoint + part = UserPromptPart(content=['text before', CachePoint(), 'text after']) + + # Convert to otel message parts + settings = InstrumentationSettings(include_content=True) + otel_parts = part.otel_message_parts(settings) + + # Should have 2 text parts, CachePoint is skipped + assert otel_parts == snapshot( + [{'type': 'text', 'content': 'text before'}, {'type': 'text', 'content': 'text after'}] + ) + + +def test_cache_control_unsupported_param_type(): + """Test that cache control raises error for unsupported param types.""" + + from pydantic_ai.exceptions import UserError + from pydantic_ai.models.anthropic import AnthropicModel + + # Create a list with an unsupported param type (document) + # We'll use a mock document block param + params: list[dict[str, Any]] = [{'type': 'document', 'source': {'data': 'test'}}] + + with pytest.raises(UserError, match='Cache control not supported for param type: document'): + AnthropicModel._add_cache_control_to_last_param(params) # type: ignore[arg-type] # Testing internal method + + +async def test_anthropic_cache_tools(allow_model_requests: None): + """Test that anthropic_cache_tools adds cache_control to last tool.""" + c = completion_message( + [BetaTextBlock(text='Tool result', type='text')], + usage=BetaUsage(input_tokens=10, output_tokens=5), + ) + mock_client = MockAnthropic.create_mock(c) + m = AnthropicModel('claude-haiku-4-5', provider=AnthropicProvider(anthropic_client=mock_client)) + agent = Agent( + m, + system_prompt='Test system prompt', + model_settings=AnthropicModelSettings(anthropic_cache_tools=True), + ) + + @agent.tool_plain + def tool_one() -> str: + return 'one' + + @agent.tool_plain + def tool_two() -> str: + return 'two' + + await agent.run('test prompt') + + # Verify cache_control was added to the last tool + completion_kwargs = get_mock_chat_completion_kwargs(mock_client)[0] + tools = completion_kwargs['tools'] + assert tools == snapshot( + [ + { + 'name': 'tool_one', + 'description': '', + 'input_schema': {'additionalProperties': False, 'properties': {}, 'type': 'object'}, + }, + { + 'name': 'tool_two', + 'description': '', + 'input_schema': {'additionalProperties': False, 'properties': {}, 'type': 'object'}, + 'cache_control': {'type': 'ephemeral'}, + }, + ] + ) + + +async def test_anthropic_cache_instructions(allow_model_requests: None): + """Test that anthropic_cache_instructions adds cache_control to system prompt.""" + c = completion_message( + [BetaTextBlock(text='Response', type='text')], + usage=BetaUsage(input_tokens=10, output_tokens=5), + ) + mock_client = MockAnthropic.create_mock(c) + m = AnthropicModel('claude-haiku-4-5', provider=AnthropicProvider(anthropic_client=mock_client)) + agent = Agent( + m, + system_prompt='This is a test system prompt with instructions.', + model_settings=AnthropicModelSettings(anthropic_cache_instructions=True), + ) + + await agent.run('test prompt') + + # Verify system is a list with cache_control on last block + completion_kwargs = get_mock_chat_completion_kwargs(mock_client)[0] + system = completion_kwargs['system'] + assert system == snapshot( + [ + { + 'type': 'text', + 'text': 'This is a test system prompt with instructions.', + 'cache_control': {'type': 'ephemeral'}, + } + ] + ) + + +async def test_anthropic_cache_tools_and_instructions(allow_model_requests: None): + """Test that both cache settings work together.""" + c = completion_message( + [BetaTextBlock(text='Response', type='text')], + usage=BetaUsage(input_tokens=10, output_tokens=5), + ) + mock_client = MockAnthropic.create_mock(c) + m = AnthropicModel('claude-haiku-4-5', provider=AnthropicProvider(anthropic_client=mock_client)) + agent = Agent( + m, + system_prompt='System instructions to cache.', + model_settings=AnthropicModelSettings( + anthropic_cache_tools=True, + anthropic_cache_instructions=True, + ), + ) + + @agent.tool_plain + def my_tool(value: str) -> str: + return f'Result: {value}' + + await agent.run('test prompt') + + # Verify both have cache_control + completion_kwargs = get_mock_chat_completion_kwargs(mock_client)[0] + tools = completion_kwargs['tools'] + system = completion_kwargs['system'] + assert tools == snapshot( + [ + { + 'name': 'my_tool', + 'description': '', + 'input_schema': { + 'additionalProperties': False, + 'properties': {'value': {'type': 'string'}}, + 'required': ['value'], + 'type': 'object', + }, + 'cache_control': {'type': 'ephemeral'}, + } + ] + ) + assert system == snapshot( + [{'type': 'text', 'text': 'System instructions to cache.', 'cache_control': {'type': 'ephemeral'}}] + ) + + async def test_async_request_text_response(allow_model_requests: None): c = completion_message( [BetaTextBlock(text='world', type='text')], @@ -4695,14 +4952,14 @@ async def test_anthropic_empty_content_filtering(env: TestEnv): messages_empty_string: list[ModelMessage] = [ ModelRequest(parts=[UserPromptPart(content='')], kind='request'), ] - _, anthropic_messages = await model._map_message(messages_empty_string, ModelRequestParameters()) # type: ignore[attr-defined] + _, anthropic_messages = await model._map_message(messages_empty_string, ModelRequestParameters(), {}) # type: ignore[attr-defined] assert anthropic_messages == snapshot([]) # Empty content should be filtered out # Test _map_message with list containing empty strings in user prompt messages_mixed_content: list[ModelMessage] = [ ModelRequest(parts=[UserPromptPart(content=['', 'Hello', '', 'World'])], kind='request'), ] - _, anthropic_messages = await model._map_message(messages_mixed_content, ModelRequestParameters()) # type: ignore[attr-defined] + _, anthropic_messages = await model._map_message(messages_mixed_content, ModelRequestParameters(), {}) # type: ignore[attr-defined] assert anthropic_messages == snapshot( [{'role': 'user', 'content': [{'text': 'Hello', 'type': 'text'}, {'text': 'World', 'type': 'text'}]}] ) @@ -4713,7 +4970,7 @@ async def test_anthropic_empty_content_filtering(env: TestEnv): ModelResponse(parts=[TextPart(content='')], kind='response'), # Empty response ModelRequest(parts=[UserPromptPart(content='Hello')], kind='request'), ] - _, anthropic_messages = await model._map_message(messages, ModelRequestParameters()) # type: ignore[attr-defined] + _, anthropic_messages = await model._map_message(messages, ModelRequestParameters(), {}) # type: ignore[attr-defined] # The empty assistant message should be filtered out assert anthropic_messages == snapshot([{'role': 'user', 'content': [{'text': 'Hello', 'type': 'text'}]}]) @@ -4721,7 +4978,7 @@ async def test_anthropic_empty_content_filtering(env: TestEnv): messages_resp: list[ModelMessage] = [ ModelResponse(parts=[TextPart(content=''), TextPart(content='')], kind='response'), ] - _, anthropic_messages = await model._map_message(messages_resp, ModelRequestParameters()) # type: ignore[attr-defined] + _, anthropic_messages = await model._map_message(messages_resp, ModelRequestParameters(), {}) # type: ignore[attr-defined] assert len(anthropic_messages) == 0 # No messages should be added diff --git a/tests/models/test_bedrock.py b/tests/models/test_bedrock.py index ddb60ebf4e..cce18a9227 100644 --- a/tests/models/test_bedrock.py +++ b/tests/models/test_bedrock.py @@ -9,6 +9,7 @@ from pydantic_ai import ( BinaryContent, + CachePoint, DocumentUrl, FinalResultEvent, FunctionToolCallEvent, @@ -35,6 +36,10 @@ from pydantic_ai.exceptions import ModelHTTPError, ModelRetry, UsageLimitExceeded from pydantic_ai.messages import AgentStreamEvent from pydantic_ai.models import ModelRequestParameters +from pydantic_ai.models.bedrock import BedrockConverseModel, BedrockModelSettings +from pydantic_ai.models.openai import OpenAIResponsesModel, OpenAIResponsesModelSettings +from pydantic_ai.providers.bedrock import BedrockProvider +from pydantic_ai.providers.openai import OpenAIProvider from pydantic_ai.run import AgentRunResult, AgentRunResultEvent from pydantic_ai.tools import ToolDefinition from pydantic_ai.usage import RequestUsage, RunUsage, UsageLimits @@ -42,10 +47,7 @@ from ..conftest import IsDatetime, IsInstance, IsStr, try_import with try_import() as imports_successful: - from pydantic_ai.models.bedrock import BedrockConverseModel, BedrockModelSettings - from pydantic_ai.models.openai import OpenAIResponsesModel, OpenAIResponsesModelSettings - from pydantic_ai.providers.bedrock import BedrockProvider - from pydantic_ai.providers.openai import OpenAIProvider + pass pytestmark = [ pytest.mark.skipif(not imports_successful(), reason='bedrock not installed'), @@ -1511,3 +1513,14 @@ async def test_bedrock_streaming_error(allow_model_requests: None, bedrock_provi assert exc_info.value.status_code == 400 assert exc_info.value.model_name == model_id assert exc_info.value.body.get('Error', {}).get('Message') == 'The provided model identifier is invalid.' # type: ignore[union-attr] + + +async def test_cache_point_filtering(): + """Test that CachePoint is filtered out in Bedrock message mapping.""" + from itertools import count + + # Test the static method directly + messages = await BedrockConverseModel._map_user_prompt(UserPromptPart(content=['text', CachePoint()]), count()) # pyright: ignore[reportPrivateUsage] + # CachePoint should be filtered out, message should still be valid + assert len(messages) == 1 + assert messages[0]['role'] == 'user' diff --git a/tests/models/test_google.py b/tests/models/test_google.py index 82332f38ef..5bda43c7b3 100644 --- a/tests/models/test_google.py +++ b/tests/models/test_google.py @@ -3201,3 +3201,19 @@ def _generate_response_with_texts(response_id: str, texts: list[str]) -> Generat ], } ) + + +async def test_cache_point_filtering(): + """Test that CachePoint is filtered out in Google internal method.""" + from pydantic_ai import CachePoint + + # Create a minimal GoogleModel instance to test _map_user_prompt + model = GoogleModel('gemini-1.5-flash', provider=GoogleProvider(api_key='test-key')) + + # Test that CachePoint in a list is handled (triggers line 606) + content = await model._map_user_prompt(UserPromptPart(content=['text before', CachePoint(), 'text after'])) # pyright: ignore[reportPrivateUsage] + + # CachePoint should be filtered out, only text content should remain + assert len(content) == 2 + assert content[0] == {'text': 'text before'} + assert content[1] == {'text': 'text after'} diff --git a/tests/models/test_huggingface.py b/tests/models/test_huggingface.py index 3bbb0d3e7b..16f7d01a1b 100644 --- a/tests/models/test_huggingface.py +++ b/tests/models/test_huggingface.py @@ -8,7 +8,25 @@ from typing import Any, Literal, cast from unittest.mock import Mock +import aiohttp import pytest +from huggingface_hub import ( + AsyncInferenceClient, + ChatCompletionInputMessage, + ChatCompletionOutput, + ChatCompletionOutputComplete, + ChatCompletionOutputFunctionDefinition, + ChatCompletionOutputMessage, + ChatCompletionOutputToolCall, + ChatCompletionOutputUsage, + ChatCompletionStreamOutput, + ChatCompletionStreamOutputChoice, + ChatCompletionStreamOutputDelta, + ChatCompletionStreamOutputDeltaToolCall, + ChatCompletionStreamOutputFunction, + ChatCompletionStreamOutputUsage, +) +from huggingface_hub.errors import HfHubHTTPError from inline_snapshot import snapshot from typing_extensions import TypedDict @@ -16,6 +34,7 @@ Agent, AudioUrl, BinaryContent, + CachePoint, DocumentUrl, ImageUrl, ModelRequest, @@ -31,6 +50,8 @@ VideoUrl, ) from pydantic_ai.exceptions import ModelHTTPError +from pydantic_ai.models.huggingface import HuggingFaceModel +from pydantic_ai.providers.huggingface import HuggingFaceProvider from pydantic_ai.result import RunUsage from pydantic_ai.run import AgentRunResult, AgentRunResultEvent from pydantic_ai.settings import ModelSettings @@ -41,30 +62,10 @@ from .mock_async_stream import MockAsyncStream with try_import() as imports_successful: - import aiohttp - from huggingface_hub import ( - AsyncInferenceClient, - ChatCompletionInputMessage, - ChatCompletionOutput, - ChatCompletionOutputComplete, - ChatCompletionOutputFunctionDefinition, - ChatCompletionOutputMessage, - ChatCompletionOutputToolCall, - ChatCompletionOutputUsage, - ChatCompletionStreamOutput, - ChatCompletionStreamOutputChoice, - ChatCompletionStreamOutputDelta, - ChatCompletionStreamOutputDeltaToolCall, - ChatCompletionStreamOutputFunction, - ChatCompletionStreamOutputUsage, - ) - from huggingface_hub.errors import HfHubHTTPError - - from pydantic_ai.models.huggingface import HuggingFaceModel - from pydantic_ai.providers.huggingface import HuggingFaceProvider - - MockChatCompletion = ChatCompletionOutput | Exception - MockStreamEvent = ChatCompletionStreamOutput | Exception + pass + +MockChatCompletion = ChatCompletionOutput | Exception +MockStreamEvent = ChatCompletionStreamOutput | Exception pytestmark = [ pytest.mark.skipif(not imports_successful(), reason='huggingface_hub not installed'), @@ -1016,3 +1017,13 @@ async def test_hf_model_thinking_part_iter(allow_model_requests: None, huggingfa ), ] ) + + +async def test_cache_point_filtering(): + """Test that CachePoint is filtered out in HuggingFace message mapping.""" + # Test the static method directly + msg = await HuggingFaceModel._map_user_prompt(UserPromptPart(content=['text', CachePoint()])) # pyright: ignore[reportPrivateUsage] + + # CachePoint should be filtered out + assert msg['role'] == 'user' + assert len(msg['content']) == 1 # pyright: ignore[reportUnknownArgumentType] diff --git a/tests/models/test_instrumented.py b/tests/models/test_instrumented.py index 8e498188ef..b6a52e0c25 100644 --- a/tests/models/test_instrumented.py +++ b/tests/models/test_instrumented.py @@ -17,6 +17,7 @@ BinaryContent, BuiltinToolCallPart, BuiltinToolReturnPart, + CachePoint, DocumentUrl, FilePart, FinalResultEvent, @@ -1615,3 +1616,78 @@ def test_message_with_builtin_tool_calls(): } ] ) + + +def test_cache_point_in_user_prompt(): + """Test that CachePoint is correctly skipped in OpenTelemetry conversion. + + CachePoint is a marker for prompt caching and should not be included in the + OpenTelemetry message parts output. + """ + messages: list[ModelMessage] = [ + ModelRequest(parts=[UserPromptPart(content=['text before', CachePoint(), 'text after'])]), + ] + settings = InstrumentationSettings() + + # Test otel_message_parts - CachePoint should be skipped + assert settings.messages_to_otel_messages(messages) == snapshot( + [ + { + 'role': 'user', + 'parts': [ + {'type': 'text', 'content': 'text before'}, + {'type': 'text', 'content': 'text after'}, + ], + } + ] + ) + + # Test with multiple CachePoints + messages_multi: list[ModelMessage] = [ + ModelRequest( + parts=[ + UserPromptPart(content=['first', CachePoint(), 'second', CachePoint(), 'third']), + ] + ), + ] + assert settings.messages_to_otel_messages(messages_multi) == snapshot( + [ + { + 'role': 'user', + 'parts': [ + {'type': 'text', 'content': 'first'}, + {'type': 'text', 'content': 'second'}, + {'type': 'text', 'content': 'third'}, + ], + } + ] + ) + + # Test with CachePoint mixed with other content types + messages_mixed: list[ModelMessage] = [ + ModelRequest( + parts=[ + UserPromptPart( + content=[ + 'context', + CachePoint(), + ImageUrl('https://example.com/image.jpg'), + CachePoint(), + 'question', + ] + ), + ] + ), + ] + assert settings.messages_to_otel_messages(messages_mixed) == snapshot( + [ + { + 'role': 'user', + 'parts': [ + {'type': 'text', 'content': 'context'}, + {'type': 'image-url', 'url': 'https://example.com/image.jpg'}, + {'type': 'text', 'content': 'question'}, + ], + } + ] + ) diff --git a/tests/models/test_openai.py b/tests/models/test_openai.py index 0181437cff..e68c64abe3 100644 --- a/tests/models/test_openai.py +++ b/tests/models/test_openai.py @@ -17,6 +17,7 @@ Agent, AudioUrl, BinaryContent, + CachePoint, DocumentUrl, ImageUrl, ModelHTTPError, @@ -3054,3 +3055,33 @@ def test_deprecated_openai_model(openai_api_key: str): provider = OpenAIProvider(api_key=openai_api_key) OpenAIModel('gpt-4o', provider=provider) # type: ignore[reportDeprecated] + + +async def test_cache_point_filtering(allow_model_requests: None): + """Test that CachePoint is filtered out in OpenAI Chat Completions requests.""" + c = completion_message(ChatCompletionMessage(content='response', role='assistant')) + mock_client = MockOpenAI.create_mock(c) + m = OpenAIChatModel('gpt-4o', provider=OpenAIProvider(openai_client=mock_client)) + + # Test the instance method directly to trigger line 864 + msg = await m._map_user_prompt(UserPromptPart(content=['text before', CachePoint(), 'text after'])) # pyright: ignore[reportPrivateUsage] + + # CachePoint should be filtered out, only text content should remain + assert msg['role'] == 'user' + assert len(msg['content']) == 2 # type: ignore[reportUnknownArgumentType] + assert msg['content'][0]['text'] == 'text before' # type: ignore[reportUnknownArgumentType] + assert msg['content'][1]['text'] == 'text after' # type: ignore[reportUnknownArgumentType] + + +async def test_cache_point_filtering_responses_model(): + """Test that CachePoint is filtered out in OpenAI Responses API requests.""" + # Test the static method directly to trigger line 1680 + msg = await OpenAIResponsesModel._map_user_prompt( # pyright: ignore[reportPrivateUsage] + UserPromptPart(content=['text before', CachePoint(), 'text after']) + ) + + # CachePoint should be filtered out, only text content should remain + assert msg['role'] == 'user' + assert len(msg['content']) == 2 + assert msg['content'][0]['text'] == 'text before' # type: ignore[reportUnknownArgumentType] + assert msg['content'][1]['text'] == 'text after' # type: ignore[reportUnknownArgumentType]