Skip to content

Commit e9acfd3

Browse files
committed
Fixed update behavior of ttl index, exposed in interrupt workflows.
1 parent 6ae79bd commit e9acfd3

File tree

3 files changed

+100
-20
lines changed

3 files changed

+100
-20
lines changed

libs/langgraph-checkpoint-mongodb/langgraph/checkpoint/mongodb/aio.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -410,6 +410,7 @@ async def aput_writes(
410410
"$set" if all(w[0] in WRITES_IDX_MAP for w in writes) else "$setOnInsert"
411411
)
412412
operations = []
413+
now = datetime.now()
413414
for idx, (channel, value) in enumerate(writes):
414415
upsert_query = {
415416
"thread_id": thread_id,
@@ -419,19 +420,22 @@ async def aput_writes(
419420
"task_path": task_path,
420421
"idx": WRITES_IDX_MAP.get(channel, idx),
421422
}
422-
if self.ttl:
423-
upsert_query["created_at"] = datetime.now()
423+
424424
type_, serialized_value = self.serde.dumps_typed(value)
425+
426+
update_doc = {
427+
"channel": channel,
428+
"type": type_,
429+
"value": serialized_value,
430+
}
431+
432+
if self.ttl:
433+
update_doc["created_at"] = now
434+
425435
operations.append(
426436
UpdateOne(
427-
upsert_query,
428-
{
429-
set_method: {
430-
"channel": channel,
431-
"type": type_,
432-
"value": serialized_value,
433-
}
434-
},
437+
filter=upsert_query,
438+
update={set_method: update_doc},
435439
upsert=True,
436440
)
437441
)

libs/langgraph-checkpoint-mongodb/langgraph/checkpoint/mongodb/saver.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -425,6 +425,7 @@ def put_writes(
425425
"$set" if all(w[0] in WRITES_IDX_MAP for w in writes) else "$setOnInsert"
426426
)
427427
operations = []
428+
now = datetime.now()
428429
for idx, (channel, value) in enumerate(writes):
429430
upsert_query = {
430431
"thread_id": thread_id,
@@ -434,20 +435,22 @@ def put_writes(
434435
"task_path": task_path,
435436
"idx": WRITES_IDX_MAP.get(channel, idx),
436437
}
437-
if self.ttl:
438-
upsert_query["created_at"] = datetime.now()
439438

440439
type_, serialized_value = self.serde.dumps_typed(value)
440+
441+
update_doc = {
442+
"channel": channel,
443+
"type": type_,
444+
"value": serialized_value,
445+
}
446+
447+
if self.ttl:
448+
update_doc["created_at"] = now
449+
441450
operations.append(
442451
UpdateOne(
443-
upsert_query,
444-
{
445-
set_method: {
446-
"channel": channel,
447-
"type": type_,
448-
"value": serialized_value,
449-
}
450-
},
452+
filter=upsert_query,
453+
update={set_method: update_doc},
451454
upsert=True,
452455
)
453456
)
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
import os
2+
from collections.abc import AsyncGenerator
3+
4+
import pytest
5+
import pytest_asyncio
6+
from pymongo import AsyncMongoClient, MongoClient
7+
8+
from langgraph.checkpoint.mongodb import AsyncMongoDBSaver, MongoDBSaver
9+
from langgraph.types import Interrupt
10+
11+
MONGODB_URI = os.environ.get(
12+
"MONGODB_URI", "mongodb://127.0.0.1:27017?directConnection=true"
13+
)
14+
DB_NAME: str = "test_langgraph_db"
15+
COLLECTION_NAME: str = "checkpoints_interrupts"
16+
WRITES_COLLECTION_NAME: str = "writes_interrupts"
17+
TTL: int = 60 * 60
18+
19+
20+
@pytest_asyncio.fixture(params=["run_in_executor", "aio"])
21+
async def async_saver(request: pytest.FixtureRequest) -> AsyncGenerator:
22+
if request.param == "aio":
23+
# Use async client and checkpointer
24+
aclient: AsyncMongoClient = AsyncMongoClient(MONGODB_URI)
25+
adb = aclient[DB_NAME]
26+
for clxn in await adb.list_collection_names():
27+
await adb.drop_collection(clxn)
28+
async with AsyncMongoDBSaver.from_conn_string(
29+
MONGODB_URI, DB_NAME, COLLECTION_NAME, WRITES_COLLECTION_NAME, TTL
30+
) as checkpointer:
31+
yield checkpointer
32+
await aclient.close()
33+
else:
34+
# Use sync client and checkpointer with async methods run in executor
35+
client: MongoClient = MongoClient(MONGODB_URI)
36+
db = client[DB_NAME]
37+
for clxn in db.list_collection_names():
38+
db.drop_collection(clxn)
39+
with MongoDBSaver.from_conn_string(
40+
MONGODB_URI, DB_NAME, COLLECTION_NAME, WRITES_COLLECTION_NAME, TTL
41+
) as checkpointer:
42+
yield checkpointer
43+
client.close()
44+
45+
46+
def test_put_writes_on_interrupt(async_saver: MongoDBSaver):
47+
"""Test that no error is raised when interrupted workflow updates writes."""
48+
config = {
49+
"configurable": {
50+
"checkpoint_id": "check1",
51+
"thread_id": "thread1",
52+
"checkpoint_ns": "",
53+
}
54+
}
55+
task_id = "task_id"
56+
task_path = "~__pregel_pull, human_feedback"
57+
58+
writes1 = [
59+
(
60+
"__interrupt__",
61+
(
62+
Interrupt(
63+
value="please provide input",
64+
resumable=True,
65+
ns=["human_feedback:1b798da3"],
66+
),
67+
),
68+
)
69+
]
70+
async_saver.aput_writes(config, writes1, task_id, task_path)
71+
72+
writes2 = [("__interrupt__", (Interrupt(value="please provide another input"),))]
73+
async_saver.aput_writes(config, writes2, task_id, task_path)

0 commit comments

Comments
 (0)