Skip to content

Commit 0999e36

Browse files
INTPYTHON-748 Fixed update behavior of ttl index, exposed in interrupt workflows. (#207)
Investigate DuplicateKeyError when graph interrupted and TTL is defined Should fix langchain-ai/langgraph#6040, "E11000 duplicate key error collection : checkpoint_writes_aio index" @spsingh559 Would you please confirm that this fixes your issue?
1 parent 9f581b5 commit 0999e36

File tree

4 files changed

+252
-22
lines changed

4 files changed

+252
-22
lines changed

libs/langgraph-checkpoint-mongodb/langgraph/checkpoint/mongodb/aio.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -359,7 +359,7 @@ async def aput(
359359
type_, serialized_checkpoint = self.serde.dumps_typed(checkpoint)
360360
metadata = metadata.copy()
361361
metadata.update(config.get("metadata", {}))
362-
doc = {
362+
doc: dict[str, Any] = {
363363
"parent_checkpoint_id": config["configurable"].get("checkpoint_id"),
364364
"type": type_,
365365
"checkpoint": serialized_checkpoint,
@@ -410,6 +410,7 @@ async def aput_writes(
410410
"$set" if all(w[0] in WRITES_IDX_MAP for w in writes) else "$setOnInsert"
411411
)
412412
operations = []
413+
now = datetime.now()
413414
for idx, (channel, value) in enumerate(writes):
414415
upsert_query = {
415416
"thread_id": thread_id,
@@ -419,19 +420,22 @@ async def aput_writes(
419420
"task_path": task_path,
420421
"idx": WRITES_IDX_MAP.get(channel, idx),
421422
}
422-
if self.ttl:
423-
upsert_query["created_at"] = datetime.now()
423+
424424
type_, serialized_value = self.serde.dumps_typed(value)
425+
426+
update_doc: dict[str, Any] = {
427+
"channel": channel,
428+
"type": type_,
429+
"value": serialized_value,
430+
}
431+
432+
if self.ttl:
433+
update_doc["created_at"] = now
434+
425435
operations.append(
426436
UpdateOne(
427-
upsert_query,
428-
{
429-
set_method: {
430-
"channel": channel,
431-
"type": type_,
432-
"value": serialized_value,
433-
}
434-
},
437+
filter=upsert_query,
438+
update={set_method: update_doc},
435439
upsert=True,
436440
)
437441
)

libs/langgraph-checkpoint-mongodb/langgraph/checkpoint/mongodb/saver.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -390,7 +390,7 @@ def put(
390390
"checkpoint_id": checkpoint_id,
391391
}
392392
if self.ttl:
393-
upsert_query["created_at"] = datetime.now()
393+
doc["created_at"] = datetime.now()
394394

395395
self.checkpoint_collection.update_one(upsert_query, {"$set": doc}, upsert=True)
396396
return {
@@ -425,6 +425,7 @@ def put_writes(
425425
"$set" if all(w[0] in WRITES_IDX_MAP for w in writes) else "$setOnInsert"
426426
)
427427
operations = []
428+
now = datetime.now()
428429
for idx, (channel, value) in enumerate(writes):
429430
upsert_query = {
430431
"thread_id": thread_id,
@@ -434,20 +435,22 @@ def put_writes(
434435
"task_path": task_path,
435436
"idx": WRITES_IDX_MAP.get(channel, idx),
436437
}
437-
if self.ttl:
438-
upsert_query["created_at"] = datetime.now()
439438

440439
type_, serialized_value = self.serde.dumps_typed(value)
440+
441+
update_doc: dict[str, Any] = {
442+
"channel": channel,
443+
"type": type_,
444+
"value": serialized_value,
445+
}
446+
447+
if self.ttl:
448+
update_doc["created_at"] = now
449+
441450
operations.append(
442451
UpdateOne(
443-
upsert_query,
444-
{
445-
set_method: {
446-
"channel": channel,
447-
"type": type_,
448-
"value": serialized_value,
449-
}
450-
},
452+
filter=upsert_query,
453+
update={set_method: update_doc},
451454
upsert=True,
452455
)
453456
)
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
import os
2+
from collections.abc import AsyncGenerator
3+
4+
import pytest
5+
import pytest_asyncio
6+
from langchain_core.runnables import RunnableConfig
7+
from pymongo import MongoClient
8+
9+
from langgraph.checkpoint.mongodb import MongoDBSaver
10+
from langgraph.types import Interrupt
11+
12+
MONGODB_URI = os.environ.get(
13+
"MONGODB_URI", "mongodb://127.0.0.1:27017?directConnection=true"
14+
)
15+
DB_NAME: str = "test_langgraph_db"
16+
COLLECTION_NAME: str = "checkpoints_interrupts"
17+
WRITES_COLLECTION_NAME: str = "writes_interrupts"
18+
TTL: int = 60 * 60
19+
20+
21+
@pytest_asyncio.fixture()
22+
async def async_saver(request: pytest.FixtureRequest) -> AsyncGenerator:
23+
# Use sync client and checkpointer with async methods run in executor
24+
client: MongoClient = MongoClient(MONGODB_URI)
25+
db = client[DB_NAME]
26+
for clxn in db.list_collection_names():
27+
db.drop_collection(clxn)
28+
with MongoDBSaver.from_conn_string(
29+
MONGODB_URI, DB_NAME, COLLECTION_NAME, WRITES_COLLECTION_NAME, TTL
30+
) as checkpointer:
31+
yield checkpointer
32+
client.close()
33+
34+
35+
async def test_put_writes_on_interrupt(async_saver: MongoDBSaver) -> None:
36+
"""Test that no error is raised when interrupted workflow updates writes."""
37+
config: RunnableConfig = {
38+
"configurable": {
39+
"checkpoint_id": "check1",
40+
"thread_id": "thread1",
41+
"checkpoint_ns": "",
42+
}
43+
}
44+
task_id = "task_id"
45+
task_path = "~__pregel_pull, human_feedback"
46+
47+
writes1 = [
48+
(
49+
"__interrupt__",
50+
(
51+
Interrupt(
52+
value="please provide input",
53+
),
54+
),
55+
)
56+
]
57+
await async_saver.aput_writes(config, writes1, task_id, task_path)
58+
59+
writes2 = [("__interrupt__", (Interrupt(value="please provide another input"),))]
60+
await async_saver.aput_writes(config, writes2, task_id, task_path)
Lines changed: 163 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,163 @@
1+
import os
2+
from collections.abc import Generator
3+
from operator import add
4+
from typing import Annotated, Any, TypedDict
5+
6+
import pytest
7+
from langchain_core.runnables import RunnableConfig
8+
from pymongo import MongoClient
9+
from typing_extensions import NotRequired
10+
11+
from langgraph.checkpoint.mongodb import MongoDBSaver
12+
from langgraph.graph import END, START, StateGraph
13+
from langgraph.types import StateSnapshot
14+
15+
# Test configuration
16+
MONGODB_URI = os.environ.get(
17+
"MONGODB_URI", "mongodb://127.0.0.1:27017?directConnection=true"
18+
)
19+
20+
21+
class ExpenseState(TypedDict):
22+
amount: NotRequired[int]
23+
version: NotRequired[int]
24+
approved: NotRequired[bool]
25+
messages: Annotated[list[str], add]
26+
27+
28+
def add_expense_node(state: ExpenseState) -> dict[str, Any]:
29+
"""Node adds expense and a message"""
30+
return dict(amount=100, version=1, approved=False, messages=["Added new expense"])
31+
32+
33+
def validate_expense_node(state: ExpenseState) -> dict[str, Any]:
34+
"""Node that processes data based on current state"""
35+
if state.get("amount") == 200:
36+
return dict(approved=True, messages=["expense approved"])
37+
else:
38+
return dict(approved=False, messages=["expense denied"])
39+
40+
41+
@pytest.fixture(
42+
params=[None, 60 * 60],
43+
ids=["ttl_none", "ttl_3600"],
44+
)
45+
def checkpointer(request: Any) -> Generator[MongoDBSaver]:
46+
db_name = "langgraph_timetravel_db"
47+
checkpoint_collection_name = "checkpoints"
48+
writes_collection_name = "checkpoint_writes"
49+
50+
# Initialize MongoDB checkpointer
51+
client: MongoClient = MongoClient(MONGODB_URI)
52+
53+
# Clean up any existing test data.
54+
client.drop_database(db_name)
55+
56+
saver = MongoDBSaver(
57+
client=client,
58+
db_name=db_name,
59+
collection_name=checkpoint_collection_name,
60+
WRITES_COLLECTION_NAME=writes_collection_name,
61+
ttl=request.param,
62+
)
63+
64+
# Can use this to compare
65+
# saver = InMemorySaver()
66+
67+
yield saver
68+
69+
client[db_name].drop_collection(checkpoint_collection_name)
70+
client[db_name].drop_collection(writes_collection_name)
71+
client.close()
72+
73+
74+
def test(checkpointer: MongoDBSaver) -> None:
75+
"""Test ability to use checkpointer to update exact state of graph.
76+
77+
In this simple example, we assume an initial state has been set incorrectly.
78+
To fix this, instead of rerunning from start,
79+
we find the incorrect node, update_state, and continue (by passing None to invoke or stream).
80+
81+
This example does not use interrupt/resume as one might, for example,
82+
in an expense report approval workflow.
83+
"""
84+
initial_state: ExpenseState = dict(
85+
amount=0, version=0, approved=False, messages=["Initial state"]
86+
)
87+
config: RunnableConfig = dict(configurable=dict(thread_id="test-time-travel"))
88+
89+
# Create the graph, which should be a 2-step procedure
90+
workflow = StateGraph(ExpenseState)
91+
workflow.add_node("add_expense", add_expense_node)
92+
workflow.add_node("validate_expense", validate_expense_node)
93+
workflow.add_edge(START, "add_expense")
94+
workflow.add_edge("validate_expense", END)
95+
workflow.add_edge("add_expense", "validate_expense")
96+
graph = workflow.compile(checkpointer=checkpointer)
97+
98+
# Run the graph
99+
graph.invoke(input=initial_state, config=config) # type:ignore[arg-type]
100+
101+
# Check to see whether the final state is approved
102+
final_state = graph.get_state(config=config)
103+
104+
# It is not approved.
105+
assert not final_state.values["approved"]
106+
107+
# Let's use time-travel to find the checkpoint before "add_expense"
108+
checkpoints: list[StateSnapshot] = list(graph.get_state_history(config))
109+
# checkpoints: list[CheckpointTuple] = list(checkpointer.list(config))
110+
print(f"\nFound {len(checkpoints)} checkpoints")
111+
112+
target_checkpoint = None
113+
for checkpoint in checkpoints:
114+
# Look for checkpoint after increment but before final processing
115+
if (
116+
checkpoint.metadata and checkpoint.metadata.get("step") == 1
117+
): # Before validate node
118+
target_checkpoint = checkpoint
119+
break
120+
121+
for state in checkpoints:
122+
if state.metadata:
123+
print(f"\nstep: {state.metadata['step']}")
124+
print(f"next: {state.next}")
125+
print(f"checkpoint_id: {state.config['configurable']['checkpoint_id']}")
126+
print(f"values: {state.values}")
127+
128+
# Get state at that checkpoint
129+
assert target_checkpoint
130+
past_state = graph.get_state(target_checkpoint.config)
131+
132+
# Update the expense amount to 200 that validate amounts
133+
updated_state = dict(**past_state.values)
134+
# updated_state = {}
135+
updated_state["amount"] = 200
136+
updated_state["version"] = 2
137+
updated_state["messages"] += ["Updated state"]
138+
139+
updated_config = graph.update_state(
140+
config=target_checkpoint.config, values=updated_state
141+
)
142+
143+
# Continue from the checkpoint
144+
print("\nContinuing execution with stream(None, config)...")
145+
final_step = None
146+
for step in graph.stream(None, updated_config):
147+
print(f"Continuation step: {step}")
148+
final_step = step
149+
150+
# Verify the final result
151+
assert isinstance(final_step, dict)
152+
assert final_step["validate_expense"]["approved"]
153+
# Note that all values are not in the final step
154+
assert "amount" not in final_step["validate_expense"]
155+
# They ARE available from graph.get_state
156+
final_state = graph.get_state(updated_config)
157+
assert final_state.values["amount"] == 200
158+
assert set(final_state.values.keys()) == {
159+
"amount",
160+
"version",
161+
"messages",
162+
"approved",
163+
}

0 commit comments

Comments
 (0)