|
4 | 4 | import uuid |
5 | 5 | from base64 import b64encode |
6 | 6 | from dataclasses import dataclass |
| 7 | +from enum import Enum |
7 | 8 |
|
8 | 9 | import streamlit as st |
9 | 10 | from audio_recorder_streamlit import audio_recorder |
10 | 11 | from langchain_community.callbacks.streamlit import ( |
11 | 12 | StreamlitCallbackHandler, |
12 | 13 | ) |
13 | 14 | from langfuse.langchain import CallbackHandler |
| 15 | +from langgraph.checkpoint.memory import InMemorySaver |
14 | 16 | from langgraph.checkpoint.sqlite import SqliteSaver |
15 | 17 | from langgraph.store.sqlite import SqliteStore |
| 18 | +from langgraph_checkpoint_cosmosdb import CosmosDBSaver |
16 | 19 |
|
17 | 20 | from template_langgraph.agents.chat_with_tools_agent.agent import ( |
18 | 21 | AgentState, |
|
22 | 25 | from template_langgraph.speeches.tts import TtsWrapper |
23 | 26 | from template_langgraph.tools.common import get_default_tools |
24 | 27 |
|
25 | | -checkpoints_conn = sqlite3.connect("checkpoints.sqlite", check_same_thread=False) |
| 28 | + |
| 29 | +class CheckpointType(str, Enum): |
| 30 | + SQLITE = "sqlite" |
| 31 | + COSMOSDB = "cosmosdb" |
| 32 | + MEMORY = "memory" |
| 33 | + NONE = "none" |
| 34 | + |
| 35 | + |
| 36 | +DEFAULT_CHECKPOINT_TYPE = CheckpointType.NONE |
| 37 | +CHECKPOINT_LABELS = { |
| 38 | + CheckpointType.COSMOSDB.value: "Cosmos DB", |
| 39 | + CheckpointType.SQLITE.value: "SQLite", |
| 40 | + CheckpointType.MEMORY.value: "メモリ", |
| 41 | + CheckpointType.NONE.value: "なし", |
| 42 | +} |
| 43 | + |
| 44 | + |
26 | 45 | store_conn = sqlite3.connect("store.sqlite", check_same_thread=False) |
27 | 46 | thread_id = str(uuid.uuid4()) |
28 | 47 |
|
@@ -70,17 +89,47 @@ def ensure_session_state_defaults(tool_names: list[str]) -> None: |
70 | 89 | st.session_state.setdefault("chat_history", []) |
71 | 90 | st.session_state.setdefault("input_output_mode", "テキスト") |
72 | 91 | st.session_state.setdefault("selected_tool_names", tool_names) |
| 92 | + st.session_state.setdefault("checkpoint_type", DEFAULT_CHECKPOINT_TYPE.value) |
| 93 | + |
| 94 | + |
| 95 | +def get_selected_checkpoint_type() -> CheckpointType: |
| 96 | + raw_value = st.session_state.get("checkpoint_type", DEFAULT_CHECKPOINT_TYPE.value) |
| 97 | + try: |
| 98 | + checkpoint = CheckpointType(raw_value) |
| 99 | + except ValueError: |
| 100 | + st.session_state["checkpoint_type"] = DEFAULT_CHECKPOINT_TYPE.value |
| 101 | + return DEFAULT_CHECKPOINT_TYPE |
| 102 | + return checkpoint |
| 103 | + |
| 104 | + |
| 105 | +def get_checkpointer(): |
| 106 | + checkpoint_type = get_selected_checkpoint_type() |
| 107 | + if checkpoint_type is CheckpointType.SQLITE: |
| 108 | + conn = sqlite3.connect("checkpoints.sqlite", check_same_thread=False) |
| 109 | + return SqliteSaver(conn=conn) |
| 110 | + if checkpoint_type is CheckpointType.COSMOSDB: |
| 111 | + from template_langgraph.tools.cosmosdb_tool import get_cosmosdb_settings |
| 112 | + |
| 113 | + settings = get_cosmosdb_settings() |
| 114 | + os.environ["COSMOSDB_ENDPOINT"] = settings.cosmosdb_host |
| 115 | + os.environ["COSMOSDB_KEY"] = settings.cosmosdb_key |
| 116 | + |
| 117 | + return CosmosDBSaver( |
| 118 | + database_name=settings.cosmosdb_database_name, |
| 119 | + container_name="checkpoints", |
| 120 | + ) |
| 121 | + if checkpoint_type is CheckpointType.MEMORY: |
| 122 | + return InMemorySaver() |
| 123 | + return None |
73 | 124 |
|
74 | 125 |
|
75 | 126 | def ensure_agent_graph(selected_tools: list) -> None: |
76 | | - signature = tuple(tool.name for tool in selected_tools) |
| 127 | + signature = (tuple(tool.name for tool in selected_tools), get_selected_checkpoint_type().value) |
77 | 128 | graph_signature = st.session_state.get("graph_tools_signature") |
78 | 129 | if "graph" not in st.session_state or graph_signature != signature: |
79 | 130 | st.session_state["graph"] = ChatWithToolsAgent( |
80 | 131 | tools=selected_tools, |
81 | | - checkpointer=SqliteSaver( |
82 | | - conn=checkpoints_conn, |
83 | | - ), |
| 132 | + checkpointer=get_checkpointer(), |
84 | 133 | store=SqliteStore( |
85 | 134 | conn=store_conn, |
86 | 135 | ), |
@@ -111,6 +160,23 @@ def build_sidebar() -> tuple[str, AudioSettings | None]: |
111 | 160 | if input_mode == "音声": |
112 | 161 | audio_settings = render_audio_controls() |
113 | 162 |
|
| 163 | + st.divider() |
| 164 | + st.subheader("チェックポイント") |
| 165 | + |
| 166 | + checkpoint_options = [checkpoint.value for checkpoint in CheckpointType] |
| 167 | + current_checkpoint_value = st.session_state["checkpoint_type"] |
| 168 | + if current_checkpoint_value not in checkpoint_options: |
| 169 | + current_checkpoint_value = DEFAULT_CHECKPOINT_TYPE.value |
| 170 | + st.session_state["checkpoint_type"] = current_checkpoint_value |
| 171 | + checkpoint_index = checkpoint_options.index(current_checkpoint_value) |
| 172 | + selected_checkpoint_value = st.selectbox( |
| 173 | + "保存方法", |
| 174 | + options=checkpoint_options, |
| 175 | + index=checkpoint_index, |
| 176 | + format_func=lambda value: CHECKPOINT_LABELS.get(value, value), |
| 177 | + ) |
| 178 | + st.session_state["checkpoint_type"] = selected_checkpoint_value |
| 179 | + |
114 | 180 | st.divider() |
115 | 181 | st.subheader("使用するツール") |
116 | 182 |
|
|
0 commit comments