Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 131 additions & 0 deletions template_langgraph/services/streamlits/pages/generic_agent_runner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
import streamlit as st
from langchain_community.callbacks.streamlit import (
StreamlitCallbackHandler,
)
from langgraph.graph.state import CompiledStateGraph

from template_langgraph.agents.demo_agents.parallel_rag_agent.agent import ParallelRagAgent
from template_langgraph.agents.demo_agents.weather_agent import graph as weather_agent_graph
from template_langgraph.llms.azure_openais import AzureOpenAiWrapper
from template_langgraph.tools.common import get_default_tools


# 追加: 各エージェントのグラフ生成ファクトリと tool call 対応フラグ
def _make_parallel_rag_graph(tools):
graph = ParallelRagAgent(
llm=AzureOpenAiWrapper().chat_model,
tools=tools,
).create_graph()

# 追加: このエージェント用の入力ビルダー
def build_input(prompt, callbacks):
return {"query": prompt, "callbacks": callbacks}

return {"graph": graph, "build_input": build_input}


def _make_weather_graph(_tools=None):
# weather_agent_graph が CompiledStateGraph を指している前提
graph = weather_agent_graph

# 追加: このエージェント用の入力ビルダー(必要に応じてキー名を調整)
def build_input(prompt, callbacks):
return {
"messages": [
prompt,
],
"callbacks": callbacks,
}

return {"graph": graph, "build_input": build_input}


agent_options = {
"Parallel RAG Agent": {
"supports_tools": True,
"factory": _make_parallel_rag_graph,
},
"Weather Agent": {
"supports_tools": False,
"factory": _make_weather_graph,
},
# "Another Agent": {"supports_tools": True/False, "factory": your_factory},
}


def create_graph() -> CompiledStateGraph:
# ...existing code...
cfg = agent_options.get(selected_agent_key) or next(iter(agent_options.values()))
supports_tools = cfg.get("supports_tools", True)
factory = cfg["factory"]
result = factory(selected_tools if supports_tools else None)
# 追加: 入力ビルダーを保存(無ければデフォルトにフォールバック)
st.session_state["input_builder"] = result.get("build_input") or (lambda p, cbs: {"query": p, "callbacks": cbs})
return result["graph"]
# ...existing code...


# Sidebar: ツール選択とエージェントの構築
with st.sidebar:
# 追加: エージェント選択 UI
st.subheader("使用するエージェント")
available_agent_keys = list(agent_options.keys())
if "selected_agent_key" not in st.session_state:
st.session_state["selected_agent_key"] = available_agent_keys[0]
selected_agent_key = st.selectbox(
"実行するエージェントを選択",
options=available_agent_keys,
index=available_agent_keys.index(st.session_state["selected_agent_key"]),
)
st.session_state["selected_agent_key"] = selected_agent_key

# エージェントの tool call 対応フラグを取得
supports_tools = agent_options[selected_agent_key].get("supports_tools", True)

# ツール選択 UI(supports_tools が True の時のみ表示)
if supports_tools:
st.subheader("使用するツール")
# 利用可能なツール一覧を取得
available_tools = get_default_tools()
tool_name_to_obj = {t.name: t for t in available_tools}
tool_names = list(tool_name_to_obj.keys())

# 初期選択は全選択
if "selected_tool_names" not in st.session_state:
st.session_state["selected_tool_names"] = tool_names

selected_tool_names = st.multiselect(
"有効化するツールを選択",
options=tool_names,
default=st.session_state["selected_tool_names"],
)
st.session_state["selected_tool_names"] = selected_tool_names
selected_tools = [tool_name_to_obj[name] for name in selected_tool_names]
signature = (selected_agent_key, tuple(selected_tool_names))
else:
# 非対応時はツール選択をスキップ
selected_tool_names = []
selected_tools = []
signature = (selected_agent_key,)

# 選択に応じてグラフを再構築
if "graph" not in st.session_state or st.session_state.get("graph_signature") != signature:
st.session_state["graph"] = create_graph()
st.session_state["graph_signature"] = signature

# 選択中の表示
st.caption(f"選択中のエージェント: {selected_agent_key}")
if supports_tools:
st.caption("選択中のツール: " + (", ".join(selected_tool_names) if selected_tool_names else "なし"))
else:
st.caption("このエージェントはツール呼び出しをサポートしていません")

if prompt := st.chat_input():
st.chat_message("user").write(prompt)
with st.chat_message("assistant"):
with st.spinner("処理中..."):
# 変更: エージェントごとの入力ビルダーを使用
callbacks = [StreamlitCallbackHandler(st.container())]
input_builder = st.session_state.get("input_builder") or (lambda p, cbs: {"query": p, "callbacks": cbs})
response = st.session_state["graph"].invoke(input=input_builder(prompt, callbacks))
st.write(response)