diff --git a/docs/index.md b/docs/index.md index 82b3f2b..e82f700 100644 --- a/docs/index.md +++ b/docs/index.md @@ -27,12 +27,18 @@ uv run python -m template_langgraph.tasks.run_kabuto_helpdesk_agent "KABUTOの # BasicWorkflowAgent uv run python -m template_langgraph.tasks.draw_basic_workflow_agent_mermaid_png "data/basic_workflow_agent.png" -uv run python -m template_langgraph.tasks.run_basic_workflow_agent "KABUTOの起動時に、画面全体が紫色に点滅し、システムがフリーズします。" -uv run python -m template_langgraph.tasks.run_basic_workflow_agent "私の名前はフグ田 サザエ。東京都世田谷区桜新町あさひが丘3丁目に住んでいる 24 歳の主婦です。夫のノリスケと子供のタラちゃんがいます。" +uv run python -m template_langgraph.tasks.run_basic_workflow_agent +# 私の名前はフグ田 サザエ。東京都世田谷区桜新町あさひが丘3丁目に住んでいる 24 歳の主婦です。夫のノリスケと子供のタラちゃんがいます +# KABUTOの起動時に、画面全体が紫色に点滅し、システムがフリーズします。KABUTO のマニュアルから、関連する情報を取得したり過去のシステムのトラブルシュート事例が蓄積されたデータベースから、関連する情報を取得して質問に答えてください +# 天狗のいたずら という現象について KABUTO のマニュアルから、関連する情報を取得したり過去のシステムのトラブルシュート事例が蓄積されたデータベースから、関連する情報を取得して質問に答えてください ``` ## References +### LangGraph + +- [Build a custom workflow](https://langchain-ai.github.io/langgraph/concepts/why-langgraph/) + ### Sample Codes - [「現場で活用するためのAIエージェント実践入門」リポジトリ](https://github.com/masamasa59/genai-agent-advanced-book) diff --git a/template_langgraph/agents/basic_workflow_agent/agent.py b/template_langgraph/agents/basic_workflow_agent/agent.py index cfa93b7..8ca90a2 100644 --- a/template_langgraph/agents/basic_workflow_agent/agent.py +++ b/template_langgraph/agents/basic_workflow_agent/agent.py @@ -1,13 +1,41 @@ -from langchain_core.messages import AIMessage +import json + +from langchain_core.messages import AIMessage, ToolMessage from langgraph.graph import END, START, StateGraph from template_langgraph.agents.basic_workflow_agent.models import AgentInput, AgentOutput, AgentState, Profile from template_langgraph.llms.azure_openais import AzureOpenAiWrapper from template_langgraph.loggers import get_logger +from template_langgraph.tools.elasticsearch_tool import search_elasticsearch +from template_langgraph.tools.qdrants import search_qdrant logger = get_logger(__name__) +class BasicToolNode: + """A node that runs the tools requested in the last AIMessage.""" + + def __init__(self, tools: list) -> None: + self.tools_by_name = {tool.name: tool for tool in tools} + + def __call__(self, inputs: dict): + if messages := inputs.get("messages", []): + message = messages[-1] + else: + raise ValueError("No message found in input") + outputs = [] + for tool_call in message.tool_calls: + tool_result = self.tools_by_name[tool_call["name"]].invoke(tool_call["args"]) + outputs.append( + ToolMessage( + content=json.dumps(tool_result.__str__(), ensure_ascii=False), + name=tool_call["name"], + tool_call_id=tool_call["id"], + ) + ) + return {"messages": outputs} + + class BasicWorkflowAgent: def __init__(self): self.llm = AzureOpenAiWrapper().chat_model @@ -21,13 +49,34 @@ def create_graph(self): workflow.add_node("initialize", self.initialize) workflow.add_node("do_something", self.do_something) workflow.add_node("extract_profile", self.extract_profile) + workflow.add_node("chat_with_tools", self.chat_with_tools) + workflow.add_node( + "tools", + BasicToolNode( + tools=[ + search_qdrant, + search_elasticsearch, + ] + ), + ) workflow.add_node("finalize", self.finalize) # Create edges workflow.add_edge(START, "initialize") workflow.add_edge("initialize", "do_something") workflow.add_edge("do_something", "extract_profile") - workflow.add_edge("extract_profile", "finalize") + workflow.add_edge("extract_profile", "chat_with_tools") + workflow.add_conditional_edges( + "chat_with_tools", + self.route_tools, + # The following dictionary lets you tell the graph to interpret the condition's outputs as a specific node + # It defaults to the identity function, but if you + # want to use a node named something else apart from "tools", + # You can update the value of the dictionary to something else + # e.g., "tools": "my_tools" + {"tools": "tools", END: "finalize"}, + ) + workflow.add_edge("tools", "chat_with_tools") workflow.add_edge("finalize", END) # Compile the graph @@ -66,6 +115,39 @@ def extract_profile(self, state: AgentState) -> AgentState: state["profile"] = profile return state + def chat_with_tools(self, state: AgentState) -> AgentState: + """Chat with tools using the state.""" + logger.info(f"Chatting with tools using state: {state}") + llm_with_tools = self.llm.bind_tools( + tools=[ + search_qdrant, + search_elasticsearch, + ], + ) + return { + "messages": [ + llm_with_tools.invoke(state["messages"]), + ] + } + + def route_tools( + self, + state: AgentState, + ): + """ + Use in the conditional_edge to route to the ToolNode if the last message + has tool calls. Otherwise, route to the end. + """ + if isinstance(state, list): + ai_message = state[-1] + elif messages := state.get("messages", []): + ai_message = messages[-1] + else: + raise ValueError(f"No messages found in input state to tool_edge: {state}") + if hasattr(ai_message, "tool_calls") and len(ai_message.tool_calls) > 0: + return "tools" + return END + def finalize(self, state: AgentState) -> AgentState: """Finalize the agent's work and prepare the output.""" logger.info(f"Finalizing BasicWorkflowAgent with state: {state}") diff --git a/template_langgraph/tasks/run_basic_workflow_agent.py b/template_langgraph/tasks/run_basic_workflow_agent.py index e67e091..b59e760 100644 --- a/template_langgraph/tasks/run_basic_workflow_agent.py +++ b/template_langgraph/tasks/run_basic_workflow_agent.py @@ -1,27 +1,33 @@ import logging -import sys -from template_langgraph.agents.basic_workflow_agent.agent import AgentInput, BasicWorkflowAgent +from template_langgraph.agents.basic_workflow_agent.agent import AgentState +from template_langgraph.agents.basic_workflow_agent.agent import graph as basic_workflow_agent_graph from template_langgraph.loggers import get_logger logger = get_logger(__name__) logger.setLevel(logging.INFO) -if __name__ == "__main__": - question = "「鬼灯」を実行すると、KABUTOが急に停止します。原因と対策を教えてください。" - if len(sys.argv) > 1: - # sys.argv[1] が最初の引数 - question = sys.argv[1] - # Agentのインスタンス化 - agent = BasicWorkflowAgent() +def stream_graph_updates( + state: AgentState, +) -> dict: + for event in basic_workflow_agent_graph.stream(input=state): + logger.info("-" * 20) + logger.info(f"Event: {event}") + return event - # AgentInputの作成 - agent_input = AgentInput( - request=question, - ) - # エージェントの実行 - logger.info(f"Running BasicWorkflowAgent with input: {agent_input.model_dump_json(indent=2)}") - agent_output = agent.run_agent(input=agent_input) - logger.info(f"Agent output: {agent_output.model_dump_json(indent=2)}") +if __name__ == "__main__": + user_input = input("User: ") + state = AgentState( + messages=[ + { + "role": "user", + "content": user_input, + } + ], + profile=None, + ) + last_event = stream_graph_updates(state) + for value in last_event.values(): + logger.info(f"Final state: {value['messages'][-1].content}") # noqa: E501