44import uuid
55from base64 import b64encode
66from dataclasses import dataclass
7+ from enum import Enum
78
89import streamlit as st
910from audio_recorder_streamlit import audio_recorder
2425from template_langgraph .speeches .tts import TtsWrapper
2526from template_langgraph .tools .common import get_default_tools
2627
27- checkpoint_type = "cosmosdb" # "cosmosdb" or "sqlite"
28+
29+ class CheckpointType (str , Enum ):
30+ SQLITE = "sqlite"
31+ COSMOSDB = "cosmosdb"
32+ MEMORY = "memory"
33+ NONE = "none"
34+
35+
36+ DEFAULT_CHECKPOINT_TYPE = CheckpointType .NONE
37+ CHECKPOINT_LABELS = {
38+ CheckpointType .COSMOSDB .value : "Cosmos DB" ,
39+ CheckpointType .SQLITE .value : "SQLite" ,
40+ CheckpointType .MEMORY .value : "メモリ" ,
41+ CheckpointType .NONE .value : "なし" ,
42+ }
43+
44+
2845store_conn = sqlite3 .connect ("store.sqlite" , check_same_thread = False )
2946thread_id = str (uuid .uuid4 ())
3047
@@ -72,15 +89,25 @@ def ensure_session_state_defaults(tool_names: list[str]) -> None:
7289 st .session_state .setdefault ("chat_history" , [])
7390 st .session_state .setdefault ("input_output_mode" , "テキスト" )
7491 st .session_state .setdefault ("selected_tool_names" , tool_names )
92+ st .session_state .setdefault ("checkpoint_type" , DEFAULT_CHECKPOINT_TYPE .value )
93+
94+
95+ def get_selected_checkpoint_type () -> CheckpointType :
96+ raw_value = st .session_state .get ("checkpoint_type" , DEFAULT_CHECKPOINT_TYPE .value )
97+ try :
98+ checkpoint = CheckpointType (raw_value )
99+ except ValueError :
100+ st .session_state ["checkpoint_type" ] = DEFAULT_CHECKPOINT_TYPE .value
101+ return DEFAULT_CHECKPOINT_TYPE
102+ return checkpoint
75103
76104
77105def get_checkpointer ():
78- if checkpoint_type == "sqlite" :
106+ checkpoint_type = get_selected_checkpoint_type ()
107+ if checkpoint_type is CheckpointType .SQLITE :
79108 conn = sqlite3 .connect ("checkpoints.sqlite" , check_same_thread = False )
80109 return SqliteSaver (conn = conn )
81- if checkpoint_type == "cosmosdb" :
82- import os
83-
110+ if checkpoint_type is CheckpointType .COSMOSDB :
84111 from template_langgraph .tools .cosmosdb_tool import get_cosmosdb_settings
85112
86113 settings = get_cosmosdb_settings ()
@@ -91,13 +118,13 @@ def get_checkpointer():
91118 database_name = settings .cosmosdb_database_name ,
92119 container_name = "checkpoints" ,
93120 )
94- if checkpoint_type == "memory" :
121+ if checkpoint_type is CheckpointType . MEMORY :
95122 return InMemorySaver ()
96123 return None
97124
98125
99126def ensure_agent_graph (selected_tools : list ) -> None :
100- signature = tuple (tool .name for tool in selected_tools )
127+ signature = ( tuple (tool .name for tool in selected_tools ), get_selected_checkpoint_type (). value )
101128 graph_signature = st .session_state .get ("graph_tools_signature" )
102129 if "graph" not in st .session_state or graph_signature != signature :
103130 st .session_state ["graph" ] = ChatWithToolsAgent (
@@ -133,6 +160,23 @@ def build_sidebar() -> tuple[str, AudioSettings | None]:
133160 if input_mode == "音声" :
134161 audio_settings = render_audio_controls ()
135162
163+ st .divider ()
164+ st .subheader ("チェックポイント" )
165+
166+ checkpoint_options = [checkpoint .value for checkpoint in CheckpointType ]
167+ current_checkpoint_value = st .session_state ["checkpoint_type" ]
168+ if current_checkpoint_value not in checkpoint_options :
169+ current_checkpoint_value = DEFAULT_CHECKPOINT_TYPE .value
170+ st .session_state ["checkpoint_type" ] = current_checkpoint_value
171+ checkpoint_index = checkpoint_options .index (current_checkpoint_value )
172+ selected_checkpoint_value = st .selectbox (
173+ "保存方法" ,
174+ options = checkpoint_options ,
175+ index = checkpoint_index ,
176+ format_func = lambda value : CHECKPOINT_LABELS .get (value , value ),
177+ )
178+ st .session_state ["checkpoint_type" ] = selected_checkpoint_value
179+
136180 st .divider ()
137181 st .subheader ("使用するツール" )
138182
0 commit comments