Skip to content

Commit 43cd203

Browse files
committed
gen: 現状 streamlit のアプリ #file:chat_with_tools_agent.py 側でチャット履歴を保持してやりとりしているが、LangGraph の checkpoint を履歴としてチャットする仕様に変更して欲しい。チャット履歴は LangGraph Agent の messages array を参照して表示してください。
1 parent 6b61e47 commit 43cd203

File tree

1 file changed

+103
-35
lines changed

1 file changed

+103
-35
lines changed

template_langgraph/services/streamlits/pages/chat_with_tools_agent.py

Lines changed: 103 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,9 @@ class CheckpointType(str, Enum):
4747

4848

4949
store_conn = sqlite3.connect("store.sqlite", check_same_thread=False)
50-
thread_id = str(uuid.uuid4())
50+
# thread_id はセッション内で保持し、チェックポイント利用時に既存スレッドを再開できるようにする
51+
if "thread_id" not in st.session_state:
52+
st.session_state["thread_id"] = str(uuid.uuid4())
5153

5254

5355
def image_to_base64(image_bytes: bytes) -> str:
@@ -62,8 +64,9 @@ def load_stt_wrapper(model_size: str = "base"):
6264
return stt_wrapper
6365

6466

65-
if "chat_history" not in st.session_state:
66-
st.session_state["chat_history"] = []
67+
# 以前は Streamlit セッションに chat_history を保持していたが、
68+
# 仕様変更により LangGraph の state (messages) を直接参照する方式へ移行。
69+
# そのため chat_history の初期化は削除。
6770

6871

6972
@dataclass(slots=True)
@@ -90,7 +93,7 @@ def to_history_message(self) -> dict[str, object]:
9093

9194

9295
def ensure_session_state_defaults(tool_names: list[str]) -> None:
93-
st.session_state.setdefault("chat_history", [])
96+
# chat_history は利用せず(LangGraph 側の messages を利用)
9497
st.session_state.setdefault("input_output_mode", "テキスト")
9598
st.session_state.setdefault("selected_tool_names", tool_names)
9699
st.session_state.setdefault("checkpoint_type", DEFAULT_CHECKPOINT_TYPE.value)
@@ -149,6 +152,27 @@ def ensure_agent_graph(selected_tools: list) -> None:
149152
st.session_state["graph_tools_signature"] = signature
150153

151154

155+
def _list_existing_thread_ids() -> list[str]:
156+
"""チェックポインタに保存されている thread_id を列挙 (最大50件)。"""
157+
checkpointer = get_checkpointer()
158+
if not checkpointer:
159+
return []
160+
thread_ids: set[str] = set()
161+
try:
162+
for i, snapshot in enumerate(checkpointer.list(config=None)):
163+
if i > 1000: # 念のため無限増加防止
164+
break
165+
cfg = getattr(snapshot, "config", {}) or {}
166+
configurable = cfg.get("configurable", {}) if isinstance(cfg, dict) else {}
167+
tid = configurable.get("thread_id")
168+
if isinstance(tid, str):
169+
thread_ids.add(tid)
170+
except Exception as exc: # noqa: BLE001
171+
logger.debug(f"thread list 取得失敗: {exc}")
172+
# 直近利用を優先できる情報が無いので単純ソート
173+
return sorted(thread_ids)[:50]
174+
175+
152176
def build_sidebar() -> tuple[str, AudioSettings | None]:
153177
audio_settings: AudioSettings | None = None
154178

@@ -189,6 +213,27 @@ def build_sidebar() -> tuple[str, AudioSettings | None]:
189213
)
190214
st.session_state["checkpoint_type"] = selected_checkpoint_value
191215

216+
# スレッド選択 UI (チェックポイント有効時のみ)
217+
if get_selected_checkpoint_type() is not CheckpointType.NONE:
218+
existing_threads = _list_existing_thread_ids()
219+
st.subheader("スレッド")
220+
new_label = "<新規作成>"
221+
options = [new_label, *existing_threads]
222+
current_thread = st.session_state.get("thread_id")
223+
# 既存に一致するならその index、なければ 0 (新規)
224+
if current_thread in existing_threads:
225+
default_index = options.index(current_thread)
226+
else:
227+
default_index = 0
228+
selected = st.selectbox("既存スレッドを選択", options=options, index=default_index)
229+
if selected == new_label:
230+
if st.button("スレッドを生成", use_container_width=True):
231+
st.session_state["thread_id"] = str(uuid.uuid4())
232+
st.experimental_rerun()
233+
else:
234+
st.session_state["thread_id"] = selected
235+
st.caption(f"現在の thread_id: {st.session_state['thread_id']}")
236+
192237
st.divider()
193238
st.subheader("使用するツール")
194239

