@@ -47,7 +47,9 @@ class CheckpointType(str, Enum):
4747
4848
4949store_conn = sqlite3 .connect ("store.sqlite" , check_same_thread = False )
50- thread_id = str (uuid .uuid4 ())
50+ # thread_id はセッション内で保持し、チェックポイント利用時に既存スレッドを再開できるようにする
51+ if "thread_id" not in st .session_state :
52+ st .session_state ["thread_id" ] = str (uuid .uuid4 ())
5153
5254
5355def image_to_base64 (image_bytes : bytes ) -> str :
@@ -62,8 +64,9 @@ def load_stt_wrapper(model_size: str = "base"):
6264 return stt_wrapper
6365
6466
65- if "chat_history" not in st .session_state :
66- st .session_state ["chat_history" ] = []
67+ # 以前は Streamlit セッションに chat_history を保持していたが、
68+ # 仕様変更により LangGraph の state (messages) を直接参照する方式へ移行。
69+ # そのため chat_history の初期化は削除。
6770
6871
6972@dataclass (slots = True )
@@ -90,7 +93,7 @@ def to_history_message(self) -> dict[str, object]:
9093
9194
9295def ensure_session_state_defaults (tool_names : list [str ]) -> None :
93- st . session_state . setdefault ( " chat_history" , [])
96+ # chat_history は利用せず(LangGraph 側の messages を利用)
9497 st .session_state .setdefault ("input_output_mode" , "テキスト" )
9598 st .session_state .setdefault ("selected_tool_names" , tool_names )
9699 st .session_state .setdefault ("checkpoint_type" , DEFAULT_CHECKPOINT_TYPE .value )
@@ -149,6 +152,27 @@ def ensure_agent_graph(selected_tools: list) -> None:
149152 st .session_state ["graph_tools_signature" ] = signature
150153
151154
155+ def _list_existing_thread_ids () -> list [str ]:
156+ """チェックポインタに保存されている thread_id を列挙 (最大50件)。"""
157+ checkpointer = get_checkpointer ()
158+ if not checkpointer :
159+ return []
160+ thread_ids : set [str ] = set ()
161+ try :
162+ for i , snapshot in enumerate (checkpointer .list (config = None )):
163+ if i > 1000 : # 念のため無限増加防止
164+ break
165+ cfg = getattr (snapshot , "config" , {}) or {}
166+ configurable = cfg .get ("configurable" , {}) if isinstance (cfg , dict ) else {}
167+ tid = configurable .get ("thread_id" )
168+ if isinstance (tid , str ):
169+ thread_ids .add (tid )
170+ except Exception as exc : # noqa: BLE001
171+ logger .debug (f"thread list 取得失敗: { exc } " )
172+ # 直近利用を優先できる情報が無いので単純ソート
173+ return sorted (thread_ids )[:50 ]
174+
175+
152176def build_sidebar () -> tuple [str , AudioSettings | None ]:
153177 audio_settings : AudioSettings | None = None
154178
@@ -189,6 +213,27 @@ def build_sidebar() -> tuple[str, AudioSettings | None]:
189213 )
190214 st .session_state ["checkpoint_type" ] = selected_checkpoint_value
191215
216+ # スレッド選択 UI (チェックポイント有効時のみ)
217+ if get_selected_checkpoint_type () is not CheckpointType .NONE :
218+ existing_threads = _list_existing_thread_ids ()
219+ st .subheader ("スレッド" )
220+ new_label = "<新規作成>"
221+ options = [new_label , * existing_threads ]
222+ current_thread = st .session_state .get ("thread_id" )
223+ # 既存に一致するならその index、なければ 0 (新規)
224+ if current_thread in existing_threads :
225+ default_index = options .index (current_thread )
226+ else :
227+ default_index = 0
228+ selected = st .selectbox ("既存スレッドを選択" , options = options , index = default_index )
229+ if selected == new_label :
230+ if st .button ("スレッドを生成" , use_container_width = True ):
231+ st .session_state ["thread_id" ] = str (uuid .uuid4 ())
232+ st .experimental_rerun ()
233+ else :
234+ st .session_state ["thread_id" ] = selected
235+ st .caption (f"現在の thread_id: { st .session_state ['thread_id' ]} " )
236+
192237 st .divider ()
193238 st .subheader ("使用するツール" )
194239
@@ -249,17 +294,25 @@ def render_audio_controls() -> AudioSettings:
249294
250295
251296def render_chat_history () -> None :
252- for msg in st .session_state ["chat_history" ]:
297+ """LangGraph の state 保存されている messages を列挙して表示する。"""
298+ agent_messages = get_agent_messages ()
299+ for msg in agent_messages :
300+ role = "assistant"
301+ content = ""
302+ attachments = []
253303 if isinstance (msg , dict ):
254- attachments = msg .get ("attachments" , [])
255- with st .chat_message (msg ["role" ]):
256- if attachments :
257- for item in attachments :
258- render_attachment (item )
259- else :
260- st .write (msg ["content" ])
261- else :
262- st .chat_message ("assistant" ).write (msg .content )
304+ role = msg .get ("role" , role )
305+ content = msg .get ("content" , content )
306+ attachments = msg .get ("attachments" , []) or []
307+ else : # LangChain Message オブジェクト互換
308+ role = getattr (msg , "role" , role )
309+ content = getattr (msg , "content" , content )
310+ with st .chat_message (role ):
311+ if attachments :
312+ for item in attachments :
313+ render_attachment (item )
314+ if content :
315+ st .write (content )
263316
264317
265318def render_attachment (item : dict [str , object ]) -> None :
@@ -378,14 +431,31 @@ def render_user_submission(submission: UserSubmission) -> None:
378431 st .write (submission .content )
379432
380433
381- def build_graph_messages () -> list :
382- graph_messages = []
383- for msg in st .session_state ["chat_history" ]:
384- if isinstance (msg , dict ):
385- graph_messages .append ({"role" : msg ["role" ], "content" : msg ["content" ]})
386- else :
387- graph_messages .append (msg )
388- return graph_messages
434+ def get_agent_messages () -> list :
435+ """LangGraph の現在 state から messages を取得。エラー時は空配列。"""
436+ if "graph" not in st .session_state :
437+ return []
438+ try :
439+ state = st .session_state ["graph" ].get_state (
440+ {
441+ "configurable" : {
442+ "thread_id" : st .session_state .get ("thread_id" ),
443+ "user_id" : "user_1" ,
444+ },
445+ },
446+ )
447+ values = getattr (state , "values" , state )
448+ if isinstance (values , dict ):
449+ return list (values .get ("messages" , []) or [])
450+ return []
451+ except Exception as exc : # noqa: BLE001
452+ logger .debug (f"messages state の取得に失敗: { exc } " )
453+ return []
454+
455+
456+ def build_graph_messages_with_new_user (user_content : str ) -> list :
457+ """既存 messages に新しい user メッセージを追加したリストを返す。"""
458+ return [* get_agent_messages (), {"role" : "user" , "content" : user_content }]
389459
390460
391461def invoke_agent (graph_messages : list ) -> AgentState :
@@ -399,7 +469,7 @@ def invoke_agent(graph_messages: list) -> AgentState:
399469 CallbackHandler (),
400470 ],
401471 "configurable" : {
402- "thread_id" : thread_id ,
472+ "thread_id" : st . session_state . get ( " thread_id" ) ,
403473 "user_id" : "user_1" ,
404474 },
405475 },
@@ -431,19 +501,17 @@ def synthesize_audio_if_needed(response_content: str, mode: str, audio_settings:
431501submission = collect_user_submission (input_output_mode , audio_settings )
432502
433503if submission :
434- history_message = submission .to_history_message ()
435- st .session_state ["chat_history" ].append (history_message )
436-
437504 with st .chat_message ("user" ):
438505 render_user_submission (submission )
439506
440- graph_messages = build_graph_messages ()
441-
507+ updated_messages = build_graph_messages_with_new_user (submission .content )
442508 with st .chat_message ("assistant" ):
443- response = invoke_agent (graph_messages )
444- last_message = response ["messages" ][- 1 ]
445- st .session_state ["chat_history" ].append (last_message )
446-
447- response_content = last_message .content
448- st .write (response_content )
449- synthesize_audio_if_needed (response_content , input_output_mode , audio_settings )
509+ response = invoke_agent (updated_messages )
510+ latest_messages = response ["messages" ]
511+ last_message = latest_messages [- 1 ] if latest_messages else None
512+ if last_message is not None :
513+ response_content = getattr (last_message , "content" , None ) or (
514+ last_message .get ("content" ) if isinstance (last_message , dict ) else ""
515+ )
516+ st .write (response_content )
517+ synthesize_audio_if_needed (response_content , input_output_mode , audio_settings )
0 commit comments