|
11 | 11 | StreamlitCallbackHandler, |
12 | 12 | ) |
13 | 13 | from langfuse.langchain import CallbackHandler |
| 14 | +from langgraph.checkpoint.memory import InMemorySaver |
14 | 15 | from langgraph.checkpoint.sqlite import SqliteSaver |
15 | 16 | from langgraph.store.sqlite import SqliteStore |
| 17 | +from langgraph_checkpoint_cosmosdb import CosmosDBSaver |
16 | 18 |
|
17 | 19 | from template_langgraph.agents.chat_with_tools_agent.agent import ( |
18 | 20 | AgentState, |
|
22 | 24 | from template_langgraph.speeches.tts import TtsWrapper |
23 | 25 | from template_langgraph.tools.common import get_default_tools |
24 | 26 |
|
25 | | -checkpoints_conn = sqlite3.connect("checkpoints.sqlite", check_same_thread=False) |
| 27 | +checkpoint_type = "cosmosdb" # "cosmosdb" or "sqlite" |
26 | 28 | store_conn = sqlite3.connect("store.sqlite", check_same_thread=False) |
27 | 29 | thread_id = str(uuid.uuid4()) |
28 | 30 |
|
@@ -72,15 +74,35 @@ def ensure_session_state_defaults(tool_names: list[str]) -> None: |
72 | 74 | st.session_state.setdefault("selected_tool_names", tool_names) |
73 | 75 |
|
74 | 76 |
|
| 77 | +def get_checkpointer(): |
| 78 | + if checkpoint_type == "sqlite": |
| 79 | + conn = sqlite3.connect("checkpoints.sqlite", check_same_thread=False) |
| 80 | + return SqliteSaver(conn=conn) |
| 81 | + if checkpoint_type == "cosmosdb": |
| 82 | + import os |
| 83 | + |
| 84 | + from template_langgraph.tools.cosmosdb_tool import get_cosmosdb_settings |
| 85 | + |
| 86 | + settings = get_cosmosdb_settings() |
| 87 | + os.environ["COSMOSDB_ENDPOINT"] = settings.cosmosdb_host |
| 88 | + os.environ["COSMOSDB_KEY"] = settings.cosmosdb_key |
| 89 | + |
| 90 | + return CosmosDBSaver( |
| 91 | + database_name=settings.cosmosdb_database_name, |
| 92 | + container_name="checkpoints", |
| 93 | + ) |
| 94 | + if checkpoint_type == "memory": |
| 95 | + return InMemorySaver() |
| 96 | + return None |
| 97 | + |
| 98 | + |
75 | 99 | def ensure_agent_graph(selected_tools: list) -> None: |
76 | 100 | signature = tuple(tool.name for tool in selected_tools) |
77 | 101 | graph_signature = st.session_state.get("graph_tools_signature") |
78 | 102 | if "graph" not in st.session_state or graph_signature != signature: |
79 | 103 | st.session_state["graph"] = ChatWithToolsAgent( |
80 | 104 | tools=selected_tools, |
81 | | - checkpointer=SqliteSaver( |
82 | | - conn=checkpoints_conn, |
83 | | - ), |
| 105 | + checkpointer=get_checkpointer(), |
84 | 106 | store=SqliteStore( |
85 | 107 | conn=store_conn, |
86 | 108 | ), |
|
0 commit comments