10
10
from template_langgraph .tools .common import get_default_tools
11
11
12
12
13
- # 追加: 各エージェントのグラフ生成ファクトリと tool call 対応フラグ
14
13
def _make_parallel_rag_graph (tools ):
15
- graph = ParallelRagAgent (
16
- llm = AzureOpenAiWrapper ().chat_model ,
17
- tools = tools ,
18
- ).create_graph ()
19
-
20
- # 追加: このエージェント用の入力ビルダー
21
- def build_input (prompt , callbacks ):
22
- return {"query" : prompt , "callbacks" : callbacks }
14
+ def build_input (prompt ):
15
+ return {
16
+ "query" : prompt ,
17
+ }
23
18
24
- return {"graph" : graph , "build_input" : build_input }
19
+ return {
20
+ "graph" : ParallelRagAgent (
21
+ llm = AzureOpenAiWrapper ().chat_model ,
22
+ tools = tools ,
23
+ ).create_graph (),
24
+ "build_input" : build_input ,
25
+ }
25
26
26
27
27
28
def _make_weather_graph (_tools = None ):
28
- # weather_agent_graph が CompiledStateGraph を指している前提
29
- graph = weather_agent_graph
30
-
31
- # 追加: このエージェント用の入力ビルダー(必要に応じてキー名を調整)
32
- def build_input (prompt , callbacks ):
29
+ def build_input (prompt ):
33
30
return {
34
31
"messages" : [
35
32
prompt ,
36
33
],
37
- "callbacks" : callbacks ,
38
34
}
39
35
40
- return {"graph" : graph , "build_input" : build_input }
36
+ return {
37
+ "graph" : weather_agent_graph ,
38
+ "build_input" : build_input ,
39
+ }
41
40
42
41
43
42
agent_options = {
@@ -54,15 +53,12 @@ def build_input(prompt, callbacks):
54
53
55
54
56
55
def create_graph () -> CompiledStateGraph :
57
- # ...existing code...
58
56
cfg = agent_options .get (selected_agent_key ) or next (iter (agent_options .values ()))
59
57
supports_tools = cfg .get ("supports_tools" , True )
60
58
factory = cfg ["factory" ]
61
59
result = factory (selected_tools if supports_tools else None )
62
- # 追加: 入力ビルダーを保存(無ければデフォルトにフォールバック)
63
- st .session_state ["input_builder" ] = result .get ("build_input" ) or (lambda p , cbs : {"query" : p , "callbacks" : cbs })
60
+ st .session_state ["input_builder" ] = result .get ("build_input" ) or (lambda p : {"query" : p })
64
61
return result ["graph" ]
65
- # ...existing code...
66
62
67
63
68
64
# Sidebar: ツール選択とエージェントの構築
@@ -124,8 +120,11 @@ def create_graph() -> CompiledStateGraph:
124
120
st .chat_message ("user" ).write (prompt )
125
121
with st .chat_message ("assistant" ):
126
122
with st .spinner ("処理中..." ):
127
- # 変更: エージェントごとの入力ビルダーを使用
123
+ # 変更: callbacks は config に渡す。input は入力のみ。
128
124
callbacks = [StreamlitCallbackHandler (st .container ())]
129
- input_builder = st .session_state .get ("input_builder" ) or (lambda p , cbs : {"query" : p , "callbacks" : cbs })
130
- response = st .session_state ["graph" ].invoke (input = input_builder (prompt , callbacks ))
125
+ input_builder = st .session_state .get ("input_builder" ) or (lambda p : {"query" : p })
126
+ response = st .session_state ["graph" ].invoke (
127
+ input = input_builder (prompt ),
128
+ config = {"callbacks" : callbacks },
129
+ )
131
130
st .write (response )
0 commit comments