1
- """Middleware agent implementation."""
1
+ """Agent factory for creating agents with middleware support."""
2
+
3
+ from __future__ import annotations
2
4
3
5
import itertools
4
- from collections .abc import Callable , Sequence
5
- from typing import Annotated , Any , cast , get_args , get_origin , get_type_hints
6
+ from typing import (
7
+ TYPE_CHECKING ,
8
+ Annotated ,
9
+ Any ,
10
+ cast ,
11
+ get_args ,
12
+ get_origin ,
13
+ get_type_hints ,
14
+ )
6
15
7
16
from langchain_core .language_models .chat_models import BaseChatModel
8
17
from langchain_core .messages import AIMessage , AnyMessage , SystemMessage , ToolMessage
9
- from langchain_core .runnables import Runnable
10
18
from langchain_core .tools import BaseTool
11
19
from langgraph ._internal ._runnable import RunnableCallable
12
20
from langgraph .constants import END , START
13
21
from langgraph .graph .state import StateGraph
14
- from langgraph .runtime import Runtime
22
+ from langgraph .runtime import Runtime # noqa: TC002
15
23
from langgraph .types import Send
16
- from langgraph .typing import ContextT
24
+ from langgraph .typing import ContextT # noqa: TC002
17
25
from typing_extensions import NotRequired , Required , TypedDict , TypeVar
18
26
19
27
from langchain .agents .middleware .types import (
37
45
from langchain .chat_models import init_chat_model
38
46
from langchain .tools import ToolNode
39
47
48
+ if TYPE_CHECKING :
49
+ from collections .abc import Callable , Sequence
50
+
51
+ from langchain_core .runnables import Runnable
52
+ from langgraph .cache .base import BaseCache
53
+ from langgraph .graph .state import CompiledStateGraph
54
+ from langgraph .store .base import BaseStore
55
+ from langgraph .types import Checkpointer
56
+
40
57
STRUCTURED_OUTPUT_ERROR_TEMPLATE = "Error: {error}\n Please fix your mistakes."
41
58
42
59
ResponseT = TypeVar ("ResponseT" )
@@ -176,17 +193,92 @@ def _handle_structured_output_error(
176
193
177
194
178
195
def create_agent ( # noqa: PLR0915
179
- * ,
180
196
model : str | BaseChatModel ,
181
197
tools : Sequence [BaseTool | Callable | dict [str , Any ]] | ToolNode | None = None ,
198
+ * ,
182
199
system_prompt : str | None = None ,
183
200
middleware : Sequence [AgentMiddleware [AgentState [ResponseT ], ContextT ]] = (),
184
201
response_format : ResponseFormat [ResponseT ] | type [ResponseT ] | None = None ,
185
202
context_schema : type [ContextT ] | None = None ,
186
- ) -> StateGraph [
203
+ checkpointer : Checkpointer | None = None ,
204
+ store : BaseStore | None = None ,
205
+ interrupt_before : list [str ] | None = None ,
206
+ interrupt_after : list [str ] | None = None ,
207
+ debug : bool = False ,
208
+ name : str | None = None ,
209
+ cache : BaseCache | None = None ,
210
+ ) -> CompiledStateGraph [
187
211
AgentState [ResponseT ], ContextT , PublicAgentState [ResponseT ], PublicAgentState [ResponseT ]
188
212
]:
189
- """Create a middleware agent graph."""
213
+ """Creates an agent graph that calls tools in a loop until a stopping condition is met.
214
+
215
+ For more details on using ``create_agent``,
216
+ visit [Agents](https://docs.langchain.com/oss/python/langchain/agents) documentation.
217
+
218
+ Args:
219
+ model: The language model for the agent. Can be a string identifier
220
+ (e.g., ``"openai:gpt-4"``), a chat model instance (e.g., ``ChatOpenAI()``).
221
+ tools: A list of tools or a ToolNode instance. If ``None`` or an empty list,
222
+ the agent will consist of a single LLM node without tool calling.
223
+ system_prompt: An optional system prompt for the LLM. If provided as a string,
224
+ it will be converted to a SystemMessage and added to the beginning
225
+ of the message list.
226
+ middleware: A sequence of middleware instances to apply to the agent.
227
+ Middleware can intercept and modify agent behavior at various stages.
228
+ response_format: An optional configuration for structured responses.
229
+ Can be a ToolStrategy, ProviderStrategy, or a Pydantic model class.
230
+ If provided, the agent will handle structured output during the
231
+ conversation flow. Raw schemas will be wrapped in an appropriate strategy
232
+ based on model capabilities.
233
+ context_schema: An optional schema for runtime context.
234
+ checkpointer: An optional checkpoint saver object. This is used for persisting
235
+ the state of the graph (e.g., as chat memory) for a single thread
236
+ (e.g., a single conversation).
237
+ store: An optional store object. This is used for persisting data
238
+ across multiple threads (e.g., multiple conversations / users).
239
+ interrupt_before: An optional list of node names to interrupt before.
240
+ This is useful if you want to add a user confirmation or other interrupt
241
+ before taking an action.
242
+ interrupt_after: An optional list of node names to interrupt after.
243
+ This is useful if you want to return directly or run additional processing
244
+ on an output.
245
+ debug: A flag indicating whether to enable debug mode.
246
+ name: An optional name for the CompiledStateGraph.
247
+ This name will be automatically used when adding the agent graph to
248
+ another graph as a subgraph node - particularly useful for building
249
+ multi-agent systems.
250
+ cache: An optional BaseCache instance to enable caching of graph execution.
251
+
252
+ Returns:
253
+ A compiled StateGraph that can be used for chat interactions.
254
+
255
+ The agent node calls the language model with the messages list (after applying
256
+ the system prompt). If the resulting AIMessage contains ``tool_calls``, the graph will
257
+ then call the tools. The tools node executes the tools and adds the responses
258
+ to the messages list as ``ToolMessage`` objects. The agent node then calls the
259
+ language model again. The process repeats until no more ``tool_calls`` are
260
+ present in the response. The agent then returns the full list of messages.
261
+
262
+ Example:
263
+ ```python
264
+ from langchain.agents import create_agent
265
+
266
+
267
+ def check_weather(location: str) -> str:
268
+ '''Return the weather forecast for the specified location.'''
269
+ return f"It's always sunny in {location}"
270
+
271
+
272
+ graph = create_agent(
273
+ model="anthropic:claude-3-7-sonnet-latest",
274
+ tools=[check_weather],
275
+ system_prompt="You are a helpful assistant",
276
+ )
277
+ inputs = {"messages": [{"role": "user", "content": "what is the weather in sf"}]}
278
+ for chunk in graph.stream(inputs, stream_mode="updates"):
279
+ print(chunk)
280
+ ```
281
+ """
190
282
# init chat model
191
283
if isinstance (model , str ):
192
284
model = init_chat_model (model )
@@ -230,7 +322,7 @@ def create_agent( # noqa: PLR0915
230
322
# Setup tools
231
323
tool_node : ToolNode | None = None
232
324
if isinstance (tools , list ):
233
- # Extract built-in provider tools (dict format) and regular tools (BaseTool)
325
+ # Extract built-in provider tools (dict format) and regular tools (BaseTool/callables )
234
326
built_in_tools = [t for t in tools if isinstance (t , dict )]
235
327
regular_tools = [t for t in tools if not isinstance (t , dict )]
236
328
@@ -241,9 +333,13 @@ def create_agent( # noqa: PLR0915
241
333
tool_node = ToolNode (tools = available_tools ) if available_tools else None
242
334
243
335
# Default tools for ModelRequest initialization
244
- # Include built-ins and regular tools (can be changed dynamically by middleware)
336
+ # Use converted BaseTool instances from ToolNode (not raw callables)
337
+ # Include built-ins and converted tools (can be changed dynamically by middleware)
245
338
# Structured tools are NOT included - they're added dynamically based on response_format
246
- default_tools = regular_tools + middleware_tools + built_in_tools
339
+ if tool_node :
340
+ default_tools = list (tool_node .tools_by_name .values ()) + built_in_tools
341
+ else :
342
+ default_tools = list (built_in_tools )
247
343
elif isinstance (tools , ToolNode ):
248
344
tool_node = tools
249
345
if tool_node :
@@ -252,10 +348,15 @@ def create_agent( # noqa: PLR0915
252
348
tool_node = ToolNode (available_tools )
253
349
254
350
# default_tools includes all client-side tools (no built-ins or structured tools)
255
- default_tools = available_tools
351
+ default_tools = list (tool_node .tools_by_name .values ())
352
+ else :
353
+ default_tools = []
354
+ # No tools provided, only middleware_tools available
355
+ elif middleware_tools :
356
+ tool_node = ToolNode (middleware_tools )
357
+ default_tools = list (tool_node .tools_by_name .values ())
256
358
else :
257
- # No tools provided, only middleware_tools available
258
- default_tools = middleware_tools
359
+ default_tools = []
259
360
260
361
# validate middleware
261
362
assert len ({m .name for m in middleware }) == len (middleware ), ( # noqa: S101
@@ -425,16 +526,19 @@ def _get_bound_model(request: ModelRequest) -> tuple[Runnable, ResponseFormat |
425
526
is the actual strategy used (may differ from initial if auto-detected).
426
527
"""
427
528
# Validate ONLY client-side tools that need to exist in tool_node
428
- # Build map of available client-side tools (regular_tools + middleware_tools)
429
- available_tools_by_name = {t .name : t for t in default_tools if isinstance (t , BaseTool )}
529
+ # Build map of available client-side tools from the ToolNode
530
+ # (which has already converted callables)
531
+ available_tools_by_name = {}
532
+ if tool_node :
533
+ available_tools_by_name = tool_node .tools_by_name .copy ()
430
534
431
535
# Check if any requested tools are unknown CLIENT-SIDE tools
432
536
unknown_tool_names = []
433
537
for t in request .tools :
434
538
# Only validate BaseTool instances (skip built-in dict tools)
435
539
if isinstance (t , dict ):
436
540
continue
437
- if t .name not in available_tools_by_name :
541
+ if isinstance ( t , BaseTool ) and t .name not in available_tools_by_name :
438
542
unknown_tool_names .append (t .name )
439
543
440
544
if unknown_tool_names :
@@ -468,7 +572,8 @@ def _get_bound_model(request: ModelRequest) -> tuple[Runnable, ResponseFormat |
468
572
effective_response_format = request .response_format
469
573
470
574
# Build final tools list including structured output tools
471
- # request.tools already contains both BaseTool and dict (built-in) tools
575
+ # request.tools now only contains BaseTool instances (converted from callables)
576
+ # and dicts (built-ins)
472
577
final_tools = list (request .tools )
473
578
if isinstance (effective_response_format , ToolStrategy ):
474
579
# Add structured output tools to final tools list
@@ -767,6 +872,12 @@ async def amodel_request(state: AgentState, runtime: Runtime[ContextT]) -> dict[
767
872
),
768
873
[loop_entry_node , "tools" , exit_node ],
769
874
)
875
+ elif len (structured_output_tools ) > 0 :
876
+ graph .add_conditional_edges (
877
+ loop_exit_node ,
878
+ _make_model_to_model_edge (loop_entry_node , exit_node ),
879
+ [loop_entry_node , exit_node ],
880
+ )
770
881
elif loop_exit_node == "model_request" :
771
882
# If no tools and no after_model, go directly to exit_node
772
883
graph .add_edge (loop_exit_node , exit_node )
@@ -857,7 +968,15 @@ async def amodel_request(state: AgentState, runtime: Runtime[ContextT]) -> dict[
857
968
can_jump_to = _get_can_jump_to (middleware_w_after_agent [0 ], "after_agent" ),
858
969
)
859
970
860
- return graph
971
+ return graph .compile (
972
+ checkpointer = checkpointer ,
973
+ store = store ,
974
+ interrupt_before = interrupt_before ,
975
+ interrupt_after = interrupt_after ,
976
+ debug = debug ,
977
+ name = name ,
978
+ cache = cache ,
979
+ )
861
980
862
981
863
982
def _resolve_jump (jump_to : JumpTo | None , first_node : str ) -> str | None :
@@ -891,8 +1010,10 @@ def _make_model_to_tools_edge(
891
1010
structured_output_tools : dict [str , OutputToolBinding ],
892
1011
tool_node : ToolNode ,
893
1012
exit_node : str ,
894
- ) -> Callable [[dict [str , Any ]], str | list [Send ] | None ]:
895
- def model_to_tools (state : dict [str , Any ]) -> str | list [Send ] | None :
1013
+ ) -> Callable [[dict [str , Any ], Runtime [ContextT ]], str | list [Send ] | None ]:
1014
+ def model_to_tools (
1015
+ state : dict [str , Any ], runtime : Runtime [ContextT ]
1016
+ ) -> str | list [Send ] | None :
896
1017
# 1. if there's an explicit jump_to in the state, use it
897
1018
if jump_to := state .get ("jump_to" ):
898
1019
return _resolve_jump (jump_to , first_node )
@@ -914,36 +1035,69 @@ def model_to_tools(state: dict[str, Any]) -> str | list[Send] | None:
914
1035
# 3. if there are pending tool calls, jump to the tool node
915
1036
if pending_tool_calls :
916
1037
pending_tool_calls = [
917
- tool_node .inject_tool_args (call , state , None ) for call in pending_tool_calls
1038
+ tool_node .inject_tool_args (call , state , runtime .store )
1039
+ for call in pending_tool_calls
918
1040
]
919
1041
return [Send ("tools" , [tool_call ]) for tool_call in pending_tool_calls ]
920
1042
921
- # 4. AIMessage has tool calls, but there are no pending tool calls
1043
+ # 4. if there is a structured response, exit the loop
1044
+ if "structured_response" in state :
1045
+ return exit_node
1046
+
1047
+ # 5. AIMessage has tool calls, but there are no pending tool calls
922
1048
# which suggests the injection of artificial tool messages. jump to the first node
923
1049
return first_node
924
1050
925
1051
return model_to_tools
926
1052
927
1053
1054
+ def _make_model_to_model_edge (
1055
+ first_node : str ,
1056
+ exit_node : str ,
1057
+ ) -> Callable [[dict [str , Any ], Runtime [ContextT ]], str | list [Send ] | None ]:
1058
+ def model_to_model (
1059
+ state : dict [str , Any ],
1060
+ runtime : Runtime [ContextT ], # noqa: ARG001
1061
+ ) -> str | list [Send ] | None :
1062
+ # 1. Priority: Check for explicit jump_to directive from middleware
1063
+ if jump_to := state .get ("jump_to" ):
1064
+ return _resolve_jump (jump_to , first_node )
1065
+
1066
+ # 2. Exit condition: A structured response was generated
1067
+ if "structured_response" in state :
1068
+ return exit_node
1069
+
1070
+ # 3. Default: Continue the loop, there may have been an issue
1071
+ # with structured output generation, so we need to retry
1072
+ return first_node
1073
+
1074
+ return model_to_model
1075
+
1076
+
928
1077
def _make_tools_to_model_edge (
929
1078
tool_node : ToolNode ,
930
1079
next_node : str ,
931
1080
structured_output_tools : dict [str , OutputToolBinding ],
932
1081
exit_node : str ,
933
- ) -> Callable [[dict [str , Any ]], str | None ]:
934
- def tools_to_model (state : dict [str , Any ]) -> str | None :
1082
+ ) -> Callable [[dict [str , Any ], Runtime [ ContextT ] ], str | None ]:
1083
+ def tools_to_model (state : dict [str , Any ], runtime : Runtime [ ContextT ] ) -> str | None : # noqa: ARG001
935
1084
last_ai_message , tool_messages = _fetch_last_ai_and_tool_messages (state ["messages" ])
936
1085
1086
+ # 1. Exit condition: All executed tools have return_direct=True
937
1087
if all (
938
1088
tool_node .tools_by_name [c ["name" ]].return_direct
939
1089
for c in last_ai_message .tool_calls
940
1090
if c ["name" ] in tool_node .tools_by_name
941
1091
):
942
1092
return exit_node
943
1093
1094
+ # 2. Exit condition: A structured output tool was executed
944
1095
if any (t .name in structured_output_tools for t in tool_messages ):
945
1096
return exit_node
946
1097
1098
+ # 3. Default: Continue the loop
1099
+ # Tool execution completed successfully, route back to the model
1100
+ # so it can process the tool results and decide the next action.
947
1101
return next_node
948
1102
949
1103
return tools_to_model
@@ -960,7 +1114,6 @@ def _add_middleware_edge(
960
1114
961
1115
Args:
962
1116
graph: The graph to add the edge to.
963
- method: The method to call for the middleware node.
964
1117
name: The name of the middleware node.
965
1118
default_destination: The default destination for the edge.
966
1119
model_destination: The destination for the edge to the model.
@@ -984,3 +1137,8 @@ def jump_edge(state: dict[str, Any]) -> str:
984
1137
985
1138
else :
986
1139
graph .add_edge (name , default_destination )
1140
+
1141
+
1142
+ __all__ = [
1143
+ "create_agent" ,
1144
+ ]
0 commit comments