1
- from langchain_core .messages import AIMessage
1
+ import json
2
+
3
+ from langchain_core .messages import AIMessage , ToolMessage
2
4
from langgraph .graph import END , START , StateGraph
3
5
4
6
from template_langgraph .agents .basic_workflow_agent .models import AgentInput , AgentOutput , AgentState , Profile
5
7
from template_langgraph .llms .azure_openais import AzureOpenAiWrapper
6
8
from template_langgraph .loggers import get_logger
9
+ from template_langgraph .tools .elasticsearch_tool import search_elasticsearch
10
+ from template_langgraph .tools .qdrants import search_qdrant
7
11
8
12
logger = get_logger (__name__ )
9
13
10
14
15
+ class BasicToolNode :
16
+ """A node that runs the tools requested in the last AIMessage."""
17
+
18
+ def __init__ (self , tools : list ) -> None :
19
+ self .tools_by_name = {tool .name : tool for tool in tools }
20
+
21
+ def __call__ (self , inputs : dict ):
22
+ if messages := inputs .get ("messages" , []):
23
+ message = messages [- 1 ]
24
+ else :
25
+ raise ValueError ("No message found in input" )
26
+ outputs = []
27
+ for tool_call in message .tool_calls :
28
+ tool_result = self .tools_by_name [tool_call ["name" ]].invoke (tool_call ["args" ])
29
+ outputs .append (
30
+ ToolMessage (
31
+ content = json .dumps (tool_result .__str__ (), ensure_ascii = False ),
32
+ name = tool_call ["name" ],
33
+ tool_call_id = tool_call ["id" ],
34
+ )
35
+ )
36
+ return {"messages" : outputs }
37
+
38
+
11
39
class BasicWorkflowAgent :
12
40
def __init__ (self ):
13
41
self .llm = AzureOpenAiWrapper ().chat_model
@@ -21,13 +49,34 @@ def create_graph(self):
21
49
workflow .add_node ("initialize" , self .initialize )
22
50
workflow .add_node ("do_something" , self .do_something )
23
51
workflow .add_node ("extract_profile" , self .extract_profile )
52
+ workflow .add_node ("chat_with_tools" , self .chat_with_tools )
53
+ workflow .add_node (
54
+ "tools" ,
55
+ BasicToolNode (
56
+ tools = [
57
+ search_qdrant ,
58
+ search_elasticsearch ,
59
+ ]
60
+ ),
61
+ )
24
62
workflow .add_node ("finalize" , self .finalize )
25
63
26
64
# Create edges
27
65
workflow .add_edge (START , "initialize" )
28
66
workflow .add_edge ("initialize" , "do_something" )
29
67
workflow .add_edge ("do_something" , "extract_profile" )
30
- workflow .add_edge ("extract_profile" , "finalize" )
68
+ workflow .add_edge ("extract_profile" , "chat_with_tools" )
69
+ workflow .add_conditional_edges (
70
+ "chat_with_tools" ,
71
+ self .route_tools ,
72
+ # The following dictionary lets you tell the graph to interpret the condition's outputs as a specific node
73
+ # It defaults to the identity function, but if you
74
+ # want to use a node named something else apart from "tools",
75
+ # You can update the value of the dictionary to something else
76
+ # e.g., "tools": "my_tools"
77
+ {"tools" : "tools" , END : "finalize" },
78
+ )
79
+ workflow .add_edge ("tools" , "chat_with_tools" )
31
80
workflow .add_edge ("finalize" , END )
32
81
33
82
# Compile the graph
@@ -66,6 +115,39 @@ def extract_profile(self, state: AgentState) -> AgentState:
66
115
state ["profile" ] = profile
67
116
return state
68
117
118
+ def chat_with_tools (self , state : AgentState ) -> AgentState :
119
+ """Chat with tools using the state."""
120
+ logger .info (f"Chatting with tools using state: { state } " )
121
+ llm_with_tools = self .llm .bind_tools (
122
+ tools = [
123
+ search_qdrant ,
124
+ search_elasticsearch ,
125
+ ],
126
+ )
127
+ return {
128
+ "messages" : [
129
+ llm_with_tools .invoke (state ["messages" ]),
130
+ ]
131
+ }
132
+
133
+ def route_tools (
134
+ self ,
135
+ state : AgentState ,
136
+ ):
137
+ """
138
+ Use in the conditional_edge to route to the ToolNode if the last message
139
+ has tool calls. Otherwise, route to the end.
140
+ """
141
+ if isinstance (state , list ):
142
+ ai_message = state [- 1 ]
143
+ elif messages := state .get ("messages" , []):
144
+ ai_message = messages [- 1 ]
145
+ else :
146
+ raise ValueError (f"No messages found in input state to tool_edge: { state } " )
147
+ if hasattr (ai_message , "tool_calls" ) and len (ai_message .tool_calls ) > 0 :
148
+ return "tools"
149
+ return END
150
+
69
151
def finalize (self , state : AgentState ) -> AgentState :
70
152
"""Finalize the agent's work and prepare the output."""
71
153
logger .info (f"Finalizing BasicWorkflowAgent with state: { state } " )
0 commit comments