|
| 1 | +import streamlit as st |
| 2 | +from langchain_community.callbacks.streamlit import ( |
| 3 | + StreamlitCallbackHandler, |
| 4 | +) |
| 5 | +from langgraph.graph.state import CompiledStateGraph |
| 6 | + |
| 7 | +from template_langgraph.agents.demo_agents.parallel_rag_agent.agent import ParallelRagAgent |
| 8 | +from template_langgraph.agents.demo_agents.weather_agent import graph as weather_agent_graph |
| 9 | +from template_langgraph.llms.azure_openais import AzureOpenAiWrapper |
| 10 | +from template_langgraph.tools.common import get_default_tools |
| 11 | + |
| 12 | + |
| 13 | +# 追加: 各エージェントのグラフ生成ファクトリと tool call 対応フラグ |
| 14 | +def _make_parallel_rag_graph(tools): |
| 15 | + graph = ParallelRagAgent( |
| 16 | + llm=AzureOpenAiWrapper().chat_model, |
| 17 | + tools=tools, |
| 18 | + ).create_graph() |
| 19 | + |
| 20 | + # 追加: このエージェント用の入力ビルダー |
| 21 | + def build_input(prompt, callbacks): |
| 22 | + return {"query": prompt, "callbacks": callbacks} |
| 23 | + |
| 24 | + return {"graph": graph, "build_input": build_input} |
| 25 | + |
| 26 | + |
| 27 | +def _make_weather_graph(_tools=None): |
| 28 | + # weather_agent_graph が CompiledStateGraph を指している前提 |
| 29 | + graph = weather_agent_graph |
| 30 | + |
| 31 | + # 追加: このエージェント用の入力ビルダー(必要に応じてキー名を調整) |
| 32 | + def build_input(prompt, callbacks): |
| 33 | + return { |
| 34 | + "messages": [ |
| 35 | + prompt, |
| 36 | + ], |
| 37 | + "callbacks": callbacks, |
| 38 | + } |
| 39 | + |
| 40 | + return {"graph": graph, "build_input": build_input} |
| 41 | + |
| 42 | + |
| 43 | +agent_options = { |
| 44 | + "Parallel RAG Agent": { |
| 45 | + "supports_tools": True, |
| 46 | + "factory": _make_parallel_rag_graph, |
| 47 | + }, |
| 48 | + "Weather Agent": { |
| 49 | + "supports_tools": False, |
| 50 | + "factory": _make_weather_graph, |
| 51 | + }, |
| 52 | + # "Another Agent": {"supports_tools": True/False, "factory": your_factory}, |
| 53 | +} |
| 54 | + |
| 55 | + |
| 56 | +def create_graph() -> CompiledStateGraph: |
| 57 | + # ...existing code... |
| 58 | + cfg = agent_options.get(selected_agent_key) or next(iter(agent_options.values())) |
| 59 | + supports_tools = cfg.get("supports_tools", True) |
| 60 | + factory = cfg["factory"] |
| 61 | + result = factory(selected_tools if supports_tools else None) |
| 62 | + # 追加: 入力ビルダーを保存(無ければデフォルトにフォールバック) |
| 63 | + st.session_state["input_builder"] = result.get("build_input") or (lambda p, cbs: {"query": p, "callbacks": cbs}) |
| 64 | + return result["graph"] |
| 65 | + # ...existing code... |
| 66 | + |
| 67 | + |
| 68 | +# Sidebar: ツール選択とエージェントの構築 |
| 69 | +with st.sidebar: |
| 70 | + # 追加: エージェント選択 UI |
| 71 | + st.subheader("使用するエージェント") |
| 72 | + available_agent_keys = list(agent_options.keys()) |
| 73 | + if "selected_agent_key" not in st.session_state: |
| 74 | + st.session_state["selected_agent_key"] = available_agent_keys[0] |
| 75 | + selected_agent_key = st.selectbox( |
| 76 | + "実行するエージェントを選択", |
| 77 | + options=available_agent_keys, |
| 78 | + index=available_agent_keys.index(st.session_state["selected_agent_key"]), |
| 79 | + ) |
| 80 | + st.session_state["selected_agent_key"] = selected_agent_key |
| 81 | + |
| 82 | + # エージェントの tool call 対応フラグを取得 |
| 83 | + supports_tools = agent_options[selected_agent_key].get("supports_tools", True) |
| 84 | + |
| 85 | + # ツール選択 UI(supports_tools が True の時のみ表示) |
| 86 | + if supports_tools: |
| 87 | + st.subheader("使用するツール") |
| 88 | + # 利用可能なツール一覧を取得 |
| 89 | + available_tools = get_default_tools() |
| 90 | + tool_name_to_obj = {t.name: t for t in available_tools} |
| 91 | + tool_names = list(tool_name_to_obj.keys()) |
| 92 | + |
| 93 | + # 初期選択は全選択 |
| 94 | + if "selected_tool_names" not in st.session_state: |
| 95 | + st.session_state["selected_tool_names"] = tool_names |
| 96 | + |
| 97 | + selected_tool_names = st.multiselect( |
| 98 | + "有効化するツールを選択", |
| 99 | + options=tool_names, |
| 100 | + default=st.session_state["selected_tool_names"], |
| 101 | + ) |
| 102 | + st.session_state["selected_tool_names"] = selected_tool_names |
| 103 | + selected_tools = [tool_name_to_obj[name] for name in selected_tool_names] |
| 104 | + signature = (selected_agent_key, tuple(selected_tool_names)) |
| 105 | + else: |
| 106 | + # 非対応時はツール選択をスキップ |
| 107 | + selected_tool_names = [] |
| 108 | + selected_tools = [] |
| 109 | + signature = (selected_agent_key,) |
| 110 | + |
| 111 | + # 選択に応じてグラフを再構築 |
| 112 | + if "graph" not in st.session_state or st.session_state.get("graph_signature") != signature: |
| 113 | + st.session_state["graph"] = create_graph() |
| 114 | + st.session_state["graph_signature"] = signature |
| 115 | + |
| 116 | + # 選択中の表示 |
| 117 | + st.caption(f"選択中のエージェント: {selected_agent_key}") |
| 118 | + if supports_tools: |
| 119 | + st.caption("選択中のツール: " + (", ".join(selected_tool_names) if selected_tool_names else "なし")) |
| 120 | + else: |
| 121 | + st.caption("このエージェントはツール呼び出しをサポートしていません") |
| 122 | + |
| 123 | +if prompt := st.chat_input(): |
| 124 | + st.chat_message("user").write(prompt) |
| 125 | + with st.chat_message("assistant"): |
| 126 | + with st.spinner("処理中..."): |
| 127 | + # 変更: エージェントごとの入力ビルダーを使用 |
| 128 | + callbacks = [StreamlitCallbackHandler(st.container())] |
| 129 | + input_builder = st.session_state.get("input_builder") or (lambda p, cbs: {"query": p, "callbacks": cbs}) |
| 130 | + response = st.session_state["graph"].invoke(input=input_builder(prompt, callbacks)) |
| 131 | + st.write(response) |
0 commit comments