Skip to content

Commit 62ccf7e

Browse files
feat(langchain_v1): simplify to use ONE agent (#33302)
This reduces confusion w/ types like `AgentState`, different arg names, etc. Second attempt, following #33249 * Ability to pass through `cache` and name in `create_agent` as compilation args for the agent * Right now, removing `test_react_agent.py` but we should add these tests back as implemented w/ the new agent * Add conditional edge when structured output tools are present to allow for retries * Rename `tracking` to `model_call_limit` to be consistent w/ tool call limits We need in the future (I'm happy to own): * Significant test refactor * Significant test overhaul where we emphasize and enforce coverage
1 parent 0ff2bc8 commit 62ccf7e

16 files changed

+1254
-3134
lines changed

β€Ž.claude/settings.local.jsonβ€Ž

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,10 @@
77
"WebFetch(domain:ai.pydantic.dev)",
88
"WebFetch(domain:openai.github.io)",
99
"Bash(uv run:*)",
10-
"Bash(python3:*)"
10+
"Bash(python3:*)",
11+
"WebFetch(domain:github.com)",
12+
"Bash(gh pr view:*)",
13+
"Bash(gh pr diff:*)"
1114
],
1215
"deny": [],
1316
"ask": []

β€Žlibs/langchain_v1/langchain/agents/__init__.pyβ€Ž

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""langgraph.prebuilt exposes a higher-level API for creating and executing agents and tools."""
22

3-
from langchain.agents.react_agent import AgentState, create_agent
3+
from langchain.agents.factory import create_agent
4+
from langchain.agents.middleware.types import AgentState
45

56
__all__ = [
67
"AgentState",

β€Žlibs/langchain_v1/langchain/agents/middleware_agent.pyβ€Ž renamed to β€Žlibs/langchain_v1/langchain/agents/factory.pyβ€Ž

Lines changed: 185 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,27 @@
1-
"""Middleware agent implementation."""
1+
"""Agent factory for creating agents with middleware support."""
2+
3+
from __future__ import annotations
24

35
import itertools
4-
from collections.abc import Callable, Sequence
5-
from typing import Annotated, Any, cast, get_args, get_origin, get_type_hints
6+
from typing import (
7+
TYPE_CHECKING,
8+
Annotated,
9+
Any,
10+
cast,
11+
get_args,
12+
get_origin,
13+
get_type_hints,
14+
)
615

716
from langchain_core.language_models.chat_models import BaseChatModel
817
from langchain_core.messages import AIMessage, AnyMessage, SystemMessage, ToolMessage
9-
from langchain_core.runnables import Runnable
1018
from langchain_core.tools import BaseTool
1119
from langgraph._internal._runnable import RunnableCallable
1220
from langgraph.constants import END, START
1321
from langgraph.graph.state import StateGraph
14-
from langgraph.runtime import Runtime
22+
from langgraph.runtime import Runtime # noqa: TC002
1523
from langgraph.types import Send
16-
from langgraph.typing import ContextT
24+
from langgraph.typing import ContextT # noqa: TC002
1725
from typing_extensions import NotRequired, Required, TypedDict, TypeVar
1826

1927
from langchain.agents.middleware.types import (
@@ -37,6 +45,15 @@
3745
from langchain.chat_models import init_chat_model
3846
from langchain.tools import ToolNode
3947

48+
if TYPE_CHECKING:
49+
from collections.abc import Callable, Sequence
50+
51+
from langchain_core.runnables import Runnable
52+
from langgraph.cache.base import BaseCache
53+
from langgraph.graph.state import CompiledStateGraph
54+
from langgraph.store.base import BaseStore
55+
from langgraph.types import Checkpointer
56+
4057
STRUCTURED_OUTPUT_ERROR_TEMPLATE = "Error: {error}\n Please fix your mistakes."
4158

4259
ResponseT = TypeVar("ResponseT")
@@ -176,17 +193,92 @@ def _handle_structured_output_error(
176193

177194

178195
def create_agent( # noqa: PLR0915
179-
*,
180196
model: str | BaseChatModel,
181197
tools: Sequence[BaseTool | Callable | dict[str, Any]] | ToolNode | None = None,
198+
*,
182199
system_prompt: str | None = None,
183200
middleware: Sequence[AgentMiddleware[AgentState[ResponseT], ContextT]] = (),
184201
response_format: ResponseFormat[ResponseT] | type[ResponseT] | None = None,
185202
context_schema: type[ContextT] | None = None,
186-
) -> StateGraph[
203+
checkpointer: Checkpointer | None = None,
204+
store: BaseStore | None = None,
205+
interrupt_before: list[str] | None = None,
206+
interrupt_after: list[str] | None = None,
207+
debug: bool = False,
208+
name: str | None = None,
209+
cache: BaseCache | None = None,
210+
) -> CompiledStateGraph[
187211
AgentState[ResponseT], ContextT, PublicAgentState[ResponseT], PublicAgentState[ResponseT]
188212
]:
189-
"""Create a middleware agent graph."""
213+
"""Creates an agent graph that calls tools in a loop until a stopping condition is met.
214+
215+
For more details on using ``create_agent``,
216+
visit [Agents](https://docs.langchain.com/oss/python/langchain/agents) documentation.
217+
218+
Args:
219+
model: The language model for the agent. Can be a string identifier
220+
(e.g., ``"openai:gpt-4"``), a chat model instance (e.g., ``ChatOpenAI()``).
221+
tools: A list of tools or a ToolNode instance. If ``None`` or an empty list,
222+
the agent will consist of a single LLM node without tool calling.
223+
system_prompt: An optional system prompt for the LLM. If provided as a string,
224+
it will be converted to a SystemMessage and added to the beginning
225+
of the message list.
226+
middleware: A sequence of middleware instances to apply to the agent.
227+
Middleware can intercept and modify agent behavior at various stages.
228+
response_format: An optional configuration for structured responses.
229+
Can be a ToolStrategy, ProviderStrategy, or a Pydantic model class.
230+
If provided, the agent will handle structured output during the
231+
conversation flow. Raw schemas will be wrapped in an appropriate strategy
232+
based on model capabilities.
233+
context_schema: An optional schema for runtime context.
234+
checkpointer: An optional checkpoint saver object. This is used for persisting
235+
the state of the graph (e.g., as chat memory) for a single thread
236+
(e.g., a single conversation).
237+
store: An optional store object. This is used for persisting data
238+
across multiple threads (e.g., multiple conversations / users).
239+
interrupt_before: An optional list of node names to interrupt before.
240+
This is useful if you want to add a user confirmation or other interrupt
241+
before taking an action.
242+
interrupt_after: An optional list of node names to interrupt after.
243+
This is useful if you want to return directly or run additional processing
244+
on an output.
245+
debug: A flag indicating whether to enable debug mode.
246+
name: An optional name for the CompiledStateGraph.
247+
This name will be automatically used when adding the agent graph to
248+
another graph as a subgraph node - particularly useful for building
249+
multi-agent systems.
250+
cache: An optional BaseCache instance to enable caching of graph execution.
251+
252+
Returns:
253+
A compiled StateGraph that can be used for chat interactions.
254+
255+
The agent node calls the language model with the messages list (after applying
256+
the system prompt). If the resulting AIMessage contains ``tool_calls``, the graph will
257+
then call the tools. The tools node executes the tools and adds the responses
258+
to the messages list as ``ToolMessage`` objects. The agent node then calls the
259+
language model again. The process repeats until no more ``tool_calls`` are
260+
present in the response. The agent then returns the full list of messages.
261+
262+
Example:
263+
```python
264+
from langchain.agents import create_agent
265+
266+
267+
def check_weather(location: str) -> str:
268+
'''Return the weather forecast for the specified location.'''
269+
return f"It's always sunny in {location}"
270+
271+
272+
graph = create_agent(
273+
model="anthropic:claude-3-7-sonnet-latest",
274+
tools=[check_weather],
275+
system_prompt="You are a helpful assistant",
276+
)
277+
inputs = {"messages": [{"role": "user", "content": "what is the weather in sf"}]}
278+
for chunk in graph.stream(inputs, stream_mode="updates"):
279+
print(chunk)
280+
```
281+
"""
190282
# init chat model
191283
if isinstance(model, str):
192284
model = init_chat_model(model)
@@ -230,7 +322,7 @@ def create_agent( # noqa: PLR0915
230322
# Setup tools
231323
tool_node: ToolNode | None = None
232324
if isinstance(tools, list):
233-
# Extract built-in provider tools (dict format) and regular tools (BaseTool)
325+
# Extract built-in provider tools (dict format) and regular tools (BaseTool/callables)
234326
built_in_tools = [t for t in tools if isinstance(t, dict)]
235327
regular_tools = [t for t in tools if not isinstance(t, dict)]
236328

@@ -241,9 +333,13 @@ def create_agent( # noqa: PLR0915
241333
tool_node = ToolNode(tools=available_tools) if available_tools else None
242334

243335
# Default tools for ModelRequest initialization
244-
# Include built-ins and regular tools (can be changed dynamically by middleware)
336+
# Use converted BaseTool instances from ToolNode (not raw callables)
337+
# Include built-ins and converted tools (can be changed dynamically by middleware)
245338
# Structured tools are NOT included - they're added dynamically based on response_format
246-
default_tools = regular_tools + middleware_tools + built_in_tools
339+
if tool_node:
340+
default_tools = list(tool_node.tools_by_name.values()) + built_in_tools
341+
else:
342+
default_tools = list(built_in_tools)
247343
elif isinstance(tools, ToolNode):
248344
tool_node = tools
249345
if tool_node:
@@ -252,10 +348,15 @@ def create_agent( # noqa: PLR0915
252348
tool_node = ToolNode(available_tools)
253349

254350
# default_tools includes all client-side tools (no built-ins or structured tools)
255-
default_tools = available_tools
351+
default_tools = list(tool_node.tools_by_name.values())
352+
else:
353+
default_tools = []
354+
# No tools provided, only middleware_tools available
355+
elif middleware_tools:
356+
tool_node = ToolNode(middleware_tools)
357+
default_tools = list(tool_node.tools_by_name.values())
256358
else:
257-
# No tools provided, only middleware_tools available
258-
default_tools = middleware_tools
359+
default_tools = []
259360

260361
# validate middleware
261362
assert len({m.name for m in middleware}) == len(middleware), ( # noqa: S101
@@ -425,16 +526,19 @@ def _get_bound_model(request: ModelRequest) -> tuple[Runnable, ResponseFormat |
425526
is the actual strategy used (may differ from initial if auto-detected).
426527
"""
427528
# Validate ONLY client-side tools that need to exist in tool_node
428-
# Build map of available client-side tools (regular_tools + middleware_tools)
429-
available_tools_by_name = {t.name: t for t in default_tools if isinstance(t, BaseTool)}
529+
# Build map of available client-side tools from the ToolNode
530+
# (which has already converted callables)
531+
available_tools_by_name = {}
532+
if tool_node:
533+
available_tools_by_name = tool_node.tools_by_name.copy()
430534

431535
# Check if any requested tools are unknown CLIENT-SIDE tools
432536
unknown_tool_names = []
433537
for t in request.tools:
434538
# Only validate BaseTool instances (skip built-in dict tools)
435539
if isinstance(t, dict):
436540
continue
437-
if t.name not in available_tools_by_name:
541+
if isinstance(t, BaseTool) and t.name not in available_tools_by_name:
438542
unknown_tool_names.append(t.name)
439543

440544
if unknown_tool_names:
@@ -468,7 +572,8 @@ def _get_bound_model(request: ModelRequest) -> tuple[Runnable, ResponseFormat |
468572
effective_response_format = request.response_format
469573

470574
# Build final tools list including structured output tools
471-
# request.tools already contains both BaseTool and dict (built-in) tools
575+
# request.tools now only contains BaseTool instances (converted from callables)
576+
# and dicts (built-ins)
472577
final_tools = list(request.tools)
473578
if isinstance(effective_response_format, ToolStrategy):
474579
# Add structured output tools to final tools list
@@ -767,6 +872,12 @@ async def amodel_request(state: AgentState, runtime: Runtime[ContextT]) -> dict[
767872
),
768873
[loop_entry_node, "tools", exit_node],
769874
)
875+
elif len(structured_output_tools) > 0:
876+
graph.add_conditional_edges(
877+
loop_exit_node,
878+
_make_model_to_model_edge(loop_entry_node, exit_node),
879+
[loop_entry_node, exit_node],
880+
)
770881
elif loop_exit_node == "model_request":
771882
# If no tools and no after_model, go directly to exit_node
772883
graph.add_edge(loop_exit_node, exit_node)
@@ -857,7 +968,15 @@ async def amodel_request(state: AgentState, runtime: Runtime[ContextT]) -> dict[
857968
can_jump_to=_get_can_jump_to(middleware_w_after_agent[0], "after_agent"),
858969
)
859970

860-
return graph
971+
return graph.compile(
972+
checkpointer=checkpointer,
973+
store=store,
974+
interrupt_before=interrupt_before,
975+
interrupt_after=interrupt_after,
976+
debug=debug,
977+
name=name,
978+
cache=cache,
979+
)
861980

862981

863982
def _resolve_jump(jump_to: JumpTo | None, first_node: str) -> str | None:
@@ -891,8 +1010,10 @@ def _make_model_to_tools_edge(
8911010
structured_output_tools: dict[str, OutputToolBinding],
8921011
tool_node: ToolNode,
8931012
exit_node: str,
894-
) -> Callable[[dict[str, Any]], str | list[Send] | None]:
895-
def model_to_tools(state: dict[str, Any]) -> str | list[Send] | None:
1013+
) -> Callable[[dict[str, Any], Runtime[ContextT]], str | list[Send] | None]:
1014+
def model_to_tools(
1015+
state: dict[str, Any], runtime: Runtime[ContextT]
1016+
) -> str | list[Send] | None:
8961017
# 1. if there's an explicit jump_to in the state, use it
8971018
if jump_to := state.get("jump_to"):
8981019
return _resolve_jump(jump_to, first_node)
@@ -914,36 +1035,69 @@ def model_to_tools(state: dict[str, Any]) -> str | list[Send] | None:
9141035
# 3. if there are pending tool calls, jump to the tool node
9151036
if pending_tool_calls:
9161037
pending_tool_calls = [
917-
tool_node.inject_tool_args(call, state, None) for call in pending_tool_calls
1038+
tool_node.inject_tool_args(call, state, runtime.store)
1039+
for call in pending_tool_calls
9181040
]
9191041
return [Send("tools", [tool_call]) for tool_call in pending_tool_calls]
9201042

