diff --git a/libs/langchain_v1/langchain/agents/factory.py b/libs/langchain_v1/langchain/agents/factory.py index eb07787929124..b0e1bcfe75d4f 100644 --- a/libs/langchain_v1/langchain/agents/factory.py +++ b/libs/langchain_v1/langchain/agents/factory.py @@ -2,7 +2,10 @@ from __future__ import annotations +import heapq import itertools +from collections import deque +from dataclasses import dataclass, field from typing import ( TYPE_CHECKING, Annotated, @@ -29,9 +32,13 @@ AgentMiddleware, AgentState, JumpTo, + MergeStrategy, + MiddlewareOrderCycleError, + MiddlewareSpec, ModelRequest, ModelResponse, OmitFromSchema, + OrderingConstraints, ResponseT, StateT_co, _InputAgentState, @@ -541,6 +548,367 @@ async def call_inner(req: ToolCallRequest) -> ToolMessage | Command: return result +@dataclass(slots=True) +class _ResolvedMiddlewareNode: + """Internal representation of middleware during dependency resolution.""" + + instance: AgentMiddleware[Any, Any] + id: str + order_index: int + insertion_seq: int + priority_value: float + priority_raw: Any + tags: tuple[str, ...] + before: set[str] = field(default_factory=set) + after: set[str] = field(default_factory=set) + requires_before: set[str] = field(default_factory=set) + processed: bool = False + + +@dataclass(slots=True) +class _MiddlewareIdGenerator: + counts: dict[str, int] = field(default_factory=dict) + + def assign(self, instance: AgentMiddleware[Any, Any]) -> None: + if instance.id: + return + + qualified = f"{instance.__class__.__module__}.{instance.__class__.__qualname__}" + count = self.counts.get(qualified, 0) + self.counts[qualified] = count + 1 + + identifier = instance.name if count == 0 else f"{qualified}#{count + 1}" + + instance.id = identifier + + +def _normalize_tags(tags: Sequence[str] | None) -> tuple[str, ...]: + if not tags: + return () + return tuple(dict.fromkeys(tags)) + + +def _merge_tags(existing: tuple[str, ...], new: tuple[str, ...]) -> tuple[str, ...]: + if not new: + return existing + return tuple(dict.fromkeys((*existing, *new))) + + +def _normalize_priority(value: Any) -> tuple[float, Any]: + if value is None: + return 0.0, None + try: + return float(value), value + except (TypeError, ValueError) as exc: + msg = f"Middleware priority {value!r} cannot be converted to a float" + raise TypeError(msg) from exc + + +def _get_middleware_id(instance: AgentMiddleware[Any, Any]) -> str: + identifier = instance.id if instance.id is not None else instance.name + instance.id = identifier + return identifier + + +def _ensure_instance_metadata( + instance: AgentMiddleware[Any, Any], +) -> tuple[str, tuple[str, ...], float, Any]: + identifier = _get_middleware_id(instance) + tags_tuple = _normalize_tags(getattr(instance, "tags", ())) + instance.tags = tags_tuple + priority_value, priority_raw = _normalize_priority(getattr(instance, "priority", None)) + instance.priority = priority_raw + return identifier, tags_tuple, priority_value, priority_raw + + +def _prepare_spec_instance( + spec: MiddlewareSpec[Any, Any], +) -> tuple[AgentMiddleware[Any, Any], OrderingConstraints | None, MergeStrategy]: + instance: AgentMiddleware[Any, Any] | None + if spec.middleware is not None: + instance = spec.middleware + elif spec.factory is not None: + instance = spec.factory() + else: # pragma: no cover - guarded by MiddlewareSpec + instance = None + + if instance is None: + msg = "MiddlewareSpec factory returned None" + raise ValueError(msg) + + if spec.id is not None: + instance.id = spec.id + if spec.priority is not None: + instance.priority = spec.priority + if spec.tags is not None: + instance.tags = tuple(spec.tags) + + return instance, spec.ordering, spec.merge_strategy + + +def _register_middleware_node( + *, + instance: AgentMiddleware[Any, Any], + origin_index: int, + merge_strategy: MergeStrategy, + ordering: OrderingConstraints | None, + nodes: dict[str, _ResolvedMiddlewareNode], + insertion_counter: list[int], + id_generator: _MiddlewareIdGenerator, +) -> tuple[_ResolvedMiddlewareNode, bool, bool]: + id_generator.assign(instance) + identifier, tags_tuple, priority_value, priority_raw = _ensure_instance_metadata(instance) + ordering_before = ordering.before if ordering else () + ordering_after = ordering.after if ordering else () + + node = nodes.get(identifier) + if node is None: + node = _ResolvedMiddlewareNode( + instance=instance, + id=identifier, + order_index=origin_index, + insertion_seq=insertion_counter[0], + priority_value=priority_value, + priority_raw=priority_raw, + tags=tags_tuple, + ) + insertion_counter[0] += 1 + node.before.update(ordering_before) + node.after.update(ordering_after) + nodes[identifier] = node + return node, True, False + + node.order_index = min(node.order_index, origin_index) + existing_tags = node.tags + replaced = False + + if merge_strategy == "error" and node.instance is not instance: + msg = ( + f"Duplicate middleware id '{identifier}' encountered without a merge strategy. " + "Set merge_strategy to 'first_wins' or 'last_wins', or provide a unique id." + ) + raise ValueError(msg) + + if merge_strategy == "last_wins" and node.instance is not instance: + node.instance = instance + node.priority_value = priority_value + node.priority_raw = priority_raw + node.processed = False + replaced = True + elif merge_strategy in {"first_wins", "error"}: + pass + else: # pragma: no cover - guarded by Literal typing but keeps runtime safe + msg = f"Unknown merge strategy '{merge_strategy}' for middleware '{identifier}'." + raise ValueError(msg) + + node.tags = _merge_tags(existing_tags, tags_tuple) + node.before.update(ordering_before) + node.after.update(ordering_after) + + return node, False, replaced + + +def _resolve_order_targets( + token: str, + source_id: str, + tag_map: dict[str, set[str]], + nodes: dict[str, _ResolvedMiddlewareNode], +) -> set[str]: + if token.startswith("tag:"): + tag = token[4:] + targets = set(tag_map.get(tag, set())) + if not targets: + msg = f"Ordering constraint on middleware '{source_id}' references unknown tag '{tag}'." + raise ValueError(msg) + targets.discard(source_id) + if not targets: + msg = ( + "Ordering constraint on middleware " + f"'{source_id}' cannot target itself via tag '{tag}'." + ) + raise ValueError(msg) + return targets + + if token not in nodes: + msg = ( + "Ordering constraint on middleware " + f"'{source_id}' references unknown middleware id '{token}'." + ) + raise ValueError(msg) + + if token == source_id: + msg = f"Ordering constraint on middleware '{source_id}' cannot reference itself." + raise ValueError(msg) + + return {token} + + +def _collect_middleware_edges(nodes: dict[str, _ResolvedMiddlewareNode]) -> set[tuple[str, str]]: + edges: set[tuple[str, str]] = set() + tag_map: dict[str, set[str]] = {} + + for identifier, node in nodes.items(): + for dependency in node.requires_before: + edges.add((dependency, identifier)) + + for tag in node.tags: + tag_map.setdefault(tag, set()).add(identifier) + + for identifier, node in nodes.items(): + for target in node.before: + for resolved in _resolve_order_targets(target, identifier, tag_map, nodes): + edges.add((identifier, resolved)) + + for target in node.after: + for resolved in _resolve_order_targets(target, identifier, tag_map, nodes): + edges.add((resolved, identifier)) + + return edges + + +def _find_cycle(adjacency: dict[str, set[str]]) -> list[str] | None: + visited: set[str] = set() + stack: set[str] = set() + path: list[str] = [] + + def dfs(node_id: str) -> list[str] | None: + visited.add(node_id) + stack.add(node_id) + path.append(node_id) + + for neighbour in adjacency[node_id]: + if neighbour not in visited: + result = dfs(neighbour) + if result: + return result + elif neighbour in stack: + cycle_start = path.index(neighbour) + return [*path[cycle_start:], neighbour] + + stack.remove(node_id) + path.pop() + return None + + for identifier in adjacency: + if identifier not in visited: + result = dfs(identifier) + if result: + return result + return None + + +def _topologically_sort_middleware( + nodes: dict[str, _ResolvedMiddlewareNode], + edges: set[tuple[str, str]], +) -> list[str]: + adjacency: dict[str, set[str]] = {identifier: set() for identifier in nodes} + indegree = cast("dict[str, int]", dict.fromkeys(nodes, 0)) + + for source, target in edges: + if source not in adjacency: + msg = f"Ordering constraint references unknown middleware id '{source}'." + raise ValueError(msg) + if target not in adjacency: + msg = f"Ordering constraint references unknown middleware id '{target}'." + raise ValueError(msg) + if source == target: + msg = f"Detected cycle in middleware ordering: '{source}' depends on itself." + raise MiddlewareOrderCycleError(msg) + if target not in adjacency[source]: + adjacency[source].add(target) + indegree[target] += 1 + + heap: list[tuple[int, float, int, str]] = [] + for identifier, node in nodes.items(): + if indegree[identifier] == 0: + heapq.heappush( + heap, + (node.order_index, -node.priority_value, node.insertion_seq, identifier), + ) + + ordered: list[str] = [] + while heap: + _, _, _, identifier = heapq.heappop(heap) + ordered.append(identifier) + for target in adjacency[identifier]: + indegree[target] -= 1 + if indegree[target] == 0: + node = nodes[target] + heapq.heappush( + heap, + (node.order_index, -node.priority_value, node.insertion_seq, target), + ) + + if len(ordered) != len(nodes): + cycle = _find_cycle(adjacency) + cycle_path = " -> ".join(cycle) if cycle else "unknown cycle" + msg = f"Detected cycle in middleware ordering: {cycle_path}" + raise MiddlewareOrderCycleError(msg) + + return ordered + + +def _resolve_middleware( + middleware: Sequence[AgentMiddleware[StateT_co, ContextT]], +) -> list[AgentMiddleware[StateT_co, ContextT]]: + nodes: dict[str, _ResolvedMiddlewareNode] = {} + queue: deque[str] = deque() + insertion_counter = [0] + id_generator = _MiddlewareIdGenerator() + + for index, instance in enumerate(middleware): + node, created, replaced = _register_middleware_node( + instance=instance, + origin_index=index, + merge_strategy="error", + ordering=None, + nodes=nodes, + insertion_counter=insertion_counter, + id_generator=id_generator, + ) + if created or replaced or not node.processed: + queue.append(node.id) + + while queue: + node_id = queue.popleft() + node = nodes[node_id] + if node.processed: + continue + + node.processed = True + node.requires_before.clear() + + dependencies = node.instance.requires() or () + for spec in dependencies: + instance, ordering, merge_strategy = _prepare_spec_instance(spec) + dep_node, created, replaced = _register_middleware_node( + instance=instance, + origin_index=node.order_index, + merge_strategy=merge_strategy, + ordering=ordering, + nodes=nodes, + insertion_counter=insertion_counter, + id_generator=id_generator, + ) + node.requires_before.add(dep_node.id) + + if created or replaced or not dep_node.processed: + queue.append(dep_node.id) + + edges = _collect_middleware_edges(nodes) + ordered_ids = _topologically_sort_middleware(nodes, edges) + + resolved: list[AgentMiddleware[StateT_co, ContextT]] = [] + for identifier in ordered_ids: + node = nodes[identifier] + node.instance.id = node.id + node.instance.tags = node.tags + node.instance.priority = node.priority_raw + resolved.append(node.instance) + + return resolved + + def create_agent( # noqa: PLR0915 model: str | BaseChatModel, tools: Sequence[BaseTool | Callable | dict[str, Any]] | None = None, @@ -722,6 +1090,8 @@ def check_weather(location: str) -> str: for response_schema in tool_strategy_for_setup.schema_specs: structured_tool_info = OutputToolBinding.from_schema_spec(response_schema) structured_output_tools[structured_tool_info.tool.name] = structured_tool_info + middleware = _resolve_middleware(list(middleware)) + middleware_tools = [t for m in middleware for t in getattr(m, "tools", [])] # Collect middleware with wrap_tool_call or awrap_tool_call hooks diff --git a/libs/langchain_v1/langchain/agents/middleware/__init__.py b/libs/langchain_v1/langchain/agents/middleware/__init__.py index 44389008aa41b..04a4271870a52 100644 --- a/libs/langchain_v1/langchain/agents/middleware/__init__.py +++ b/libs/langchain_v1/langchain/agents/middleware/__init__.py @@ -28,8 +28,11 @@ from .types import ( AgentMiddleware, AgentState, + MiddlewareOrderCycleError, + MiddlewareSpec, ModelRequest, ModelResponse, + OrderingConstraints, after_agent, after_model, before_agent, @@ -53,10 +56,13 @@ "InterruptOnConfig", "LLMToolEmulator", "LLMToolSelectorMiddleware", + "MiddlewareOrderCycleError", + "MiddlewareSpec", "ModelCallLimitMiddleware", "ModelFallbackMiddleware", "ModelRequest", "ModelResponse", + "OrderingConstraints", "PIIDetectionError", "PIIMiddleware", "RedactionRule", diff --git a/libs/langchain_v1/langchain/agents/middleware/types.py b/libs/langchain_v1/langchain/agents/middleware/types.py index 2ec6ac5dbaf52..8bc65382cd52f 100644 --- a/libs/langchain_v1/langchain/agents/middleware/types.py +++ b/libs/langchain_v1/langchain/agents/middleware/types.py @@ -2,7 +2,7 @@ from __future__ import annotations -from collections.abc import Awaitable, Callable +from collections.abc import Awaitable, Callable, Sequence from dataclasses import dataclass, field, replace from inspect import iscoroutinefunction from typing import ( @@ -46,9 +46,12 @@ "AgentMiddleware", "AgentState", "ContextT", + "MiddlewareOrderCycleError", + "MiddlewareSpec", "ModelRequest", "ModelResponse", "OmitFromSchema", + "OrderingConstraints", "ResponseT", "StateT_co", "ToolCallRequest", @@ -59,6 +62,7 @@ "before_model", "dynamic_prompt", "hook_config", + "wrap_model_call", "wrap_tool_call", ] @@ -219,6 +223,15 @@ class AgentMiddleware(Generic[StateT, ContextT]): tools: list[BaseTool] """Additional tools registered by the middleware.""" + id: str | None = None + """Optional unique identifier used for deduplication and ordering.""" + + priority: float | int | None = None + """Optional priority used as a tie-breaker when ordering middleware.""" + + tags: tuple[str, ...] = () + """Tags that can be referenced by ordering constraints.""" + @property def name(self) -> str: """The name of the middleware instance. @@ -227,6 +240,10 @@ def name(self) -> str: """ return self.__class__.__name__ + def requires(self) -> Sequence[MiddlewareSpec[StateT, ContextT]]: + """Return additional middleware specifications required by this instance.""" + return () + def before_agent(self, state: StateT, runtime: Runtime[ContextT]) -> dict[str, Any] | None: """Logic to run before the agent execution starts. @@ -560,6 +577,64 @@ async def awrap_tool_call(self, request, handler): raise NotImplementedError(msg) +MergeStrategy = Literal["first_wins", "last_wins", "error"] + + +@dataclass(slots=True) +class OrderingConstraints: + """Ordering rules for middleware resolution.""" + + before: tuple[str, ...] = field(default_factory=tuple) + """Identifiers or tag references that must execute after this middleware.""" + + after: tuple[str, ...] = field(default_factory=tuple) + """Identifiers or tag references that must execute before this middleware.""" + + def __post_init__(self) -> None: + """Normalize incoming ordering tokens to tuples.""" + self.before = tuple(self.before) + self.after = tuple(self.after) + + +@dataclass(slots=True) +class MiddlewareSpec(Generic[StateT, ContextT]): + """Specification describing middleware dependencies.""" + + factory: Callable[[], AgentMiddleware[StateT, ContextT]] | None = None + """Factory used to instantiate the middleware dependency.""" + + middleware: AgentMiddleware[StateT, ContextT] | None = None + """Pre-instantiated middleware instance.""" + + id: str | None = None + """Optional identifier override for the dependency.""" + + priority: float | int | None = None + """Optional priority override for tie-breaking.""" + + tags: Sequence[str] | None = None + """Optional tag override for the dependency.""" + + ordering: OrderingConstraints | None = None + """Additional ordering constraints for the dependency.""" + + merge_strategy: MergeStrategy = "first_wins" + """Strategy for handling duplicate middleware identifiers.""" + + def __post_init__(self) -> None: + """Validate spec inputs and normalize metadata.""" + if self.factory is None and self.middleware is None: + msg = "MiddlewareSpec requires either 'factory' or 'middleware'." + raise ValueError(msg) + + if self.tags is not None: + self.tags = tuple(self.tags) + + +class MiddlewareOrderCycleError(ValueError): + """Raised when middleware ordering introduces a cycle.""" + + class _CallableWithStateAndRuntime(Protocol[StateT_contra, ContextT]): """Callable with `AgentState` and `Runtime` as arguments.""" diff --git a/libs/langchain_v1/tests/unit_tests/agents/test_middleware_resolution.py b/libs/langchain_v1/tests/unit_tests/agents/test_middleware_resolution.py new file mode 100644 index 0000000000000..854e46057235f --- /dev/null +++ b/libs/langchain_v1/tests/unit_tests/agents/test_middleware_resolution.py @@ -0,0 +1,350 @@ +from __future__ import annotations + +from functools import partial +from typing import Sequence + +import pytest + +from langchain.agents.factory import _resolve_middleware, create_agent +from langchain.agents.middleware.types import ( + AgentMiddleware, + AgentState, + MiddlewareOrderCycleError, + MiddlewareSpec, + OrderingConstraints, +) +from langchain_core.language_models import FakeListChatModel +from langchain_core.messages import HumanMessage + + +class BareMiddleware(AgentMiddleware): + """Minimal middleware used for unit tests.""" + + def __init__( + self, + mid: str, + *, + priority: float | None = None, + tags: Sequence[str] | None = None, + ) -> None: + self.id = mid + self.priority = priority + self.tags = tuple(tags or ()) + self.tools = [] + + @property + def name(self) -> str: # pragma: no cover - exercised indirectly + return self.id or super().name + + +class SharedMiddleware(BareMiddleware): + def __init__(self, marker: str) -> None: + super().__init__("shared") + self.marker = marker + + +class ParentWithShared(BareMiddleware): + def __init__(self, mid: str, marker: str, *, strategy: str = "first_wins") -> None: + super().__init__(mid) + self._marker = marker + self._strategy = strategy + + def requires(self) -> Sequence[MiddlewareSpec[AgentState, object]]: + return [ + MiddlewareSpec( + factory=lambda: SharedMiddleware(self._marker), + merge_strategy=self._strategy, + ) + ] + + +class DependencyMiddleware(BareMiddleware): + def __init__(self) -> None: + super().__init__("dependency", tags=("helper",)) + + +class ControllerMiddleware(BareMiddleware): + def __init__(self) -> None: + super().__init__("controller") + + def requires(self) -> Sequence[MiddlewareSpec[AgentState, object]]: + return [MiddlewareSpec(factory=DependencyMiddleware)] + + +class OrderedControllerMiddleware(ControllerMiddleware): + def requires(self) -> Sequence[MiddlewareSpec[AgentState, object]]: + return [ + MiddlewareSpec( + factory=DependencyMiddleware, + ordering=OrderingConstraints( + after=("start",), + before=("tag:exit",), + ), + ) + ] + + +class PriorityDependency(BareMiddleware): + def __init__(self, mid: str, priority: float) -> None: + super().__init__(mid, priority=priority) + + +class PriorityParent(BareMiddleware): + def __init__(self) -> None: + super().__init__("priority-parent") + + def requires(self) -> Sequence[MiddlewareSpec[AgentState, object]]: + return [ + MiddlewareSpec(factory=lambda: PriorityDependency("low", 1)), + MiddlewareSpec(factory=lambda: PriorityDependency("high", 10)), + ] + + +class CycleA(BareMiddleware): + def __init__(self) -> None: + super().__init__("cycle-a") + + def requires(self) -> Sequence[MiddlewareSpec[AgentState, object]]: + return [MiddlewareSpec(factory=CycleB)] + + +class CycleB(BareMiddleware): + def __init__(self) -> None: + super().__init__("cycle-b") + + def requires(self) -> Sequence[MiddlewareSpec[AgentState, object]]: + return [MiddlewareSpec(factory=CycleA)] + + +class RecordingMiddleware(AgentMiddleware): + """Middleware that records hook execution order.""" + + def __init__( + self, + mid: str, + log: list[str], + *, + priority: float | None = None, + tags: Sequence[str] | None = None, + ) -> None: + self.id = mid + self._name = mid + self.priority = priority + self.tags = tuple(tags or ()) + self._log = log + self.tools = [] + + @property + def name(self) -> str: + return self._name + + def before_model(self, state, runtime) -> None: # noqa: ARG002 + self._log.append(f"{self.id}:before_model") + + def after_model(self, state, runtime) -> None: # noqa: ARG002 + self._log.append(f"{self.id}:after_model") + + def wrap_tool_call(self, request, handler): # noqa: ARG002 + return handler(request) + + async def awrap_tool_call(self, request, handler): # noqa: ARG002 + return await handler(request) + + +class DeepMiddleware(RecordingMiddleware): + def __init__(self, log: list[str]) -> None: + super().__init__("deep", log) + + +class InnerMiddleware(RecordingMiddleware): + def __init__(self, log: list[str]) -> None: + super().__init__("inner", log) + + def requires(self) -> Sequence[MiddlewareSpec[AgentState, object]]: + return [MiddlewareSpec(factory=partial(DeepMiddleware, self._log))] + + +class OuterMiddleware(RecordingMiddleware): + def __init__(self, log: list[str]) -> None: + super().__init__("outer", log) + + def requires(self) -> Sequence[MiddlewareSpec[AgentState, object]]: + return [MiddlewareSpec(factory=partial(InnerMiddleware, self._log))] + + +class SharedLoggingMiddleware(RecordingMiddleware): + def __init__(self, log: list[str], marker: str) -> None: + super().__init__("shared", log, tags=("shared",)) + self.marker = marker + + def before_model(self, state, runtime) -> None: # noqa: ARG002 + self._log.append(f"shared:{self.marker}:before_model") + + def after_model(self, state, runtime) -> None: # noqa: ARG002 + self._log.append(f"shared:{self.marker}:after_model") + + +class ParentOne(RecordingMiddleware): + def __init__(self, log: list[str]) -> None: + super().__init__("parent-one", log) + + def requires(self) -> Sequence[MiddlewareSpec[AgentState, object]]: + return [ + MiddlewareSpec( + factory=partial(SharedLoggingMiddleware, self._log, "first"), + merge_strategy="first_wins", + ) + ] + + +class ParentTwo(RecordingMiddleware): + def __init__(self, log: list[str]) -> None: + super().__init__("parent-two", log) + + def requires(self) -> Sequence[MiddlewareSpec[AgentState, object]]: + return [ + MiddlewareSpec( + factory=partial(SharedLoggingMiddleware, self._log, "second"), + merge_strategy="first_wins", + ) + ] + + +class AnonymousMiddleware(AgentMiddleware): + def __init__(self, marker: str) -> None: + self.marker = marker + self.tools: list = [] + + +def test_requires_flattens_dependencies() -> None: + resolved = _resolve_middleware([ControllerMiddleware()]) + ids = [mw.id for mw in resolved] + assert ids == ["dependency", "controller"] + + +def test_first_wins_deduplication_keeps_initial_instance() -> None: + resolved = _resolve_middleware( + [ParentWithShared("one", "first"), ParentWithShared("two", "second")] + ) + shared = [mw for mw in resolved if mw.id == "shared"] + assert len(shared) == 1 + assert shared[0].marker == "first" + assert [mw.id for mw in resolved] == ["shared", "one", "two"] + + +def test_last_wins_replaces_existing_instance() -> None: + resolved = _resolve_middleware( + [ + ParentWithShared("one", "first", strategy="last_wins"), + ParentWithShared("two", "second", strategy="last_wins"), + ] + ) + shared = [mw for mw in resolved if mw.id == "shared"] + assert len(shared) == 1 + assert shared[0].marker == "second" + + +def test_duplicate_without_merge_strategy_errors() -> None: + with pytest.raises(ValueError, match="Duplicate middleware id 'shared'"): + _resolve_middleware( + [ + ParentWithShared("one", "first", strategy="error"), + ParentWithShared("two", "second", strategy="error"), + ] + ) + + +def test_anonymous_instances_preserve_order_without_conflict() -> None: + first = AnonymousMiddleware("first") + second = AnonymousMiddleware("second") + + resolved = _resolve_middleware([first, second]) + + assert resolved[:2] == [first, second] + assert [mw.marker for mw in resolved] == ["first", "second"] + assert resolved[0].id == "AnonymousMiddleware" + assert resolved[1].id.startswith( + f"{AnonymousMiddleware.__module__}.{AnonymousMiddleware.__qualname__}#" + ) + + +def test_ordering_constraints_with_ids_and_tags() -> None: + start = BareMiddleware("start", tags=("entry",)) + terminus = BareMiddleware("terminus", tags=("exit",)) + controller = OrderedControllerMiddleware() + + resolved = _resolve_middleware([start, terminus, controller]) + assert [mw.id for mw in resolved] == ["start", "dependency", "terminus", "controller"] + + +def test_ordering_constraint_unknown_tag_raises() -> None: + class TagController(BareMiddleware): + def __init__(self) -> None: + super().__init__("tag-controller") + + def requires(self) -> Sequence[MiddlewareSpec[AgentState, object]]: + return [ + MiddlewareSpec( + factory=DependencyMiddleware, + ordering=OrderingConstraints(after=("tag:missing",)), + ) + ] + + with pytest.raises(ValueError, match="unknown tag 'missing'"): + _resolve_middleware([TagController()]) + + +def test_priority_breaks_ties_when_order_equal() -> None: + resolved = _resolve_middleware([PriorityParent()]) + ids = [mw.id for mw in resolved] + assert ids == ["high", "low", "priority-parent"] + + +def test_cycle_detection_raises_error() -> None: + with pytest.raises(MiddlewareOrderCycleError, match="cycle-a"): + _resolve_middleware([CycleA()]) + + +def test_nested_dependencies_execute_in_order() -> None: + log: list[str] = [] + agent = create_agent( + model=FakeListChatModel(responses=["done"]), + middleware=[OuterMiddleware(log)], + ) + + agent.invoke({"messages": [HumanMessage(content="hi")]}) + + assert log[:3] == [ + "deep:before_model", + "inner:before_model", + "outer:before_model", + ] + assert log[3:] == [ + "outer:after_model", + "inner:after_model", + "deep:after_model", + ] + + +def test_shared_dependency_runs_once_with_first_wins() -> None: + log: list[str] = [] + agent = create_agent( + model=FakeListChatModel(responses=["done"]), + middleware=[ParentOne(log), ParentTwo(log)], + ) + + agent.invoke({"messages": [HumanMessage(content="go")]}) + + before_entries = [entry for entry in log if entry.endswith("before_model")] + after_entries = [entry for entry in log if entry.endswith("after_model")] + + assert before_entries == [ + "shared:first:before_model", + "parent-one:before_model", + "parent-two:before_model", + ] + assert after_entries == [ + "parent-two:after_model", + "parent-one:after_model", + "shared:first:after_model", + ] diff --git a/libs/langchain_v1/uv.lock b/libs/langchain_v1/uv.lock index c9157176b917d..716e73f5a7316 100644 --- a/libs/langchain_v1/uv.lock +++ b/libs/langchain_v1/uv.lock @@ -1,5 +1,5 @@ version = 1 -revision = 2 +revision = 3 requires-python = ">=3.10.0, <4.0.0" resolution-markers = [ "python_full_version >= '3.14' and platform_python_implementation == 'PyPy'", @@ -2395,7 +2395,7 @@ wheels = [ [[package]] name = "langchain-openai" -version = "1.0.2" +version = "1.0.3" source = { editable = "../partners/openai" } dependencies = [ { name = "langchain-core" },