diff --git a/libs/langchain_v1/langchain/agents/middleware/__init__.py b/libs/langchain_v1/langchain/agents/middleware/__init__.py index 44389008aa41b..cffdf98283aee 100644 --- a/libs/langchain_v1/langchain/agents/middleware/__init__.py +++ b/libs/langchain_v1/langchain/agents/middleware/__init__.py @@ -1,5 +1,6 @@ """Entrypoint to using [middleware](https://docs.langchain.com/oss/python/langchain/middleware) plugins with [Agents](https://docs.langchain.com/oss/python/langchain/agents).""" # noqa: E501 +from .conditional_model_settings import ConditionalModelSettingsMiddleware from .context_editing import ( ClearToolUsesEdit, ContextEditingMiddleware, @@ -45,6 +46,7 @@ "AgentState", "ClearToolUsesEdit", "CodexSandboxExecutionPolicy", + "ConditionalModelSettingsMiddleware", "ContextEditingMiddleware", "DockerExecutionPolicy", "FilesystemFileSearchMiddleware", diff --git a/libs/langchain_v1/langchain/agents/middleware/conditional_model_settings.py b/libs/langchain_v1/langchain/agents/middleware/conditional_model_settings.py new file mode 100644 index 0000000000000..ddef84eb1bf14 --- /dev/null +++ b/libs/langchain_v1/langchain/agents/middleware/conditional_model_settings.py @@ -0,0 +1,204 @@ +"""Conditional model settings middleware for dynamic configuration.""" + +from __future__ import annotations + +from inspect import iscoroutinefunction +from typing import TYPE_CHECKING, Any + +from langchain.agents.middleware.types import ( + AgentMiddleware, + ModelCallResult, + ModelRequest, + ModelResponse, +) + +if TYPE_CHECKING: + from collections.abc import Awaitable, Callable + + +class _ConditionBuilder: + """Builder for fluent API: middleware.when(condition).use(settings).""" + + def __init__( + self, + middleware: ConditionalModelSettingsMiddleware, + condition: Callable[[ModelRequest], bool | Awaitable[bool]], + ) -> None: + self._middleware = middleware + self._condition = condition + + def use( + self, + settings: dict[str, Any] | Callable[[ModelRequest], dict[str, Any]], + ) -> ConditionalModelSettingsMiddleware: + """Apply settings when condition is met. + + Args: + settings: Dict of model settings or callable returning settings dict. + + Returns: + Parent middleware instance for chaining. + """ + self._middleware._conditions.append((self._condition, settings)) + return self._middleware + + +class ConditionalModelSettingsMiddleware(AgentMiddleware): + """Dynamically configure model bind settings based on runtime conditions. + + This middleware allows you to apply different `model_settings` (passed to + `model.bind_tools()` or `model.bind()`) based on conditions evaluated at runtime. + All matching conditions have their settings applied cumulatively (later settings + override earlier ones for the same keys). + + Note: `model_settings` are parameters passed to the model's bind method, such as + `parallel_tool_calls`, `strict`, etc. For model inference parameters like + `temperature`, `max_tokens`, use the model configuration directly in `create_agent`. + + Examples: + !!! example "Simple usage - disable parallel tool calls" + + ```python + from langchain.agents import create_agent + from langchain.agents.middleware import ConditionalModelSettingsMiddleware + + # Disable parallel tool calls for long conversations + middleware = ConditionalModelSettingsMiddleware() + middleware.when(lambda req: len(req.messages) > 10).use({"parallel_tool_calls": False}) + + agent = create_agent( + model="openai:gpt-4o", + tools=[some_tool], + middleware=[middleware], + ) + ``` + + !!! example "Using function for condition" + + ```python + def needs_sequential_execution(req: ModelRequest) -> bool: + # Check if state indicates sequential execution needed + return req.state.get("execution_mode") == "sequential" + + + middleware = ConditionalModelSettingsMiddleware() + middleware.when(needs_sequential_execution).use({"parallel_tool_calls": False}) + + agent = create_agent( + model="openai:gpt-4o", + tools=[tool1, tool2], + middleware=[middleware], + ) + ``` + + !!! example "Multiple conditions with cumulative application" + + ```python + middleware = ConditionalModelSettingsMiddleware() + + # Base setting: all long conversations + middleware.when(lambda req: len(req.messages) > 10).use({"parallel_tool_calls": False}) + + # Additional setting: emergency mode (applied on top if both match) + middleware.when(lambda req: req.state.get("emergency")).use({"strict": True}) + + agent = create_agent( + model="openai:gpt-4o", + tools=[tool1, tool2, tool3], + middleware=[middleware], + ) + + # If messages > 10 AND emergency=True: + # Result: {"parallel_tool_calls": False, "strict": True} + # Both conditions apply cumulatively + ``` + + !!! example "Dynamic settings with callable" + + ```python + def compute_settings(req: ModelRequest) -> dict[str, Any]: + # Enable parallel calls only for short conversations + if len(req.messages) < 5: + return {"parallel_tool_calls": True} + return {"parallel_tool_calls": False} + + + middleware = ConditionalModelSettingsMiddleware() + middleware.when(lambda req: True).use(compute_settings) + + agent = create_agent( + model="openai:gpt-4o", + tools=[tool1, tool2], + middleware=[middleware], + ) + ``` + """ + + def __init__(self) -> None: + """Initialize middleware. Settings are merged with existing model_settings.""" + super().__init__() + self._conditions: list[ + tuple[ + Callable[[ModelRequest], bool | Awaitable[bool]], + dict[str, Any] | Callable[[ModelRequest], dict[str, Any]], + ] + ] = [] + + def when( + self, + condition: Callable[[ModelRequest], bool | Awaitable[bool]], + ) -> _ConditionBuilder: + """Register condition for applying settings. + + Args: + condition: Callable taking ModelRequest and returning bool (sync or async). + + Returns: + Builder object with .use() method. + """ + return _ConditionBuilder(self, condition) + + def _apply_settings( + self, + request: ModelRequest, + settings: dict[str, Any] | Callable[[ModelRequest], dict[str, Any]], + ) -> None: + """Apply settings to request.""" + resolved_settings = settings(request) if callable(settings) else settings + request.model_settings = {**request.model_settings, **resolved_settings} + + def wrap_model_call( + self, + request: ModelRequest, + handler: Callable[[ModelRequest], ModelResponse], + ) -> ModelCallResult: + """Apply conditional settings before calling model.""" + for condition, settings in self._conditions: + if iscoroutinefunction(condition): + msg = ( + "Async condition function detected in sync execution path. " + "Use sync condition or invoke agent with `astream()`/`ainvoke()`." + ) + raise RuntimeError(msg) + + if condition(request): + self._apply_settings(request, settings) + + return handler(request) + + async def awrap_model_call( + self, + request: ModelRequest, + handler: Callable[[ModelRequest], Awaitable[ModelResponse]], + ) -> ModelCallResult: + """Apply conditional settings before calling model (async).""" + for condition, settings in self._conditions: + if iscoroutinefunction(condition): + result = await condition(request) + else: + result = condition(request) + + if result: + self._apply_settings(request, settings) + + return await handler(request) diff --git a/libs/langchain_v1/tests/unit_tests/agents/middleware/implementations/test_conditional_model_setting.py b/libs/langchain_v1/tests/unit_tests/agents/middleware/implementations/test_conditional_model_setting.py new file mode 100644 index 0000000000000..4651ea01f92e1 --- /dev/null +++ b/libs/langchain_v1/tests/unit_tests/agents/middleware/implementations/test_conditional_model_setting.py @@ -0,0 +1,307 @@ +"""Tests for ConditionalModelSettingsMiddleware.""" + +import pytest +from langchain_core.messages import HumanMessage + +from langchain.agents.middleware.conditional_model_settings import ( + ConditionalModelSettingsMiddleware, +) +from langchain.agents.middleware.types import ModelRequest +from langchain.agents.factory import create_agent + +from tests.unit_tests.agents.model import FakeToolCallingModel + + +class TestConditionalModelSettingsBasic: + """Test basic functionality.""" + + def test_cumulative_application(self): + """Test that multiple matching conditions apply cumulatively.""" + middleware = ConditionalModelSettingsMiddleware() + + request = ModelRequest( + model=FakeToolCallingModel(), + system_prompt=None, + messages=[HumanMessage(content=f"Message {i}") for i in range(15)], + tool_choice=None, + tools=[], + response_format=None, + state={"messages": [], "emergency": True}, + runtime=None, # type: ignore + model_settings={}, + ) + + # First condition + middleware._apply_settings(request, {"parallel_tool_calls": False}) + assert request.model_settings == {"parallel_tool_calls": False} + + # Second condition (cumulative) + middleware._apply_settings(request, {"strict": True}) + assert request.model_settings == {"parallel_tool_calls": False, "strict": True} + + def test_later_settings_override_earlier(self): + """Test that later settings override earlier ones for the same key.""" + middleware = ConditionalModelSettingsMiddleware() + + request = ModelRequest( + model=FakeToolCallingModel(), + system_prompt=None, + messages=[HumanMessage(content="test")], + tool_choice=None, + tools=[], + response_format=None, + state={"messages": []}, + runtime=None, # type: ignore + model_settings={}, + ) + + middleware._apply_settings(request, {"parallel_tool_calls": False, "strict": False}) + middleware._apply_settings(request, {"strict": True}) + assert request.model_settings == {"parallel_tool_calls": False, "strict": True} + + def test_callable_settings(self): + """Test that callable settings are resolved correctly.""" + + def compute_settings(req: ModelRequest) -> dict: + return {"parallel_tool_calls": len(req.messages) < 5} + + request = ModelRequest( + model=FakeToolCallingModel(), + system_prompt=None, + messages=[HumanMessage(content=f"Message {i}") for i in range(3)], + tool_choice=None, + tools=[], + response_format=None, + state={"messages": []}, + runtime=None, # type: ignore + model_settings={}, + ) + + middleware = ConditionalModelSettingsMiddleware() + middleware._apply_settings(request, compute_settings) + assert request.model_settings == {"parallel_tool_calls": True} + + def test_builder_pattern_chaining(self): + """Test that builder pattern returns middleware for chaining.""" + middleware = ConditionalModelSettingsMiddleware() + result = middleware.when(lambda req: True).use({"setting1": "value1"}) + assert result is middleware + assert len(middleware._conditions) == 1 + + +class TestWrapModelCall: + """Test wrap_model_call behavior.""" + + def test_multiple_conditions_cumulative(self): + """Test that multiple conditions apply cumulatively.""" + middleware = ConditionalModelSettingsMiddleware() + middleware.when(lambda req: len(req.messages) > 5).use({"parallel_tool_calls": False}) + middleware.when(lambda req: req.state.get("emergency")).use({"strict": True}) + + request = ModelRequest( + model=FakeToolCallingModel(), + system_prompt=None, + messages=[HumanMessage(content=f"Message {i}") for i in range(10)], + tool_choice=None, + tools=[], + response_format=None, + state={"messages": [], "emergency": True}, + runtime=None, # type: ignore + model_settings={}, + ) + + middleware.wrap_model_call(request, lambda req: None) + assert request.model_settings == {"parallel_tool_calls": False, "strict": True} + + def test_async_condition_raises_error(self): + """Test that async condition in sync mode raises RuntimeError.""" + middleware = ConditionalModelSettingsMiddleware() + + async def async_condition(req): + return True + + middleware.when(async_condition).use({"parallel_tool_calls": False}) + + request = ModelRequest( + model=FakeToolCallingModel(), + system_prompt=None, + messages=[HumanMessage(content="test")], + tool_choice=None, + tools=[], + response_format=None, + state={"messages": []}, + runtime=None, # type: ignore + model_settings={}, + ) + + with pytest.raises(RuntimeError, match="Async condition function detected"): + middleware.wrap_model_call(request, lambda req: None) + + +class TestAwrapModelCall: + """Test awrap_model_call behavior.""" + + @pytest.mark.asyncio + async def test_async_condition(self): + """Test that async conditions work.""" + middleware = ConditionalModelSettingsMiddleware() + + async def async_condition(req): + return len(req.messages) > 5 + + middleware.when(async_condition).use({"parallel_tool_calls": False}) + + request = ModelRequest( + model=FakeToolCallingModel(), + system_prompt=None, + messages=[HumanMessage(content=f"Message {i}") for i in range(10)], + tool_choice=None, + tools=[], + response_format=None, + state={"messages": []}, + runtime=None, # type: ignore + model_settings={}, + ) + + async def async_handler(req): + return None + + await middleware.awrap_model_call(request, async_handler) + assert request.model_settings == {"parallel_tool_calls": False} + + @pytest.mark.asyncio + async def test_sync_condition_in_async(self): + """Test that sync conditions work in async mode.""" + middleware = ConditionalModelSettingsMiddleware() + middleware.when(lambda req: len(req.messages) > 5).use({"parallel_tool_calls": False}) + + request = ModelRequest( + model=FakeToolCallingModel(), + system_prompt=None, + messages=[HumanMessage(content=f"Message {i}") for i in range(10)], + tool_choice=None, + tools=[], + response_format=None, + state={"messages": []}, + runtime=None, # type: ignore + model_settings={}, + ) + + async def async_handler(req): + return None + + await middleware.awrap_model_call(request, async_handler) + assert request.model_settings == {"parallel_tool_calls": False} + + +class TestCreateAgentIntegration: + """Test middleware behavior in create_agent.""" + + def test_single_condition_lambda(self): + """Test single condition with lambda in agent.""" + captured_settings = {} + + class CapturingMiddleware(ConditionalModelSettingsMiddleware): + def wrap_model_call(self, request, handler): + result = super().wrap_model_call(request, handler) + captured_settings.update(request.model_settings) + return result + + middleware = CapturingMiddleware() + middleware.when(lambda req: len(req.messages) > 2).use({"parallel_tool_calls": False}) + + model = FakeToolCallingModel() + agent = create_agent(model=model, middleware=[middleware]) + + # Short conversation - condition should not match + agent.invoke({"messages": [HumanMessage(content="Hello")]}) + assert "parallel_tool_calls" not in captured_settings + + # Long conversation - condition should match + captured_settings.clear() + agent.invoke({"messages": [HumanMessage(content=f"Msg {i}") for i in range(5)]}) + assert captured_settings.get("parallel_tool_calls") is False + + def test_single_condition_function(self): + """Test single condition with function in agent.""" + captured_settings = {} + + class CapturingMiddleware(ConditionalModelSettingsMiddleware): + def wrap_model_call(self, request, handler): + result = super().wrap_model_call(request, handler) + captured_settings.update(request.model_settings) + return result + + def is_long_conversation(req: ModelRequest) -> bool: + return len(req.messages) > 3 + + middleware = CapturingMiddleware() + middleware.when(is_long_conversation).use({"parallel_tool_calls": False}) + + model = FakeToolCallingModel() + agent = create_agent(model=model, middleware=[middleware]) + + # Short conversation - condition should not match + agent.invoke({"messages": [HumanMessage(content=f"Msg {i}") for i in range(2)]}) + assert "parallel_tool_calls" not in captured_settings + + # Long conversation - condition should match + captured_settings.clear() + agent.invoke({"messages": [HumanMessage(content=f"Msg {i}") for i in range(5)]}) + assert captured_settings.get("parallel_tool_calls") is False + + def test_multiple_conditions_cumulative(self): + """Test multiple conditions apply cumulatively in agent.""" + captured_settings = {} + + class CapturingMiddleware(ConditionalModelSettingsMiddleware): + def wrap_model_call(self, request, handler): + result = super().wrap_model_call(request, handler) + captured_settings.update(request.model_settings) + return result + + middleware = CapturingMiddleware() + middleware.when(lambda req: len(req.messages) > 2).use({"parallel_tool_calls": False}) + middleware.when(lambda req: len(req.messages) > 4).use({"strict": True}) + + model = FakeToolCallingModel() + agent = create_agent(model=model, middleware=[middleware]) + + # Only first condition matches + agent.invoke({"messages": [HumanMessage(content=f"Msg {i}") for i in range(3)]}) + assert captured_settings.get("parallel_tool_calls") is False + assert "strict" not in captured_settings + + # Both conditions match - should apply both settings + captured_settings.clear() + agent.invoke({"messages": [HumanMessage(content=f"Msg {i}") for i in range(6)]}) + assert captured_settings.get("parallel_tool_calls") is False + assert captured_settings.get("strict") is True + + def test_callable_settings_function(self): + """Test callable settings (use with function) in agent.""" + captured_settings = {} + + class CapturingMiddleware(ConditionalModelSettingsMiddleware): + def wrap_model_call(self, request, handler): + result = super().wrap_model_call(request, handler) + captured_settings.update(request.model_settings) + return result + + def compute_settings(req: ModelRequest) -> dict: + return {"parallel_tool_calls": len(req.messages) < 3} + + middleware = CapturingMiddleware() + middleware.when(lambda req: True).use(compute_settings) + + model = FakeToolCallingModel() + agent = create_agent(model=model, middleware=[middleware]) + + # Short conversation - should enable parallel + agent.invoke({"messages": [HumanMessage(content="Test")]}) + assert captured_settings.get("parallel_tool_calls") is True + + # Long conversation - should disable parallel + captured_settings.clear() + agent.invoke({"messages": [HumanMessage(content=f"Msg {i}") for i in range(5)]}) + assert captured_settings.get("parallel_tool_calls") is False