921-
# 4. AIMessage has tool calls, but there are no pending tool calls
1043+
# 4. if there is a structured response, exit the loop
1044+
if "structured_response" in state:
1045+
return exit_node
1046+
1047+
# 5. AIMessage has tool calls, but there are no pending tool calls
9221048
# which suggests the injection of artificial tool messages. jump to the first node
9231049
return first_node
9241050

9251051
return model_to_tools
9261052

9271053

1054+
def _make_model_to_model_edge(
1055+
first_node: str,
1056+
exit_node: str,
1057+
) -> Callable[[dict[str, Any], Runtime[ContextT]], str | list[Send] | None]:
1058+
def model_to_model(
1059+
state: dict[str, Any],
1060+
runtime: Runtime[ContextT], # noqa: ARG001
1061+
) -> str | list[Send] | None:
1062+
# 1. Priority: Check for explicit jump_to directive from middleware
1063+
if jump_to := state.get("jump_to"):
1064+
return _resolve_jump(jump_to, first_node)
1065+
1066+
# 2. Exit condition: A structured response was generated
1067+
if "structured_response" in state:
1068+
return exit_node
1069+
1070+
# 3. Default: Continue the loop, there may have been an issue
1071+
# with structured output generation, so we need to retry
1072+
return first_node
1073+
1074+
return model_to_model
1075+
1076+
9281077
def _make_tools_to_model_edge(
9291078
tool_node: ToolNode,
9301079
next_node: str,
9311080
structured_output_tools: dict[str, OutputToolBinding],
9321081
exit_node: str,
933-
) -> Callable[[dict[str, Any]], str | None]:
934-
def tools_to_model(state: dict[str, Any]) -> str | None:
1082+
) -> Callable[[dict[str, Any], Runtime[ContextT]], str | None]:
1083+
def tools_to_model(state: dict[str, Any], runtime: Runtime[ContextT]) -> str | None: # noqa: ARG001
9351084
last_ai_message, tool_messages = _fetch_last_ai_and_tool_messages(state["messages"])
9361085

