Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,17 @@ def _apply_cache_control(self, request: ModelRequest) -> None:
"""
request.model_settings["cache_control"] = {"type": self.type, "ttl": self.ttl}

def _remove_cache_control(self, request: ModelRequest) -> None:
"""Remove cache control settings from the request.

This reverses the changes made by _apply_cache_control().
Useful when a fallback to a non-Anthropic model occurs at runtime.

Args:
request: The model request to clean up.
"""
request.model_settings.pop("cache_control", None)

def wrap_model_call(
self,
request: ModelRequest,
Expand All @@ -117,10 +128,27 @@ def wrap_model_call(
The model response from the handler.
"""
if not self._should_apply_caching(request):
# Model is not Anthropic - ensure cache_control is removed
# This handles the case where ModelFallbackMiddleware switched
# to a non-Anthropic model after we already applied cache_control
self._remove_cache_control(request)
return handler(request)

# Optimistically apply caching (works for Anthropic models)
self._apply_cache_control(request)
return handler(request)

try:
return handler(request)
except TypeError as e:
# Check if error is specifically about cache_control parameter
# This can occur when ModelFallbackMiddleware switches to a
# non-Anthropic model at runtime (e.g., Anthropic → OpenAI)
if "cache_control" in str(e):
# Remove Anthropic-specific settings and retry without caching
self._remove_cache_control(request)
return handler(request)
# Different TypeError - re-raise
raise

