66from typing import Any , List
77from typing_extensions import Literal
88from langchain_openai import ChatOpenAI
9- from langchain_core .messages import SystemMessage
9+ from langchain_core .messages import SystemMessage , BaseMessage
1010from langchain_core .runnables import RunnableConfig
1111from langchain .tools import tool
1212from langgraph .graph import StateGraph , END
@@ -39,11 +39,15 @@ def get_weather(location: str):
3939# print(f"Your tool logic here")
4040# return "Your tool response here."
4141
42- tools = [
42+ backend_tools = [
4343 get_weather
4444 # your_tool_here
4545]
4646
47+ # Extract tool names from backend_tools for comparison
48+ backend_tool_names = [tool .name for tool in backend_tools ]
49+
50+
4751async def chat_node (state : AgentState , config : RunnableConfig ) -> Command [Literal ["tool_node" , "__end__" ]]:
4852 """
4953 Standard chat node based on the ReAct design pattern. It handles:
@@ -63,7 +67,7 @@ async def chat_node(state: AgentState, config: RunnableConfig) -> Command[Litera
6367 model_with_tools = model .bind_tools (
6468 [
6569 * state .get ("tools" , []), # bind tools defined by ag-ui
66- get_weather ,
70+ * backend_tools ,
6771 # your_tool_here
6872 ],
6973
@@ -84,18 +88,41 @@ async def chat_node(state: AgentState, config: RunnableConfig) -> Command[Litera
8488 * state ["messages" ],
8589 ], config )
8690
91+ # only route to tool node if tool is not in the tools list
92+ if route_to_tool_node (response ):
93+ print ("routing to tool node" )
94+ return Command (
95+ goto = "tool_node" ,
96+ update = {
97+ "messages" : [response ],
98+ }
99+ )
100+
87101 # 5. We've handled all tool calls, so we can end the graph.
88102 return Command (
89103 goto = END ,
90104 update = {
91- "messages" : response
105+ "messages" : [ response ],
92106 }
93107 )
94108
109+ def route_to_tool_node (response : BaseMessage ):
110+ """
111+ Route to tool node if any tool call in the response matches a backend tool name.
112+ """
113+ tool_calls = getattr (response , "tool_calls" , None )
114+ if not tool_calls :
115+ return False
116+
117+ for tool_call in tool_calls :
118+ if tool_call .get ("name" ) in backend_tool_names :
119+ return True
120+ return False
121+
95122# Define the workflow graph
96123workflow = StateGraph (AgentState )
97124workflow .add_node ("chat_node" , chat_node )
98- workflow .add_node ("tool_node" , ToolNode (tools = tools ))
125+ workflow .add_node ("tool_node" , ToolNode (tools = backend_tools ))
99126workflow .add_edge ("tool_node" , "chat_node" )
100127workflow .set_entry_point ("chat_node" )
101128
0 commit comments