Skip to content

Commit f125f1e

Browse files
authored
Merge pull request #108 from ks6088ts-labs/feature/issue-107_generic-agent-runner
add generic agent runner
2 parents 31e644d + e729973 commit f125f1e

File tree

1 file changed

+131
-0
lines changed

1 file changed

+131
-0
lines changed
Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
import streamlit as st
2+
from langchain_community.callbacks.streamlit import (
3+
StreamlitCallbackHandler,
4+
)
5+
from langgraph.graph.state import CompiledStateGraph
6+
7+
from template_langgraph.agents.demo_agents.parallel_rag_agent.agent import ParallelRagAgent
8+
from template_langgraph.agents.demo_agents.weather_agent import graph as weather_agent_graph
9+
from template_langgraph.llms.azure_openais import AzureOpenAiWrapper
10+
from template_langgraph.tools.common import get_default_tools
11+
12+
13+
# 追加: 各エージェントのグラフ生成ファクトリと tool call 対応フラグ
14+
def _make_parallel_rag_graph(tools):
15+
graph = ParallelRagAgent(
16+
llm=AzureOpenAiWrapper().chat_model,
17+
tools=tools,
18+
).create_graph()
19+
20+
# 追加: このエージェント用の入力ビルダー
21+
def build_input(prompt, callbacks):
22+
return {"query": prompt, "callbacks": callbacks}
23+
24+
return {"graph": graph, "build_input": build_input}
25+
26+
27+
def _make_weather_graph(_tools=None):
28+
# weather_agent_graph が CompiledStateGraph を指している前提
29+
graph = weather_agent_graph
30+
31+
# 追加: このエージェント用の入力ビルダー(必要に応じてキー名を調整)
32+
def build_input(prompt, callbacks):
33+
return {
34+
"messages": [
35+
prompt,
36+
],
37+
"callbacks": callbacks,
38+
}
39+
40+
return {"graph": graph, "build_input": build_input}
41+
42+
43+
agent_options = {
44+
"Parallel RAG Agent": {
45+
"supports_tools": True,
46+
"factory": _make_parallel_rag_graph,
47+
},
48+
"Weather Agent": {
49+
"supports_tools": False,
50+
"factory": _make_weather_graph,
51+
},
52+
# "Another Agent": {"supports_tools": True/False, "factory": your_factory},
53+
}
54+
55+
56+
def create_graph() -> CompiledStateGraph:
57+
# ...existing code...
58+
cfg = agent_options.get(selected_agent_key) or next(iter(agent_options.values()))
59+
supports_tools = cfg.get("supports_tools", True)
60+
factory = cfg["factory"]
61+
result = factory(selected_tools if supports_tools else None)
62+
# 追加: 入力ビルダーを保存(無ければデフォルトにフォールバック)
63+
st.session_state["input_builder"] = result.get("build_input") or (lambda p, cbs: {"query": p, "callbacks": cbs})
64+
return result["graph"]
65+
# ...existing code...
66+
67+
68+
# Sidebar: ツール選択とエージェントの構築
69+
with st.sidebar:
70+
# 追加: エージェント選択 UI
71+
st.subheader("使用するエージェント")
72+
available_agent_keys = list(agent_options.keys())
73+
if "selected_agent_key" not in st.session_state:
74+
st.session_state["selected_agent_key"] = available_agent_keys[0]
75+
selected_agent_key = st.selectbox(
76+
"実行するエージェントを選択",
77+
options=available_agent_keys,
78+
index=available_agent_keys.index(st.session_state["selected_agent_key"]),
79+
)
80+
st.session_state["selected_agent_key"] = selected_agent_key
81+
82+
# エージェントの tool call 対応フラグを取得
83+
supports_tools = agent_options[selected_agent_key].get("supports_tools", True)
84+
85+
# ツール選択 UI(supports_tools が True の時のみ表示)
86+
if supports_tools:
87+
st.subheader("使用するツール")
88+
# 利用可能なツール一覧を取得
89+
available_tools = get_default_tools()
90+
tool_name_to_obj = {t.name: t for t in available_tools}
91+
tool_names = list(tool_name_to_obj.keys())
92+
93+
# 初期選択は全選択
94+
if "selected_tool_names" not in st.session_state:
95+
st.session_state["selected_tool_names"] = tool_names
96+
97+
selected_tool_names = st.multiselect(
98+
"有効化するツールを選択",
99+
options=tool_names,
100+
default=st.session_state["selected_tool_names"],
101+
)
102+
st.session_state["selected_tool_names"] = selected_tool_names
103+
selected_tools = [tool_name_to_obj[name] for name in selected_tool_names]
104+
signature = (selected_agent_key, tuple(selected_tool_names))
105+
else:
106+
# 非対応時はツール選択をスキップ
107+
selected_tool_names = []
108+
selected_tools = []
109+
signature = (selected_agent_key,)
110+
111+
# 選択に応じてグラフを再構築
112+
if "graph" not in st.session_state or st.session_state.get("graph_signature") != signature:
113+
st.session_state["graph"] = create_graph()
114+
st.session_state["graph_signature"] = signature
115+
116+
# 選択中の表示
117+
st.caption(f"選択中のエージェント: {selected_agent_key}")
118+
if supports_tools:
119+
st.caption("選択中のツール: " + (", ".join(selected_tool_names) if selected_tool_names else "なし"))
120+
else:
121+
st.caption("このエージェントはツール呼び出しをサポートしていません")
122+
123+
if prompt := st.chat_input():
124+
st.chat_message("user").write(prompt)
125+
with st.chat_message("assistant"):
126+
with st.spinner("処理中..."):
127+
# 変更: エージェントごとの入力ビルダーを使用
128+
callbacks = [StreamlitCallbackHandler(st.container())]
129+
input_builder = st.session_state.get("input_builder") or (lambda p, cbs: {"query": p, "callbacks": cbs})
130+
response = st.session_state["graph"].invoke(input=input_builder(prompt, callbacks))
131+
st.write(response)

0 commit comments

Comments
 (0)