async def awrap_model_call(
self,
Expand All @@ -137,7 +165,24 @@ async def awrap_model_call(
The model response from the handler.
"""
if not self._should_apply_caching(request):
# Model is not Anthropic - ensure cache_control is removed
# This handles the case where ModelFallbackMiddleware switched
# to a non-Anthropic model after we already applied cache_control
self._remove_cache_control(request)
return await handler(request)

# Optimistically apply caching (works for Anthropic models)
self._apply_cache_control(request)
return await handler(request)

try:
return await handler(request)
except TypeError as e:
# Check if error is specifically about cache_control parameter
# This can occur when ModelFallbackMiddleware switches to a
# non-Anthropic model at runtime (e.g., Anthropic → OpenAI)
if "cache_control" in str(e):
# Remove Anthropic-specific settings and retry without caching
self._remove_cache_control(request)
return await handler(request)
# Different TypeError - re-raise
raise
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
from typing import Any

import pytest
from langchain.agents import create_agent
from langchain.agents.middleware import ModelFallbackMiddleware
from langchain_core.messages import HumanMessage
from langgraph.checkpoint.memory import MemorySaver

from langchain_anthropic import ChatAnthropic
from langchain_anthropic.middleware import AnthropicPromptCachingMiddleware


class TestMiddlewareWithFallback:
"""Test middleware behavior with actual model fallback scenarios."""

def test_fallback_to_openai_without_error(self) -> None:
"""Test that fallback to OpenAI works without cache_control errors."""
# This test requires API keys but can be mocked for CI
# Skip if API keys not available
pytest.importorskip("langchain_openai")

from langchain_openai import ChatOpenAI # type: ignore[import-not-found]

# Create agent with fallback middleware
# Use invalid Anthropic model to force immediate fallback
agent: Any = create_agent(
model=ChatAnthropic(model="invalid-model-name", api_key="invalid"),
checkpointer=MemorySaver(),
middleware=[
ModelFallbackMiddleware(
ChatOpenAI(model="gpt-4o-mini") # Will fallback to this
),
AnthropicPromptCachingMiddleware(ttl="5m"),
],
)

# This should not raise TypeError about cache_control
# It will fail due to invalid API keys, but that's expected
# The key is that it doesn't fail with cache_control error
config: Any = {"configurable": {"thread_id": "test"}}

with pytest.raises(Exception) as exc_info:
agent.invoke({"messages": [HumanMessage(content="Hello")]}, config)
# Should not be a cache_control TypeError
assert "cache_control" not in str(exc_info.value).lower()

@pytest.mark.asyncio
async def test_async_fallback_to_openai(self) -> None:
"""Test async version with fallback to OpenAI."""
pytest.importorskip("langchain_openai")

from langchain_openai import ChatOpenAI # type: ignore[import-not-found]

agent: Any = create_agent(
model=ChatAnthropic(model="invalid-model", api_key="invalid"),
checkpointer=MemorySaver(),
middleware=[
ModelFallbackMiddleware(ChatOpenAI(model="gpt-4o-mini")),
AnthropicPromptCachingMiddleware(ttl="5m"),
],
)

config: Any = {"configurable": {"thread_id": "test"}}

with pytest.raises(Exception) as exc_info:
await agent.ainvoke({"messages": [HumanMessage(content="Hello")]}, config)
# Should not be a cache_control TypeError
assert "cache_control" not in str(exc_info.value).lower()
Original file line number Diff line number Diff line change
Expand Up @@ -309,3 +309,201 @@ async def mock_handler(req: ModelRequest) -> ModelResponse:
assert fake_request.model_settings == {
"cache_control": {"type": "ephemeral", "ttl": "5m"}
}


def test_remove_cache_control_when_present() -> None:
"""Test that _remove_cache_control removes cache_control from model_settings."""
middleware = AnthropicPromptCachingMiddleware()

fake_request = ModelRequest(
model=MagicMock(spec=ChatAnthropic),
messages=[HumanMessage("Hello")],
system_prompt=None,
tool_choice=None,
tools=[],
response_format=None,
state={"messages": [HumanMessage("Hello")]},
runtime=cast(Runtime, object()),
model_settings={"cache_control": {"type": "ephemeral", "ttl": "5m"}},
)

assert "cache_control" in fake_request.model_settings
middleware._remove_cache_control(fake_request)
assert "cache_control" not in fake_request.model_settings


def test_remove_cache_control_safe_when_absent() -> None:
"""Test that _remove_cache_control is safe when cache_control is not present."""
middleware = AnthropicPromptCachingMiddleware()

fake_request = ModelRequest(
model=FakeToolCallingModel(),
messages=[HumanMessage("Hello")],
system_prompt=None,
tool_choice=None,
tools=[],
response_format=None,
state={"messages": [HumanMessage("Hello")]},
runtime=cast(Runtime, object()),
model_settings={}, # Empty, no cache_control
)

# Should not raise an error
middleware._remove_cache_control(fake_request)
assert "cache_control" not in fake_request.model_settings


def test_wrap_model_call_cleans_up_for_non_anthropic_model() -> None:
"""Test cache_control removal on fallback to non-Anthropic model."""
middleware = AnthropicPromptCachingMiddleware()

# Simulate post-fallback state: non-Anthropic model with cache_control present
fake_request = ModelRequest(
model=FakeToolCallingModel(), # Non-Anthropic model
messages=[HumanMessage("Hello")],
system_prompt=None,
tool_choice=None,
tools=[],
response_format=None,
state={"messages": [HumanMessage("Hello")]},
runtime=cast(Runtime, object()),
model_settings={
"cache_control": {"type": "ephemeral", "ttl": "5m"}
}, # Present from earlier
)

def mock_handler(req: ModelRequest) -> ModelResponse:
# Verify cache_control was removed before handler is called
assert "cache_control" not in req.model_settings
return ModelResponse(result=[AIMessage(content="response")])

response = middleware.wrap_model_call(fake_request, mock_handler)
assert isinstance(response, ModelResponse)
assert "cache_control" not in fake_request.model_settings


def test_wrap_model_call_recovers_from_cache_control_type_error() -> None:
"""Test that middleware recovers from cache_control TypeError and retries."""
middleware = AnthropicPromptCachingMiddleware()

fake_request = ModelRequest(
model=MagicMock(spec=ChatAnthropic),
messages=[HumanMessage("Hello")],
system_prompt=None,
tool_choice=None,
tools=[],
response_format=None,
state={"messages": [HumanMessage("Hello")]},
runtime=cast(Runtime, object()),
model_settings={},
)

call_count = 0
mock_response = ModelResponse(result=[AIMessage(content="response")])

def mock_handler(req: ModelRequest) -> ModelResponse:
nonlocal call_count
call_count += 1
if call_count == 1:
# First call: simulate cache_control error
msg = (
"Completions.create() got an unexpected keyword argument "
"'cache_control'"
)
raise TypeError(msg)
# Second call: succeed
return mock_response

response = middleware.wrap_model_call(fake_request, mock_handler)

# Verify handler was called twice (original + retry)
assert call_count == 2
# Verify cache_control was removed on retry
assert "cache_control" not in fake_request.model_settings
assert response == mock_response


def test_wrap_model_call_reraises_non_cache_control_type_error() -> None:
"""Test that non-cache_control TypeErrors are re-raised."""
middleware = AnthropicPromptCachingMiddleware()

fake_request = ModelRequest(
model=MagicMock(spec=ChatAnthropic),
messages=[HumanMessage("Hello")],
system_prompt=None,
tool_choice=None,
tools=[],
response_format=None,
state={"messages": [HumanMessage("Hello")]},
runtime=cast(Runtime, object()),
model_settings={},
)

def mock_handler(req: ModelRequest) -> ModelResponse:
msg = "Some other type error"
raise TypeError(msg)

# Should re-raise the error
with pytest.raises(TypeError, match="Some other type error"):
middleware.wrap_model_call(fake_request, mock_handler)


async def test_awrap_model_call_cleans_up_for_non_anthropic_model() -> None:
"""Test that async version also cleans up cache_control."""
middleware = AnthropicPromptCachingMiddleware()

fake_request = ModelRequest(
model=FakeToolCallingModel(),
messages=[HumanMessage("Hello")],
system_prompt=None,
tool_choice=None,
tools=[],
response_format=None,
state={"messages": [HumanMessage("Hello")]},
runtime=cast(Runtime, object()),
model_settings={"cache_control": {"type": "ephemeral", "ttl": "5m"}},
)

async def mock_handler(req: ModelRequest) -> ModelResponse:
# Verify cache_control was removed before handler is called
assert "cache_control" not in req.model_settings
return ModelResponse(result=[AIMessage(content="response")])

response = await middleware.awrap_model_call(fake_request, mock_handler)
assert isinstance(response, ModelResponse)
assert "cache_control" not in fake_request.model_settings


async def test_awrap_model_call_recovers_from_type_error() -> None:
"""Test that async version recovers from cache_control TypeError."""
middleware = AnthropicPromptCachingMiddleware()

fake_request = ModelRequest(
model=MagicMock(spec=ChatAnthropic),
messages=[HumanMessage("Hello")],
system_prompt=None,
tool_choice=None,
tools=[],
response_format=None,
state={"messages": [HumanMessage("Hello")]},
runtime=cast(Runtime, object()),
model_settings={},
)

call_count = 0
mock_response = ModelResponse(result=[AIMessage(content="response")])

async def mock_handler(req: ModelRequest) -> ModelResponse:
nonlocal call_count
call_count += 1
if call_count == 1:
msg = "got an unexpected keyword argument 'cache_control'"
raise TypeError(msg)
return mock_response

response = await middleware.awrap_model_call(fake_request, mock_handler)

# Verify retry happened
assert call_count == 2
assert "cache_control" not in fake_request.model_settings
assert response == mock_response