Skip to content

Commit 0492c6e

Browse files
committed
gen: (GPT-5-Codex) #file:chat_with_tools_agent.py に sidebar から checkpoint_type を選択できる UI を追加してください。現在サポートされている sqlite, cosmosdb, memory についても enum で型定義してください。
1 parent d278b82 commit 0492c6e

File tree

1 file changed

+51
-7
lines changed

1 file changed

+51
-7
lines changed

template_langgraph/services/streamlits/pages/chat_with_tools_agent.py

Lines changed: 51 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import uuid
55
from base64 import b64encode
66
from dataclasses import dataclass
7+
from enum import Enum
78

89
import streamlit as st
910
from audio_recorder_streamlit import audio_recorder
@@ -24,7 +25,23 @@
2425
from template_langgraph.speeches.tts import TtsWrapper
2526
from template_langgraph.tools.common import get_default_tools
2627

27-
checkpoint_type = "cosmosdb" # "cosmosdb" or "sqlite"
28+
29+
class CheckpointType(str, Enum):
30+
SQLITE = "sqlite"
31+
COSMOSDB = "cosmosdb"
32+
MEMORY = "memory"
33+
NONE = "none"
34+
35+
36+
DEFAULT_CHECKPOINT_TYPE = CheckpointType.NONE
37+
CHECKPOINT_LABELS = {
38+
CheckpointType.COSMOSDB.value: "Cosmos DB",
39+
CheckpointType.SQLITE.value: "SQLite",
40+
CheckpointType.MEMORY.value: "メモリ",
41+
CheckpointType.NONE.value: "なし",
42+
}
43+
44+
2845
store_conn = sqlite3.connect("store.sqlite", check_same_thread=False)
2946
thread_id = str(uuid.uuid4())
3047

@@ -72,15 +89,25 @@ def ensure_session_state_defaults(tool_names: list[str]) -> None:
7289
st.session_state.setdefault("chat_history", [])
7390
st.session_state.setdefault("input_output_mode", "テキスト")
7491
st.session_state.setdefault("selected_tool_names", tool_names)
92+
st.session_state.setdefault("checkpoint_type", DEFAULT_CHECKPOINT_TYPE.value)
93+
94+
95+
def get_selected_checkpoint_type() -> CheckpointType:
96+
raw_value = st.session_state.get("checkpoint_type", DEFAULT_CHECKPOINT_TYPE.value)
97+
try:
98+
checkpoint = CheckpointType(raw_value)
99+
except ValueError:
100+
st.session_state["checkpoint_type"] = DEFAULT_CHECKPOINT_TYPE.value
101+
return DEFAULT_CHECKPOINT_TYPE
102+
return checkpoint
75103

76104

77105
def get_checkpointer():
78-
if checkpoint_type == "sqlite":
106+
checkpoint_type = get_selected_checkpoint_type()
107+
if checkpoint_type is CheckpointType.SQLITE:
79108
conn = sqlite3.connect("checkpoints.sqlite", check_same_thread=False)
80109
return SqliteSaver(conn=conn)
81-
if checkpoint_type == "cosmosdb":
82-
import os
83-
110+
if checkpoint_type is CheckpointType.COSMOSDB:
84111
from template_langgraph.tools.cosmosdb_tool import get_cosmosdb_settings
85112

86113
settings = get_cosmosdb_settings()
@@ -91,13 +118,13 @@ def get_checkpointer():
91118
database_name=settings.cosmosdb_database_name,
92119
container_name="checkpoints",
93120
)
94-
if checkpoint_type == "memory":
121+
if checkpoint_type is CheckpointType.MEMORY:
95122
return InMemorySaver()
96123
return None
97124

98125

99126
def ensure_agent_graph(selected_tools: list) -> None:
100-
signature = tuple(tool.name for tool in selected_tools)
127+
signature = (tuple(tool.name for tool in selected_tools), get_selected_checkpoint_type().value)
101128
graph_signature = st.session_state.get("graph_tools_signature")
102129
if "graph" not in st.session_state or graph_signature != signature:
103130
st.session_state["graph"] = ChatWithToolsAgent(
@@ -133,6 +160,23 @@ def build_sidebar() -> tuple[str, AudioSettings | None]:
133160
if input_mode == "音声":
134161
audio_settings = render_audio_controls()
135162

163+
st.divider()
164+
st.subheader("チェックポイント")
165+
166+
checkpoint_options = [checkpoint.value for checkpoint in CheckpointType]
167+
current_checkpoint_value = st.session_state["checkpoint_type"]
168+
if current_checkpoint_value not in checkpoint_options:
169+
current_checkpoint_value = DEFAULT_CHECKPOINT_TYPE.value
170+
st.session_state["checkpoint_type"] = current_checkpoint_value
171+
checkpoint_index = checkpoint_options.index(current_checkpoint_value)
172+
selected_checkpoint_value = st.selectbox(
173+
"保存方法",
174+
options=checkpoint_options,
175+
index=checkpoint_index,
176+
format_func=lambda value: CHECKPOINT_LABELS.get(value, value),
177+
)
178+
st.session_state["checkpoint_type"] = selected_checkpoint_value
179+
136180
st.divider()
137181
st.subheader("使用するツール")
138182

0 commit comments

Comments
 (0)