Skip to content

Commit d25325d

Browse files
committed
Added test of langgraph time travel to test low level checkpointer behavior via high level graph api.
1 parent e21428a commit d25325d

File tree

1 file changed

+163
-0
lines changed

1 file changed

+163
-0
lines changed
Lines changed: 163 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,163 @@
1+
import os
2+
from collections.abc import Generator
3+
from operator import add
4+
from typing import Annotated, Any, TypedDict
5+
6+
import pytest
7+
from langchain_core.runnables import RunnableConfig
8+
from pymongo import MongoClient
9+
from typing_extensions import NotRequired
10+
11+
from langgraph.checkpoint.mongodb import MongoDBSaver
12+
from langgraph.graph import END, START, StateGraph
13+
from langgraph.types import StateSnapshot
14+
15+
# Test configuration
16+
MONGODB_URI = os.environ.get(
17+
"MONGODB_URI", "mongodb://127.0.0.1:27017?directConnection=true"
18+
)
19+
20+
21+
class ExpenseState(TypedDict):
22+
amount: NotRequired[int]
23+
version: NotRequired[int]
24+
approved: NotRequired[bool]
25+
messages: Annotated[list[str], add]
26+
27+
28+
def add_expense_node(state: ExpenseState) -> dict[str, Any]:
29+
"""Node adds expense and a message"""
30+
return dict(amount=100, version=1, approved=False, messages=["Added new expense"])
31+
32+
33+
def validate_expense_node(state: ExpenseState) -> dict[str, Any]:
34+
"""Node that processes data based on current state"""
35+
if state.get("amount") == 200:
36+
return dict(approved=True, messages=["expense approved"])
37+
else:
38+
return dict(approved=False, messages=["expense denied"])
39+
40+
41+
@pytest.fixture(
42+
params=[None, 60 * 60],
43+
ids=["ttl_none", "ttl_3600"],
44+
)
45+
def checkpointer(request: Any) -> Generator[MongoDBSaver]:
46+
db_name = "langgraph_timetravel_db"
47+
checkpoint_collection_name = "checkpoints"
48+
writes_collection_name = "checkpoint_writes"
49+
50+
# Initialize MongoDB checkpointer
51+
client: MongoClient = MongoClient(MONGODB_URI)
52+
53+
# Clean up any existing test data.
54+
client.drop_database(db_name)
55+
56+
saver = MongoDBSaver(
57+
client=client,
58+
db_name=db_name,
59+
collection_name=checkpoint_collection_name,
60+
WRITES_COLLECTION_NAME=writes_collection_name,
61+
ttl=request.param,
62+
)
63+
64+
# Can use this to compare
65+
# saver = InMemorySaver()
66+
67+
yield saver
68+
69+
client[db_name].drop_collection(checkpoint_collection_name)
70+
client[db_name].drop_collection(writes_collection_name)
71+
client.close()
72+
73+
74+
def test(checkpointer: MongoDBSaver) -> None:
75+
"""Test ability to use checkpointer to update exact state of graph.
76+
77+
In this simple example, we assume an initial state has been set incorrectly.
78+
To fix this, instead of rerunning from start,
79+
we find the incorrect node, update_state, and continue (by passing None to invoke or stream).
80+
81+
This example does not use interrupt/resume as one might, for example,
82+
in an expense report approval workflow.
83+
"""
84+
initial_state: ExpenseState = dict(
85+
amount=0, version=0, approved=False, messages=["Initial state"]
86+
)
87+
config: RunnableConfig = dict(configurable=dict(thread_id="test-time-travel"))
88+
89+
# Create the graph, which should be a 2-step procedure
90+
workflow = StateGraph(ExpenseState)
91+
workflow.add_node("add_expense", add_expense_node)
92+
workflow.add_node("validate_expense", validate_expense_node)
93+
workflow.add_edge(START, "add_expense")
94+
workflow.add_edge("validate_expense", END)
95+
workflow.add_edge("add_expense", "validate_expense")
96+
graph = workflow.compile(checkpointer=checkpointer)
97+
98+
# Run the graph
99+
graph.invoke(input=initial_state, config=config) # type:ignore[arg-type]
100+
101+
# Check to see whether the final state is approved
102+
final_state = graph.get_state(config=config)
103+
104+
# It is not approved.
105+
assert not final_state.values["approved"]
106+
107+
# Let's use time-travel to find the checkpoint before "add_expense"
108+
checkpoints: list[StateSnapshot] = list(graph.get_state_history(config))
109+
# checkpoints: list[CheckpointTuple] = list(checkpointer.list(config))
110+
print(f"\nFound {len(checkpoints)} checkpoints")
111+
112+
target_checkpoint = None
113+
for checkpoint in checkpoints:
114+
# Look for checkpoint after increment but before final processing
115+
if (
116+
checkpoint.metadata and checkpoint.metadata.get("step") == 1
117+
): # Before validate node
118+
target_checkpoint = checkpoint
119+
break
120+
121+
for state in checkpoints:
122+
if state.metadata:
123+
print(f"\n{state.metadata["step"]=}")
124+
print(f"{state.next=}")
125+
print(f"{state.config["configurable"]["checkpoint_id"]=}")
126+
print(f"{state.values=}")
127+
128+
# Get state at that checkpoint
129+
assert target_checkpoint
130+
past_state = graph.get_state(target_checkpoint.config)
131+
132+
# Update the expense amount to 200 that validate amounts
133+
updated_state = dict(**past_state.values)
134+
# updated_state = {}
135+
updated_state["amount"] = 200
136+
updated_state["version"] = 2
137+
updated_state["messages"] += ["Updated state"]
138+
139+
updated_config = graph.update_state(
140+
config=target_checkpoint.config, values=updated_state
141+
)
142+
143+
# Continue from the checkpoint
144+
print("\nContinuing execution with stream(None, config)...")
145+
final_step = None
146+
for step in graph.stream(None, updated_config):
147+
print(f"Continuation step: {step}")
148+
final_step = step
149+
150+
# Verify the final result
151+
assert isinstance(final_step, dict)
152+
assert final_step["validate_expense"]["approved"]
153+
# Note that all values are not in the final step
154+
assert "amount" not in final_step["validate_expense"]
155+
# They ARE available from graph.get_state
156+
final_state = graph.get_state(updated_config)
157+
assert final_state.values["amount"] == 200
158+
assert set(final_state.values.keys()) == {
159+
"amount",
160+
"version",
161+
"messages",
162+
"approved",
163+
}

0 commit comments

Comments
 (0)