Skip to content

Commit 120d772

Browse files
committed
cosmetic changes
1 parent f125f1e commit 120d772

File tree

1 file changed

+23
-24
lines changed

1 file changed

+23
-24
lines changed

template_langgraph/services/streamlits/pages/generic_agent_runner.py

Lines changed: 23 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -10,34 +10,33 @@
1010
from template_langgraph.tools.common import get_default_tools
1111

1212

13-
# 追加: 各エージェントのグラフ生成ファクトリと tool call 対応フラグ
1413
def _make_parallel_rag_graph(tools):
15-
graph = ParallelRagAgent(
16-
llm=AzureOpenAiWrapper().chat_model,
17-
tools=tools,
18-
).create_graph()
19-
20-
# 追加: このエージェント用の入力ビルダー
21-
def build_input(prompt, callbacks):
22-
return {"query": prompt, "callbacks": callbacks}
14+
def build_input(prompt):
15+
return {
16+
"query": prompt,
17+
}
2318

24-
return {"graph": graph, "build_input": build_input}
19+
return {
20+
"graph": ParallelRagAgent(
21+
llm=AzureOpenAiWrapper().chat_model,
22+
tools=tools,
23+
).create_graph(),
24+
"build_input": build_input,
25+
}
2526

2627

2728
def _make_weather_graph(_tools=None):
28-
# weather_agent_graph が CompiledStateGraph を指している前提
29-
graph = weather_agent_graph
30-
31-
# 追加: このエージェント用の入力ビルダー(必要に応じてキー名を調整)
32-
def build_input(prompt, callbacks):
29+
def build_input(prompt):
3330
return {
3431
"messages": [
3532
prompt,
3633
],
37-
"callbacks": callbacks,
3834
}
3935

40-
return {"graph": graph, "build_input": build_input}
36+
return {
37+
"graph": weather_agent_graph,
38+
"build_input": build_input,
39+
}
4140

4241

4342
agent_options = {
@@ -54,15 +53,12 @@ def build_input(prompt, callbacks):
5453

5554

5655
def create_graph() -> CompiledStateGraph:
57-
# ...existing code...
5856
cfg = agent_options.get(selected_agent_key) or next(iter(agent_options.values()))
5957
supports_tools = cfg.get("supports_tools", True)
6058
factory = cfg["factory"]
6159
result = factory(selected_tools if supports_tools else None)
62-
# 追加: 入力ビルダーを保存(無ければデフォルトにフォールバック)
63-
st.session_state["input_builder"] = result.get("build_input") or (lambda p, cbs: {"query": p, "callbacks": cbs})
60+
st.session_state["input_builder"] = result.get("build_input") or (lambda p: {"query": p})
6461
return result["graph"]
65-
# ...existing code...
6662

6763

6864
# Sidebar: ツール選択とエージェントの構築
@@ -124,8 +120,11 @@ def create_graph() -> CompiledStateGraph:
124120
st.chat_message("user").write(prompt)
125121
with st.chat_message("assistant"):
126122
with st.spinner("処理中..."):
127-
# 変更: エージェントごとの入力ビルダーを使用
123+
# 変更: callbacks は config に渡す。input は入力のみ。
128124
callbacks = [StreamlitCallbackHandler(st.container())]
129-
input_builder = st.session_state.get("input_builder") or (lambda p, cbs: {"query": p, "callbacks": cbs})
130-
response = st.session_state["graph"].invoke(input=input_builder(prompt, callbacks))
125+
input_builder = st.session_state.get("input_builder") or (lambda p: {"query": p})
126+
response = st.session_state["graph"].invoke(
127+
input=input_builder(prompt),
128+
config={"callbacks": callbacks},
129+
)
131130
st.write(response)

0 commit comments

Comments
 (0)