Skip to content

Commit 6e923b1

Browse files
authored
Merge pull request #182 from ks6088ts-labs/cosmosdb-checkpointer
support Cosmosdb checkpointer
2 parents 97e9269 + 0492c6e commit 6e923b1

File tree

3 files changed

+89
-5
lines changed

3 files changed

+89
-5
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ dependencies = [
2727
"langchain-text-splitters>=0.3.9",
2828
"langfuse>=3.6.2",
2929
"langgraph>=0.6.2",
30+
"langgraph-checkpoint-cosmosdb>=0.2.4",
3031
"langgraph-checkpoint-sqlite>=2.0.11",
3132
"langgraph-supervisor>=0.0.29",
3233
"mlflow>=3.4.0",

template_langgraph/services/streamlits/pages/chat_with_tools_agent.py

Lines changed: 71 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,18 @@
44
import uuid
55
from base64 import b64encode
66
from dataclasses import dataclass
7+
from enum import Enum
78

89
import streamlit as st
910
from audio_recorder_streamlit import audio_recorder
1011
from langchain_community.callbacks.streamlit import (
1112
StreamlitCallbackHandler,
1213
)
1314
from langfuse.langchain import CallbackHandler
15+
from langgraph.checkpoint.memory import InMemorySaver
1416
from langgraph.checkpoint.sqlite import SqliteSaver
1517
from langgraph.store.sqlite import SqliteStore
18+
from langgraph_checkpoint_cosmosdb import CosmosDBSaver
1619

1720
from template_langgraph.agents.chat_with_tools_agent.agent import (
1821
AgentState,
@@ -22,7 +25,23 @@
2225
from template_langgraph.speeches.tts import TtsWrapper
2326
from template_langgraph.tools.common import get_default_tools
2427

25-
checkpoints_conn = sqlite3.connect("checkpoints.sqlite", check_same_thread=False)
28+
29+
class CheckpointType(str, Enum):
30+
SQLITE = "sqlite"
31+
COSMOSDB = "cosmosdb"
32+
MEMORY = "memory"
33+
NONE = "none"
34+
35+
36+
DEFAULT_CHECKPOINT_TYPE = CheckpointType.NONE
37+
CHECKPOINT_LABELS = {
38+
CheckpointType.COSMOSDB.value: "Cosmos DB",
39+
CheckpointType.SQLITE.value: "SQLite",
40+
CheckpointType.MEMORY.value: "メモリ",
41+
CheckpointType.NONE.value: "なし",
42+
}
43+
44+
2645
store_conn = sqlite3.connect("store.sqlite", check_same_thread=False)
2746
thread_id = str(uuid.uuid4())
2847

@@ -70,17 +89,47 @@ def ensure_session_state_defaults(tool_names: list[str]) -> None:
7089
st.session_state.setdefault("chat_history", [])
7190
st.session_state.setdefault("input_output_mode", "テキスト")
7291
st.session_state.setdefault("selected_tool_names", tool_names)
92+
st.session_state.setdefault("checkpoint_type", DEFAULT_CHECKPOINT_TYPE.value)
93+
94+
95+
def get_selected_checkpoint_type() -> CheckpointType:
96+
raw_value = st.session_state.get("checkpoint_type", DEFAULT_CHECKPOINT_TYPE.value)
97+
try:
98+
checkpoint = CheckpointType(raw_value)
99+
except ValueError:
100+
st.session_state["checkpoint_type"] = DEFAULT_CHECKPOINT_TYPE.value
101+
return DEFAULT_CHECKPOINT_TYPE
102+
return checkpoint
103+
104+
105+
def get_checkpointer():
106+
checkpoint_type = get_selected_checkpoint_type()
107+
if checkpoint_type is CheckpointType.SQLITE:
108+
conn = sqlite3.connect("checkpoints.sqlite", check_same_thread=False)
109+
return SqliteSaver(conn=conn)
110+
if checkpoint_type is CheckpointType.COSMOSDB:
111+
from template_langgraph.tools.cosmosdb_tool import get_cosmosdb_settings
112+
113+
settings = get_cosmosdb_settings()
114+
os.environ["COSMOSDB_ENDPOINT"] = settings.cosmosdb_host
115+
os.environ["COSMOSDB_KEY"] = settings.cosmosdb_key
116+
117+
return CosmosDBSaver(
118+
database_name=settings.cosmosdb_database_name,
119+
container_name="checkpoints",
120+
)
121+
if checkpoint_type is CheckpointType.MEMORY:
122+
return InMemorySaver()
123+
return None
73124

74125

75126
def ensure_agent_graph(selected_tools: list) -> None:
76-
signature = tuple(tool.name for tool in selected_tools)
127+
signature = (tuple(tool.name for tool in selected_tools), get_selected_checkpoint_type().value)
77128
graph_signature = st.session_state.get("graph_tools_signature")
78129
if "graph" not in st.session_state or graph_signature != signature:
79130
st.session_state["graph"] = ChatWithToolsAgent(
80131
tools=selected_tools,
81-
checkpointer=SqliteSaver(
82-
conn=checkpoints_conn,
83-
),
132+
checkpointer=get_checkpointer(),
84133
store=SqliteStore(
85134
conn=store_conn,
86135
),
@@ -111,6 +160,23 @@ def build_sidebar() -> tuple[str, AudioSettings | None]:
111160
if input_mode == "音声":
112161
audio_settings = render_audio_controls()
113162

163+
st.divider()
164+
st.subheader("チェックポイント")
165+
166+
checkpoint_options = [checkpoint.value for checkpoint in CheckpointType]
167+
current_checkpoint_value = st.session_state["checkpoint_type"]
168+
if current_checkpoint_value not in checkpoint_options:
169+
current_checkpoint_value = DEFAULT_CHECKPOINT_TYPE.value
170+
st.session_state["checkpoint_type"] = current_checkpoint_value
171+
checkpoint_index = checkpoint_options.index(current_checkpoint_value)
172+
selected_checkpoint_value = st.selectbox(
173+
"保存方法",
174+
options=checkpoint_options,
175+
index=checkpoint_index,
176+
format_func=lambda value: CHECKPOINT_LABELS.get(value, value),
177+
)
178+
st.session_state["checkpoint_type"] = selected_checkpoint_value
179+
114180
st.divider()
115181
st.subheader("使用するツール")
116182

uv.lock

Lines changed: 17 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)