1086+
# 1. Exit condition: All executed tools have return_direct=True
9371087
if all(
9381088
tool_node.tools_by_name[c["name"]].return_direct
9391089
for c in last_ai_message.tool_calls
9401090
if c["name"] in tool_node.tools_by_name
9411091
):
9421092
return exit_node
9431093

1094+
# 2. Exit condition: A structured output tool was executed
9441095
if any(t.name in structured_output_tools for t in tool_messages):
9451096
return exit_node
9461097

1098+
# 3. Default: Continue the loop
1099+
# Tool execution completed successfully, route back to the model
1100+
# so it can process the tool results and decide the next action.
9471101
return next_node
9481102

9491103
return tools_to_model
@@ -960,7 +1114,6 @@ def _add_middleware_edge(
9601114
9611115
Args:
9621116
graph: The graph to add the edge to.
963-
method: The method to call for the middleware node.
9641117
name: The name of the middleware node.
9651118
default_destination: The default destination for the edge.
9661119
model_destination: The destination for the edge to the model.
@@ -984,3 +1137,8 @@ def jump_edge(state: dict[str, Any]) -> str:
9841137

9851138
else:
9861139
graph.add_edge(name, default_destination)
1140+
1141+
1142+
__all__ = [
1143+
"create_agent",
1144+
]

β€Žlibs/langchain_v1/langchain/agents/middleware/__init__.pyβ€Ž

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
"""Middleware plugins for agents."""
22

3-
from .call_tracking import ModelCallLimitMiddleware
43
from .context_editing import (
54
ClearToolUsesEdit,
65
ContextEditingMiddleware,
76
)
87
from .human_in_the_loop import HumanInTheLoopMiddleware
8+
from .model_call_limit import ModelCallLimitMiddleware
99
from .model_fallback import ModelFallbackMiddleware
1010
from .pii import PIIDetectionError, PIIMiddleware
1111
from .planning import PlanningMiddleware
File renamed without changes.

0 commit comments

Comments
Β (0)