Skip to content

Commit c16d361

Browse files
authored
Merge pull request #185 from ks6088ts-labs/feature/issue-184_retrieve-chat-history-from-checkpoint
retrieve chat history from checkpoint
2 parents 6e923b1 + d9d3900 commit c16d361

File tree

4 files changed

+466
-164
lines changed

4 files changed

+466
-164
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,8 +96,8 @@ environment = { python-version = "3.10" }
9696
unknown-argument = "ignore"
9797
invalid-parameter-default = "ignore"
9898
non-subscriptable = "ignore"
99-
possibly-unbound-attribute = "ignore"
10099
unresolved-attribute = "ignore"
101100
invalid-argument-type = "ignore"
102101
invalid-type-form = "ignore"
103102
invalid-assignment = "ignore"
103+
possibly-missing-attribute = "ignore"

scripts/checkpoint_operator.py

Lines changed: 261 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,261 @@
1+
import logging
2+
import os
3+
import sqlite3
4+
from enum import Enum
5+
6+
import typer
7+
from dotenv import load_dotenv
8+
from langgraph.checkpoint.memory import InMemorySaver
9+
from langgraph.checkpoint.sqlite import SqliteSaver
10+
from langgraph_checkpoint_cosmosdb import CosmosDBSaver
11+
12+
from template_langgraph.loggers import get_logger
13+
14+
# Initialize the Typer application
15+
app = typer.Typer(
16+
add_completion=False,
17+
help="Checkpoint Operator CLI",
18+
)
19+
20+
# Set up logging
21+
logger = get_logger(__name__)
22+
23+
24+
class CheckpointType(str, Enum):
25+
SQLITE = "sqlite"
26+
COSMOSDB = "cosmosdb"
27+
MEMORY = "memory"
28+
NONE = "none"
29+
30+
31+
DEFAULT_CHECKPOINT_TYPE = CheckpointType.NONE
32+
CHECKPOINT_LABELS = {
33+
CheckpointType.COSMOSDB.value: "Cosmos DB",
34+
CheckpointType.SQLITE.value: "SQLite",
35+
CheckpointType.MEMORY.value: "メモリ",
36+
CheckpointType.NONE.value: "なし",
37+
}
38+
39+
40+
def get_selected_checkpoint_type(raw_value: str) -> CheckpointType:
41+
try:
42+
checkpoint = CheckpointType(raw_value)
43+
except ValueError:
44+
return DEFAULT_CHECKPOINT_TYPE
45+
return checkpoint
46+
47+
48+
def get_checkpointer(raw_value: str):
49+
checkpoint_type = get_selected_checkpoint_type(
50+
raw_value=raw_value,
51+
)
52+
if checkpoint_type is CheckpointType.SQLITE:
53+
conn = sqlite3.connect("checkpoints.sqlite", check_same_thread=False)
54+
return SqliteSaver(conn=conn)
55+
if checkpoint_type is CheckpointType.COSMOSDB:
56+
from template_langgraph.tools.cosmosdb_tool import get_cosmosdb_settings
57+
58+
settings = get_cosmosdb_settings()
59+
os.environ["COSMOSDB_ENDPOINT"] = settings.cosmosdb_host
60+
os.environ["COSMOSDB_KEY"] = settings.cosmosdb_key
61+
62+
return CosmosDBSaver(
63+
database_name=settings.cosmosdb_database_name,
64+
container_name="checkpoints",
65+
)
66+
if checkpoint_type is CheckpointType.MEMORY:
67+
return InMemorySaver()
68+
return None
69+
70+
71+
@app.command()
72+
def list_checkpoints(
73+
checkpoint_type: str = typer.Option(
74+
DEFAULT_CHECKPOINT_TYPE.value,
75+
"--type",
76+
"-t",
77+
case_sensitive=False,
78+
help=f"Type of checkpoint to list. Options: {', '.join([f'{key} ({value})' for key, value in CHECKPOINT_LABELS.items()])}. Default is '{DEFAULT_CHECKPOINT_TYPE.value}'.", # noqa: E501
79+
),
80+
verbose: bool = typer.Option(
81+
False,
82+
"--verbose",
83+
"-v",
84+
help="Enable verbose output",
85+
),
86+
):
87+
"""List all available checkpoints with their thread IDs and basic information."""
88+
# Set up logging
89+
if verbose:
90+
logger.setLevel(logging.DEBUG)
91+
92+
logger.info(f"Using checkpoint type: {CHECKPOINT_LABELS.get(checkpoint_type, checkpoint_type)}")
93+
94+
checkpointer = get_checkpointer(raw_value=checkpoint_type)
95+
if checkpointer is None:
96+
logger.info("No checkpointing is configured.")
97+
return
98+
99+
try:
100+
checkpoints = list(checkpointer.list(config=None))
101+
102+
if not checkpoints:
103+
logger.info("No checkpoints found.")
104+
return
105+
106+
logger.info(f"Found {len(checkpoints)} checkpoint(s):")
107+
logger.info("-" * 60)
108+
109+
for i, checkpoint in enumerate(checkpoints, 1):
110+
logger.debug(f"Checkpoint raw data: {checkpoint}")
111+
thread_id = checkpoint.config["configurable"].get("thread_id", "Unknown")
112+
checkpoint_id = checkpoint.config["configurable"].get("checkpoint_id", "Unknown")
113+
114+
logger.info(f"{i}.")
115+
logger.info(f" Thread ID: {thread_id}")
116+
logger.info(f" Checkpoint ID: {checkpoint_id}")
117+
118+
# Count messages in this checkpoint
119+
messages = checkpoint.checkpoint["channel_values"].get("messages") or []
120+
non_null_messages = [msg for msg in messages if msg is not None]
121+
logger.info(f" Messages: {len(non_null_messages)} total")
122+
123+
if verbose and non_null_messages:
124+
logger.info(" Recent messages:")
125+
# Show last 2 messages for brevity
126+
for msg in non_null_messages[-2:]:
127+
if hasattr(msg, "content"):
128+
content = str(msg.content)
129+
content_preview = content[:100] + "..." if len(content) > 100 else content
130+
msg_type = type(msg).__name__
131+
logger.info(f" [{msg_type}] {content_preview}")
132+
133+
logger.info("-" * 60)
134+
135+
except Exception as e:
136+
logger.error(f"Error listing checkpoints: {str(e)}")
137+
if verbose:
138+
import traceback
139+
140+
logger.debug(traceback.format_exc())
141+
142+
143+
@app.command()
144+
def list_messages(
145+
thread_id: str = typer.Option(
146+
...,
147+
"--thread-id",
148+
"-i",
149+
help="Thread ID of the checkpoint to list messages from",
150+
),
151+
checkpoint_type: str = typer.Option(
152+
DEFAULT_CHECKPOINT_TYPE.value,
153+
"--type",
154+
"-t",
155+
case_sensitive=False,
156+
help=f"Type of checkpoint to use. Options: {', '.join([f'{key} ({value})' for key, value in CHECKPOINT_LABELS.items()])}. Default is '{DEFAULT_CHECKPOINT_TYPE.value}'.", # noqa: E501
157+
),
158+
limit: int = typer.Option(None, "--limit", "-l", help="Maximum number of messages to display (default: all)"),
159+
verbose: bool = typer.Option(
160+
False,
161+
"--verbose",
162+
"-v",
163+
help="Enable verbose output",
164+
),
165+
):
166+
"""List messages from a specific checkpoint thread."""
167+
# Set up logging
168+
if verbose:
169+
logger.setLevel(logging.DEBUG)
170+
171+
logger.info(f"Using checkpoint type: {CHECKPOINT_LABELS.get(checkpoint_type, checkpoint_type)}")
172+
logger.info(f"Retrieving messages for thread ID: {thread_id}")
173+
174+
checkpointer = get_checkpointer(raw_value=checkpoint_type)
175+
if checkpointer is None:
176+
logger.info("No checkpointing is configured.")
177+
return
178+
179+
try:
180+
# Search for the specific thread
181+
checkpoints = list(checkpointer.list(config=None))
182+
target_checkpoint = None
183+
184+
for checkpoint in checkpoints:
185+
if checkpoint.config["configurable"].get("thread_id") == thread_id:
186+
target_checkpoint = checkpoint
187+
break
188+
189+
if target_checkpoint is None:
190+
logger.error(f"Thread ID '{thread_id}' not found.")
191+
logger.info("Available thread IDs:")
192+
for checkpoint in checkpoints:
193+
available_thread_id = checkpoint.config["configurable"].get("thread_id")
194+
logger.info(f" - {available_thread_id}")
195+
return
196+
197+
# Extract messages
198+
messages = target_checkpoint.checkpoint["channel_values"].get("messages") or []
199+
non_null_messages = [msg for msg in messages if msg is not None]
200+
201+
if not non_null_messages:
202+
logger.info("No messages found in this thread.")
203+
return
204+
205+
# Apply limit if specified
206+
if limit and limit > 0:
207+
if limit < len(non_null_messages):
208+
logger.info(f"Showing last {limit} of {len(non_null_messages)} messages:")
209+
non_null_messages = non_null_messages[-limit:]
210+
else:
211+
logger.info(f"Showing all {len(non_null_messages)} messages:")
212+
else:
213+
logger.info(f"Showing all {len(non_null_messages)} messages:")
214+
215+
logger.info("=" * 80)
216+
217+
for i, msg in enumerate(non_null_messages, 1):
218+
msg_type = type(msg).__name__
219+
logger.info(f"Message {i} [{msg_type}]:")
220+
221+
# Handle different message types
222+
if hasattr(msg, "content"):
223+
logger.info(f" Content: {msg.content}")
224+
225+
if hasattr(msg, "role"):
226+
logger.info(f" Role: {msg.role}")
227+
228+
if hasattr(msg, "name"):
229+
logger.info(f" Name: {msg.name}")
230+
231+
if hasattr(msg, "tool_calls") and msg.tool_calls:
232+
logger.info(f" Tool calls: {len(msg.tool_calls)}")
233+
if verbose:
234+
for j, tool_call in enumerate(msg.tool_calls, 1):
235+
logger.info(f" {j}. {tool_call}")
236+
237+
if hasattr(msg, "additional_kwargs") and msg.additional_kwargs and verbose:
238+
logger.info(f" Additional kwargs: {msg.additional_kwargs}")
239+
240+
# Show raw message in verbose mode
241+
if verbose:
242+
logger.info(f" Raw: {msg}")
243+
244+
logger.info("-" * 40)
245+
246+
logger.info("=" * 80)
247+
248+
except Exception as e:
249+
logger.error(f"Error retrieving messages: {str(e)}")
250+
if verbose:
251+
import traceback
252+
253+
logger.debug(traceback.format_exc())
254+
255+
256+
if __name__ == "__main__":
257+
load_dotenv(
258+
override=True,
259+
verbose=True,
260+
)
261+
app()

0 commit comments

Comments
 (0)