1010from template_langgraph .tools .common import get_default_tools
1111
1212
13- # 追加: 各エージェントのグラフ生成ファクトリと tool call 対応フラグ
1413def _make_parallel_rag_graph (tools ):
15- graph = ParallelRagAgent (
16- llm = AzureOpenAiWrapper ().chat_model ,
17- tools = tools ,
18- ).create_graph ()
19-
20- # 追加: このエージェント用の入力ビルダー
21- def build_input (prompt , callbacks ):
22- return {"query" : prompt , "callbacks" : callbacks }
14+ def build_input (prompt ):
15+ return {
16+ "query" : prompt ,
17+ }
2318
24- return {"graph" : graph , "build_input" : build_input }
19+ return {
20+ "graph" : ParallelRagAgent (
21+ llm = AzureOpenAiWrapper ().chat_model ,
22+ tools = tools ,
23+ ).create_graph (),
24+ "build_input" : build_input ,
25+ }
2526
2627
2728def _make_weather_graph (_tools = None ):
28- # weather_agent_graph が CompiledStateGraph を指している前提
29- graph = weather_agent_graph
30-
31- # 追加: このエージェント用の入力ビルダー(必要に応じてキー名を調整)
32- def build_input (prompt , callbacks ):
29+ def build_input (prompt ):
3330 return {
3431 "messages" : [
3532 prompt ,
3633 ],
37- "callbacks" : callbacks ,
3834 }
3935
40- return {"graph" : graph , "build_input" : build_input }
36+ return {
37+ "graph" : weather_agent_graph ,
38+ "build_input" : build_input ,
39+ }
4140
4241
4342agent_options = {
@@ -54,15 +53,12 @@ def build_input(prompt, callbacks):
5453
5554
5655def create_graph () -> CompiledStateGraph :
57- # ...existing code...
5856 cfg = agent_options .get (selected_agent_key ) or next (iter (agent_options .values ()))
5957 supports_tools = cfg .get ("supports_tools" , True )
6058 factory = cfg ["factory" ]
6159 result = factory (selected_tools if supports_tools else None )
62- # 追加: 入力ビルダーを保存(無ければデフォルトにフォールバック)
63- st .session_state ["input_builder" ] = result .get ("build_input" ) or (lambda p , cbs : {"query" : p , "callbacks" : cbs })
60+ st .session_state ["input_builder" ] = result .get ("build_input" ) or (lambda p : {"query" : p })
6461 return result ["graph" ]
65- # ...existing code...
6662
6763
6864# Sidebar: ツール選択とエージェントの構築
@@ -124,8 +120,11 @@ def create_graph() -> CompiledStateGraph:
124120 st .chat_message ("user" ).write (prompt )
125121 with st .chat_message ("assistant" ):
126122 with st .spinner ("処理中..." ):
127- # 変更: エージェントごとの入力ビルダーを使用
123+ # 変更: callbacks は config に渡す。input は入力のみ。
128124 callbacks = [StreamlitCallbackHandler (st .container ())]
129- input_builder = st .session_state .get ("input_builder" ) or (lambda p , cbs : {"query" : p , "callbacks" : cbs })
130- response = st .session_state ["graph" ].invoke (input = input_builder (prompt , callbacks ))
125+ input_builder = st .session_state .get ("input_builder" ) or (lambda p : {"query" : p })
126+ response = st .session_state ["graph" ].invoke (
127+ input = input_builder (prompt ),
128+ config = {"callbacks" : callbacks },
129+ )
131130 st .write (response )
0 commit comments