Skip to content

Commit 478e3cd

Browse files
committed
refactor chat_with_tools agent
1 parent 9324c9d commit 478e3cd

File tree

4 files changed

+21
-52
lines changed

4 files changed

+21
-52
lines changed

template_langgraph/agents/chat_with_tools_agent/agent.py

Lines changed: 7 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,6 @@ def create_graph(self):
4646
workflow = StateGraph(AgentState)
4747

4848
# Create nodes
49-
workflow.add_node("initialize", self.initialize)
5049
workflow.add_node("chat_with_tools", self.chat_with_tools)
5150
workflow.add_node(
5251
"tools",
@@ -57,35 +56,24 @@ def create_graph(self):
5756
]
5857
),
5958
)
60-
workflow.add_node("finalize", self.finalize)
6159

6260
# Create edges
63-
workflow.set_entry_point("initialize")
64-
workflow.add_edge("initialize", "chat_with_tools")
61+
workflow.set_entry_point("chat_with_tools")
6562
workflow.add_conditional_edges(
66-
"chat_with_tools",
67-
self.route_tools,
68-
# The following dictionary lets you tell the graph to interpret the condition's outputs as a specific node
69-
# It defaults to the identity function, but if you
70-
# want to use a node named something else apart from "tools",
71-
# You can update the value of the dictionary to something else
72-
# e.g., "tools": "my_tools"
73-
{"tools": "tools", END: "finalize"},
63+
source="chat_with_tools",
64+
path=self.route_tools,
65+
path_map={
66+
"tools": "tools",
67+
END: END,
68+
},
7469
)
7570
workflow.add_edge("tools", "chat_with_tools")
76-
workflow.add_edge("finalize", END)
7771

7872
# Compile the graph
7973
return workflow.compile(
8074
name=ChatWithToolsAgent.__name__,
8175
)
8276

83-
def initialize(self, state: AgentState) -> AgentState:
84-
"""Initialize the agent with the given state."""
85-
logger.info(f"Initializing ChatWithToolsAgent with state: {state}")
86-
# Here you can add any initialization logic if needed
87-
return state
88-
8977
def chat_with_tools(self, state: AgentState) -> AgentState:
9078
"""Chat with tools using the state."""
9179
logger.info(f"Chatting with tools using state: {state}")
@@ -119,15 +107,5 @@ def route_tools(
119107
return "tools"
120108
return END
121109

122-
def finalize(self, state: AgentState) -> AgentState:
123-
"""Finalize the agent's work and prepare the output."""
124-
logger.info(f"Finalizing ChatWithToolsAgent with state: {state}")
125-
# Here you can add any finalization logic if needed
126-
return state
127-
128-
def draw_mermaid_png(self) -> bytes:
129-
"""Draw the graph in Mermaid format."""
130-
return self.create_graph().get_graph().draw_mermaid_png()
131-
132110

133111
graph = ChatWithToolsAgent().create_graph()

template_langgraph/agents/chat_with_tools_agent/models.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,6 @@
88
BaseMessage,
99
)
1010
from langgraph.graph.message import add_messages
11-
from pydantic import BaseModel, Field
12-
13-
14-
class AgentInput(BaseModel):
15-
request: str = Field(..., description="ユーザーからのリクエスト")
16-
17-
18-
class AgentOutput(BaseModel):
19-
response: str = Field(..., description="エージェントの応答")
2011

2112

2213
class AgentState(TypedDict):
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
import sys
22

3-
from template_langgraph.agents.chat_with_tools_agent.agent import ChatWithToolsAgent
3+
from template_langgraph.agents.chat_with_tools_agent.agent import graph
44

55
if __name__ == "__main__":
66
png_path = "data/chat_with_tools_agent.png"
77
if len(sys.argv) > 1:
88
png_path = sys.argv[1]
99

1010
with open(png_path, "wb") as f:
11-
f.write(ChatWithToolsAgent().draw_mermaid_png())
11+
f.write(graph.get_graph().draw_mermaid_png())

template_langgraph/tasks/run_chat_with_tools_agent.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -19,15 +19,15 @@ def stream_graph_updates(
1919

2020
if __name__ == "__main__":
2121
user_input = input("User: ")
22-
state = AgentState(
23-
messages=[
24-
{
25-
"role": "user",
26-
"content": user_input,
27-
}
28-
],
29-
profile=None,
30-
)
31-
last_event = stream_graph_updates(state)
32-
for value in last_event.values():
33-
logger.info(f"Final state: {value['messages'][-1].content}") # noqa: E501
22+
for event in chat_with_tools_agent_graph.stream(
23+
input=AgentState(
24+
messages=[
25+
{
26+
"role": "user",
27+
"content": user_input,
28+
}
29+
],
30+
)
31+
):
32+
logger.info("-" * 20)
33+
logger.info(f"Event: {event}")

0 commit comments

Comments
 (0)