Skip to content

Commit 56cf475

Browse files
committed
add checkpoint operator
1 parent eb5f05b commit 56cf475

File tree

1 file changed

+128
-0
lines changed

1 file changed

+128
-0
lines changed

scripts/checkpoint_operator.py

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
import logging
2+
import os
3+
import sqlite3
4+
from enum import Enum
5+
6+
import typer
7+
from dotenv import load_dotenv
8+
from langgraph.checkpoint.memory import InMemorySaver
9+
from langgraph.checkpoint.sqlite import SqliteSaver
10+
from langgraph_checkpoint_cosmosdb import CosmosDBSaver
11+
12+
from template_langgraph.loggers import get_logger
13+
14+
# Initialize the Typer application
15+
app = typer.Typer(
16+
add_completion=False,
17+
help="Checkpoint Operator CLI",
18+
)
19+
20+
# Set up logging
21+
logger = get_logger(__name__)
22+
23+
24+
class CheckpointType(str, Enum):
25+
SQLITE = "sqlite"
26+
COSMOSDB = "cosmosdb"
27+
MEMORY = "memory"
28+
NONE = "none"
29+
30+
31+
DEFAULT_CHECKPOINT_TYPE = CheckpointType.NONE
32+
CHECKPOINT_LABELS = {
33+
CheckpointType.COSMOSDB.value: "Cosmos DB",
34+
CheckpointType.SQLITE.value: "SQLite",
35+
CheckpointType.MEMORY.value: "メモリ",
36+
CheckpointType.NONE.value: "なし",
37+
}
38+
39+
40+
def get_selected_checkpoint_type(raw_value: str) -> CheckpointType:
41+
try:
42+
checkpoint = CheckpointType(raw_value)
43+
except ValueError:
44+
return DEFAULT_CHECKPOINT_TYPE
45+
return checkpoint
46+
47+
48+
def get_checkpointer(raw_value: str):
49+
checkpoint_type = get_selected_checkpoint_type(
50+
raw_value=raw_value,
51+
)
52+
if checkpoint_type is CheckpointType.SQLITE:
53+
conn = sqlite3.connect("checkpoints.sqlite", check_same_thread=False)
54+
return SqliteSaver(conn=conn)
55+
if checkpoint_type is CheckpointType.COSMOSDB:
56+
from template_langgraph.tools.cosmosdb_tool import get_cosmosdb_settings
57+
58+
settings = get_cosmosdb_settings()
59+
os.environ["COSMOSDB_ENDPOINT"] = settings.cosmosdb_host
60+
os.environ["COSMOSDB_KEY"] = settings.cosmosdb_key
61+
62+
return CosmosDBSaver(
63+
database_name=settings.cosmosdb_database_name,
64+
container_name="checkpoints",
65+
)
66+
if checkpoint_type is CheckpointType.MEMORY:
67+
return InMemorySaver()
68+
return None
69+
70+
71+
@app.command()
72+
def list_checkpoints(
73+
checkpoint_type: str = typer.Option(
74+
DEFAULT_CHECKPOINT_TYPE.value,
75+
"--type",
76+
"-t",
77+
case_sensitive=False,
78+
help=f"Type of checkpoint to list. Options: {', '.join([f'{key} ({value})' for key, value in CHECKPOINT_LABELS.items()])}. Default is '{DEFAULT_CHECKPOINT_TYPE.value}'.", # noqa: E501
79+
),
80+
verbose: bool = typer.Option(
81+
False,
82+
"--verbose",
83+
"-v",
84+
help="Enable verbose output",
85+
),
86+
):
87+
# Set up logging
88+
if verbose:
89+
logger.setLevel(logging.DEBUG)
90+
91+
checkpointer = get_checkpointer(raw_value=checkpoint_type)
92+
if checkpointer is None:
93+
logger.info("No checkpointing is configured.")
94+
return
95+
checkpoints = checkpointer.list(
96+
config=None,
97+
)
98+
for checkpoint in checkpoints:
99+
logger.info(f"Thread ID: {checkpoint.config['configurable'].get('thread_id')}")
100+
logger.info(f"{checkpoint.checkpoint['channel_values']}")
101+
messages = checkpoint.checkpoint["channel_values"].get("messages") or []
102+
for message in messages:
103+
if message is not None:
104+
logger.info(f" - {message}")
105+
else:
106+
logger.info(" - None")
107+
108+
109+
@app.command()
110+
def list_messages(
111+
verbose: bool = typer.Option(
112+
False,
113+
"--verbose",
114+
"-v",
115+
help="Enable verbose output",
116+
),
117+
):
118+
# Set up logging
119+
if verbose:
120+
logger.setLevel(logging.DEBUG)
121+
122+
123+
if __name__ == "__main__":
124+
load_dotenv(
125+
override=True,
126+
verbose=True,
127+
)
128+
app()

0 commit comments

Comments
 (0)