Skip to content

Commit d278b82

Browse files
committed
add cosmosdb checkpoint
1 parent 97e9269 commit d278b82

File tree

3 files changed

+44
-4
lines changed

3 files changed

+44
-4
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ dependencies = [
2727
"langchain-text-splitters>=0.3.9",
2828
"langfuse>=3.6.2",
2929
"langgraph>=0.6.2",
30+
"langgraph-checkpoint-cosmosdb>=0.2.4",
3031
"langgraph-checkpoint-sqlite>=2.0.11",
3132
"langgraph-supervisor>=0.0.29",
3233
"mlflow>=3.4.0",

template_langgraph/services/streamlits/pages/chat_with_tools_agent.py

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,10 @@
1111
StreamlitCallbackHandler,
1212
)
1313
from langfuse.langchain import CallbackHandler
14+
from langgraph.checkpoint.memory import InMemorySaver
1415
from langgraph.checkpoint.sqlite import SqliteSaver
1516
from langgraph.store.sqlite import SqliteStore
17+
from langgraph_checkpoint_cosmosdb import CosmosDBSaver
1618

1719
from template_langgraph.agents.chat_with_tools_agent.agent import (
1820
AgentState,
@@ -22,7 +24,7 @@
2224
from template_langgraph.speeches.tts import TtsWrapper
2325
from template_langgraph.tools.common import get_default_tools
2426

25-
checkpoints_conn = sqlite3.connect("checkpoints.sqlite", check_same_thread=False)
27+
checkpoint_type = "cosmosdb" # "cosmosdb" or "sqlite"
2628
store_conn = sqlite3.connect("store.sqlite", check_same_thread=False)
2729
thread_id = str(uuid.uuid4())
2830

@@ -72,15 +74,35 @@ def ensure_session_state_defaults(tool_names: list[str]) -> None:
7274
st.session_state.setdefault("selected_tool_names", tool_names)
7375

7476

77+
def get_checkpointer():
78+
if checkpoint_type == "sqlite":
79+
conn = sqlite3.connect("checkpoints.sqlite", check_same_thread=False)
80+
return SqliteSaver(conn=conn)
81+
if checkpoint_type == "cosmosdb":
82+
import os
83+
84+
from template_langgraph.tools.cosmosdb_tool import get_cosmosdb_settings
85+
86+
settings = get_cosmosdb_settings()
87+
os.environ["COSMOSDB_ENDPOINT"] = settings.cosmosdb_host
88+
os.environ["COSMOSDB_KEY"] = settings.cosmosdb_key
89+
90+
return CosmosDBSaver(
91+
database_name=settings.cosmosdb_database_name,
92+
container_name="checkpoints",
93+
)
94+
if checkpoint_type == "memory":
95+
return InMemorySaver()
96+
return None
97+
98+
7599
def ensure_agent_graph(selected_tools: list) -> None:
76100
signature = tuple(tool.name for tool in selected_tools)
77101
graph_signature = st.session_state.get("graph_tools_signature")
78102
if "graph" not in st.session_state or graph_signature != signature:
79103
st.session_state["graph"] = ChatWithToolsAgent(
80104
tools=selected_tools,
81-
checkpointer=SqliteSaver(
82-
conn=checkpoints_conn,
83-
),
105+
checkpointer=get_checkpointer(),
84106
store=SqliteStore(
85107
conn=store_conn,
86108
),

uv.lock

Lines changed: 17 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)