@@ -249,17 +294,25 @@ def render_audio_controls() -> AudioSettings:
249294

250295

251296
def render_chat_history() -> None:
252-
for msg in st.session_state["chat_history"]:
297+
"""LangGraph の state 保存されている messages を列挙して表示する。"""
298+
agent_messages = get_agent_messages()
299+
for msg in agent_messages:
300+
role = "assistant"
301+
content = ""
302+
attachments = []
253303
if isinstance(msg, dict):
254-
attachments = msg.get("attachments", [])
255-
with st.chat_message(msg["role"]):
256-
if attachments:
257-
for item in attachments:
258-
render_attachment(item)
259-
else:
260-
st.write(msg["content"])
261-
else:
262-
st.chat_message("assistant").write(msg.content)
304+
role = msg.get("role", role)
305+
content = msg.get("content", content)
306+
attachments = msg.get("attachments", []) or []
307+
else: # LangChain Message オブジェクト互換
308+
role = getattr(msg, "role", role)
309+
content = getattr(msg, "content", content)
310+
with st.chat_message(role):
311+
if attachments:
312+
for item in attachments:
313+
render_attachment(item)
314+
if content:
315+
st.write(content)
263316

264317

265318
def render_attachment(item: dict[str, object]) -> None:
@@ -378,14 +431,31 @@ def render_user_submission(submission: UserSubmission) -> None:
378431
st.write(submission.content)
379432

380433

381-
def build_graph_messages() -> list:
382-
graph_messages = []
383-
for msg in st.session_state["chat_history"]:
384-
if isinstance(msg, dict):
385-
graph_messages.append({"role": msg["role"], "content": msg["content"]})
386-
else:
387-
graph_messages.append(msg)
388-
return graph_messages
434+
def get_agent_messages() -> list:
435+
"""LangGraph の現在 state から messages を取得。エラー時は空配列。"""
436+
if "graph" not in st.session_state:
437+
return []
438+
try:
439+
state = st.session_state["graph"].get_state(
440+
{
441+
"configurable": {
442+
"thread_id": st.session_state.get("thread_id"),
443+
"user_id": "user_1",
444+
},
445+
},
446+
)
447+
values = getattr(state, "values", state)
448+
if isinstance(values, dict):
449+
return list(values.get("messages", []) or [])
450+
return []
451+
except Exception as exc: # noqa: BLE001
452+
logger.debug(f"messages state の取得に失敗: {exc}")
453+
return []
454+
455+
456+
def build_graph_messages_with_new_user(user_content: str) -> list:
457+
"""既存 messages に新しい user メッセージを追加したリストを返す。"""
458+
return [*get_agent_messages(), {"role": "user", "content": user_content}]
389459

390460

391461
def invoke_agent(graph_messages: list) -> AgentState:
@@ -399,7 +469,7 @@ def invoke_agent(graph_messages: list) -> AgentState:
399469
CallbackHandler(),
400470
],
401471
"configurable": {
402-
"thread_id": thread_id,
472+
"thread_id": st.session_state.get("thread_id"),
403473
"user_id": "user_1",
404474
},
405475
},
@@ -431,19 +501,17 @@ def synthesize_audio_if_needed(response_content: str, mode: str, audio_settings:
431501
submission = collect_user_submission(input_output_mode, audio_settings)
432502

433503
if submission:
434-
history_message = submission.to_history_message()
435-
st.session_state["chat_history"].append(history_message)
436-
437504
with st.chat_message("user"):
438505
render_user_submission(submission)
439506

440-
graph_messages = build_graph_messages()
441-
507+
updated_messages = build_graph_messages_with_new_user(submission.content)
442508
with st.chat_message("assistant"):
443-
response = invoke_agent(graph_messages)
444-
last_message = response["messages"][-1]
445-
st.session_state["chat_history"].append(last_message)
446-
447-
response_content = last_message.content
448-
st.write(response_content)
449-
synthesize_audio_if_needed(response_content, input_output_mode, audio_settings)
509+
response = invoke_agent(updated_messages)
510+
latest_messages = response["messages"]
511+
last_message = latest_messages[-1] if latest_messages else None
512+
if last_message is not None:
513+
response_content = getattr(last_message, "content", None) or (
514+
last_message.get("content") if isinstance(last_message, dict) else ""
515+
)
516+
st.write(response_content)
517+
synthesize_audio_if_needed(response_content, input_output_mode, audio_settings)

0 commit comments

Comments
 (0)