diff --git a/libs/partners/anthropic/langchain_anthropic/middleware/prompt_caching.py b/libs/partners/anthropic/langchain_anthropic/middleware/prompt_caching.py index 77b4dc9d85c4a..6f0aba9173a95 100644 --- a/libs/partners/anthropic/langchain_anthropic/middleware/prompt_caching.py +++ b/libs/partners/anthropic/langchain_anthropic/middleware/prompt_caching.py @@ -102,6 +102,17 @@ def _apply_cache_control(self, request: ModelRequest) -> None: """ request.model_settings["cache_control"] = {"type": self.type, "ttl": self.ttl} + def _remove_cache_control(self, request: ModelRequest) -> None: + """Remove cache control settings from the request. + + This reverses the changes made by _apply_cache_control(). + Useful when a fallback to a non-Anthropic model occurs at runtime. + + Args: + request: The model request to clean up. + """ + request.model_settings.pop("cache_control", None) + def wrap_model_call( self, request: ModelRequest, @@ -117,10 +128,27 @@ def wrap_model_call( The model response from the handler. """ if not self._should_apply_caching(request): + # Model is not Anthropic - ensure cache_control is removed + # This handles the case where ModelFallbackMiddleware switched + # to a non-Anthropic model after we already applied cache_control + self._remove_cache_control(request) return handler(request) + # Optimistically apply caching (works for Anthropic models) self._apply_cache_control(request) - return handler(request) + + try: + return handler(request) + except TypeError as e: + # Check if error is specifically about cache_control parameter + # This can occur when ModelFallbackMiddleware switches to a + # non-Anthropic model at runtime (e.g., Anthropic → OpenAI) + if "cache_control" in str(e): + # Remove Anthropic-specific settings and retry without caching + self._remove_cache_control(request) + return handler(request) + # Different TypeError - re-raise + raise async def awrap_model_call( self, @@ -137,7 +165,24 @@ async def awrap_model_call( The model response from the handler. """ if not self._should_apply_caching(request): + # Model is not Anthropic - ensure cache_control is removed + # This handles the case where ModelFallbackMiddleware switched + # to a non-Anthropic model after we already applied cache_control + self._remove_cache_control(request) return await handler(request) + # Optimistically apply caching (works for Anthropic models) self._apply_cache_control(request) - return await handler(request) + + try: + return await handler(request) + except TypeError as e: + # Check if error is specifically about cache_control parameter + # This can occur when ModelFallbackMiddleware switches to a + # non-Anthropic model at runtime (e.g., Anthropic → OpenAI) + if "cache_control" in str(e): + # Remove Anthropic-specific settings and retry without caching + self._remove_cache_control(request) + return await handler(request) + # Different TypeError - re-raise + raise diff --git a/libs/partners/anthropic/tests/integration_tests/test_caching_middleware.py b/libs/partners/anthropic/tests/integration_tests/test_caching_middleware.py new file mode 100644 index 0000000000000..5a8165bade878 --- /dev/null +++ b/libs/partners/anthropic/tests/integration_tests/test_caching_middleware.py @@ -0,0 +1,68 @@ +from typing import Any + +import pytest +from langchain.agents import create_agent +from langchain.agents.middleware import ModelFallbackMiddleware +from langchain_core.messages import HumanMessage +from langgraph.checkpoint.memory import MemorySaver + +from langchain_anthropic import ChatAnthropic +from langchain_anthropic.middleware import AnthropicPromptCachingMiddleware + + +class TestMiddlewareWithFallback: + """Test middleware behavior with actual model fallback scenarios.""" + + def test_fallback_to_openai_without_error(self) -> None: + """Test that fallback to OpenAI works without cache_control errors.""" + # This test requires API keys but can be mocked for CI + # Skip if API keys not available + pytest.importorskip("langchain_openai") + + from langchain_openai import ChatOpenAI # type: ignore[import-not-found] + + # Create agent with fallback middleware + # Use invalid Anthropic model to force immediate fallback + agent: Any = create_agent( + model=ChatAnthropic(model="invalid-model-name", api_key="invalid"), + checkpointer=MemorySaver(), + middleware=[ + ModelFallbackMiddleware( + ChatOpenAI(model="gpt-4o-mini") # Will fallback to this + ), + AnthropicPromptCachingMiddleware(ttl="5m"), + ], + ) + + # This should not raise TypeError about cache_control + # It will fail due to invalid API keys, but that's expected + # The key is that it doesn't fail with cache_control error + config: Any = {"configurable": {"thread_id": "test"}} + + with pytest.raises(Exception) as exc_info: + agent.invoke({"messages": [HumanMessage(content="Hello")]}, config) + # Should not be a cache_control TypeError + assert "cache_control" not in str(exc_info.value).lower() + + @pytest.mark.asyncio + async def test_async_fallback_to_openai(self) -> None: + """Test async version with fallback to OpenAI.""" + pytest.importorskip("langchain_openai") + + from langchain_openai import ChatOpenAI # type: ignore[import-not-found] + + agent: Any = create_agent( + model=ChatAnthropic(model="invalid-model", api_key="invalid"), + checkpointer=MemorySaver(), + middleware=[ + ModelFallbackMiddleware(ChatOpenAI(model="gpt-4o-mini")), + AnthropicPromptCachingMiddleware(ttl="5m"), + ], + ) + + config: Any = {"configurable": {"thread_id": "test"}} + + with pytest.raises(Exception) as exc_info: + await agent.ainvoke({"messages": [HumanMessage(content="Hello")]}, config) + # Should not be a cache_control TypeError + assert "cache_control" not in str(exc_info.value).lower() diff --git a/libs/partners/anthropic/tests/unit_tests/middleware/test_prompt_caching.py b/libs/partners/anthropic/tests/unit_tests/middleware/test_prompt_caching.py index e06ebebaf5369..537d360220dba 100644 --- a/libs/partners/anthropic/tests/unit_tests/middleware/test_prompt_caching.py +++ b/libs/partners/anthropic/tests/unit_tests/middleware/test_prompt_caching.py @@ -309,3 +309,201 @@ async def mock_handler(req: ModelRequest) -> ModelResponse: assert fake_request.model_settings == { "cache_control": {"type": "ephemeral", "ttl": "5m"} } + + +def test_remove_cache_control_when_present() -> None: + """Test that _remove_cache_control removes cache_control from model_settings.""" + middleware = AnthropicPromptCachingMiddleware() + + fake_request = ModelRequest( + model=MagicMock(spec=ChatAnthropic), + messages=[HumanMessage("Hello")], + system_prompt=None, + tool_choice=None, + tools=[], + response_format=None, + state={"messages": [HumanMessage("Hello")]}, + runtime=cast(Runtime, object()), + model_settings={"cache_control": {"type": "ephemeral", "ttl": "5m"}}, + ) + + assert "cache_control" in fake_request.model_settings + middleware._remove_cache_control(fake_request) + assert "cache_control" not in fake_request.model_settings + + +def test_remove_cache_control_safe_when_absent() -> None: + """Test that _remove_cache_control is safe when cache_control is not present.""" + middleware = AnthropicPromptCachingMiddleware() + + fake_request = ModelRequest( + model=FakeToolCallingModel(), + messages=[HumanMessage("Hello")], + system_prompt=None, + tool_choice=None, + tools=[], + response_format=None, + state={"messages": [HumanMessage("Hello")]}, + runtime=cast(Runtime, object()), + model_settings={}, # Empty, no cache_control + ) + + # Should not raise an error + middleware._remove_cache_control(fake_request) + assert "cache_control" not in fake_request.model_settings + + +def test_wrap_model_call_cleans_up_for_non_anthropic_model() -> None: + """Test cache_control removal on fallback to non-Anthropic model.""" + middleware = AnthropicPromptCachingMiddleware() + + # Simulate post-fallback state: non-Anthropic model with cache_control present + fake_request = ModelRequest( + model=FakeToolCallingModel(), # Non-Anthropic model + messages=[HumanMessage("Hello")], + system_prompt=None, + tool_choice=None, + tools=[], + response_format=None, + state={"messages": [HumanMessage("Hello")]}, + runtime=cast(Runtime, object()), + model_settings={ + "cache_control": {"type": "ephemeral", "ttl": "5m"} + }, # Present from earlier + ) + + def mock_handler(req: ModelRequest) -> ModelResponse: + # Verify cache_control was removed before handler is called + assert "cache_control" not in req.model_settings + return ModelResponse(result=[AIMessage(content="response")]) + + response = middleware.wrap_model_call(fake_request, mock_handler) + assert isinstance(response, ModelResponse) + assert "cache_control" not in fake_request.model_settings + + +def test_wrap_model_call_recovers_from_cache_control_type_error() -> None: + """Test that middleware recovers from cache_control TypeError and retries.""" + middleware = AnthropicPromptCachingMiddleware() + + fake_request = ModelRequest( + model=MagicMock(spec=ChatAnthropic), + messages=[HumanMessage("Hello")], + system_prompt=None, + tool_choice=None, + tools=[], + response_format=None, + state={"messages": [HumanMessage("Hello")]}, + runtime=cast(Runtime, object()), + model_settings={}, + ) + + call_count = 0 + mock_response = ModelResponse(result=[AIMessage(content="response")]) + + def mock_handler(req: ModelRequest) -> ModelResponse: + nonlocal call_count + call_count += 1 + if call_count == 1: + # First call: simulate cache_control error + msg = ( + "Completions.create() got an unexpected keyword argument " + "'cache_control'" + ) + raise TypeError(msg) + # Second call: succeed + return mock_response + + response = middleware.wrap_model_call(fake_request, mock_handler) + + # Verify handler was called twice (original + retry) + assert call_count == 2 + # Verify cache_control was removed on retry + assert "cache_control" not in fake_request.model_settings + assert response == mock_response + + +def test_wrap_model_call_reraises_non_cache_control_type_error() -> None: + """Test that non-cache_control TypeErrors are re-raised.""" + middleware = AnthropicPromptCachingMiddleware() + + fake_request = ModelRequest( + model=MagicMock(spec=ChatAnthropic), + messages=[HumanMessage("Hello")], + system_prompt=None, + tool_choice=None, + tools=[], + response_format=None, + state={"messages": [HumanMessage("Hello")]}, + runtime=cast(Runtime, object()), + model_settings={}, + ) + + def mock_handler(req: ModelRequest) -> ModelResponse: + msg = "Some other type error" + raise TypeError(msg) + + # Should re-raise the error + with pytest.raises(TypeError, match="Some other type error"): + middleware.wrap_model_call(fake_request, mock_handler) + + +async def test_awrap_model_call_cleans_up_for_non_anthropic_model() -> None: + """Test that async version also cleans up cache_control.""" + middleware = AnthropicPromptCachingMiddleware() + + fake_request = ModelRequest( + model=FakeToolCallingModel(), + messages=[HumanMessage("Hello")], + system_prompt=None, + tool_choice=None, + tools=[], + response_format=None, + state={"messages": [HumanMessage("Hello")]}, + runtime=cast(Runtime, object()), + model_settings={"cache_control": {"type": "ephemeral", "ttl": "5m"}}, + ) + + async def mock_handler(req: ModelRequest) -> ModelResponse: + # Verify cache_control was removed before handler is called + assert "cache_control" not in req.model_settings + return ModelResponse(result=[AIMessage(content="response")]) + + response = await middleware.awrap_model_call(fake_request, mock_handler) + assert isinstance(response, ModelResponse) + assert "cache_control" not in fake_request.model_settings + + +async def test_awrap_model_call_recovers_from_type_error() -> None: + """Test that async version recovers from cache_control TypeError.""" + middleware = AnthropicPromptCachingMiddleware() + + fake_request = ModelRequest( + model=MagicMock(spec=ChatAnthropic), + messages=[HumanMessage("Hello")], + system_prompt=None, + tool_choice=None, + tools=[], + response_format=None, + state={"messages": [HumanMessage("Hello")]}, + runtime=cast(Runtime, object()), + model_settings={}, + ) + + call_count = 0 + mock_response = ModelResponse(result=[AIMessage(content="response")]) + + async def mock_handler(req: ModelRequest) -> ModelResponse: + nonlocal call_count + call_count += 1 + if call_count == 1: + msg = "got an unexpected keyword argument 'cache_control'" + raise TypeError(msg) + return mock_response + + response = await middleware.awrap_model_call(fake_request, mock_handler) + + # Verify retry happened + assert call_count == 2 + assert "cache_control" not in fake_request.model_settings + assert response == mock_response