Skip to content

Commit 93edf33

Browse files
committed
Adds async methods to MongoDBSaver
1 parent 7604eda commit 93edf33

File tree

3 files changed

+235
-92
lines changed

3 files changed

+235
-92
lines changed

libs/langgraph-checkpoint-mongodb/langgraph/checkpoint/mongodb/saver.py

Lines changed: 120 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
from collections.abc import Iterator, Sequence
1+
import asyncio
2+
from collections.abc import AsyncIterator, Iterator, Sequence
23
from contextlib import contextmanager
34
from datetime import datetime
45
from importlib.metadata import version
@@ -7,7 +8,7 @@
78
Optional,
89
)
910

10-
from langchain_core.runnables import RunnableConfig
11+
from langchain_core.runnables import RunnableConfig, run_in_executor
1112
from pymongo import ASCENDING, MongoClient, UpdateOne
1213
from pymongo.database import Database as MongoDatabase
1314
from pymongo.driver_info import DriverInfo
@@ -468,3 +469,120 @@ def delete_thread(
468469

469470
# Delete all writes associated with the thread ID
470471
self.writes_collection.delete_many({"thread_id": thread_id})
472+
473+
async def aget_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]:
474+
"""Asynchronously fetch a checkpoint tuple using the given configuration.
475+
476+
Asynchronously wraps the blocking `self.get_tuple` method.
477+
478+
Args:
479+
config: Configuration specifying which checkpoint to retrieve.
480+
481+
Returns:
482+
Optional[CheckpointTuple]: The requested checkpoint tuple, or None if not found.
483+
484+
"""
485+
return await run_in_executor(None, self.get_tuple, config)
486+
487+
async def alist(
488+
self,
489+
config: Optional[RunnableConfig],
490+
*,
491+
filter: Optional[dict[str, Any]] = None,
492+
before: Optional[RunnableConfig] = None,
493+
limit: Optional[int] = None,
494+
) -> AsyncIterator[CheckpointTuple]:
495+
"""Asynchronously list checkpoints that match the given criteria.
496+
497+
Asynchronously wraps the blocking `self.list` generator.
498+
499+
Runs `self.list(...)` in a background thread and yields its items
500+
asynchronously from an asyncio.Queue. This allows integration of
501+
synchronous iterators into async code.
502+
503+
Args:
504+
config: Configuration object passed to `self.list`.
505+
filter: Optional filter dictionary.
506+
before: Optional parameter to limit results before a given checkpoint.
507+
limit: Optional maximum number of results to yield.
508+
509+
Yields:
510+
AsyncIterator[CheckpointTuple]: An iterator of checkpoint tuples.
511+
"""
512+
loop = asyncio.get_running_loop()
513+
queue: asyncio.Queue[CheckpointTuple] = asyncio.Queue()
514+
sentinel = object()
515+
516+
def run() -> None:
517+
try:
518+
for item in self.list(
519+
config, filter=filter, before=before, limit=limit
520+
):
521+
loop.call_soon_threadsafe(queue.put_nowait, item)
522+
finally:
523+
loop.call_soon_threadsafe(queue.put_nowait, sentinel) # type: ignore
524+
525+
await run_in_executor(None, run)
526+
while True:
527+
item = await queue.get()
528+
if item is sentinel:
529+
break
530+
yield item
531+
532+
async def aput(
533+
self,
534+
config: RunnableConfig,
535+
checkpoint: Checkpoint,
536+
metadata: CheckpointMetadata,
537+
new_versions: ChannelVersions,
538+
) -> RunnableConfig:
539+
"""Asynchronously store a checkpoint with its configuration and metadata.
540+
541+
Asynchronously wraps the blocking `self.put` method.
542+
543+
Args:
544+
config: Configuration for the checkpoint.
545+
checkpoint: The checkpoint to store.
546+
metadata: Additional metadata for the checkpoint.
547+
new_versions: New channel versions as of this write.
548+
549+
Returns:
550+
RunnableConfig: Updated configuration after storing the checkpoint.
551+
"""
552+
return await run_in_executor(
553+
None, self.put, config, checkpoint, metadata, new_versions
554+
)
555+
556+
async def aput_writes(
557+
self,
558+
config: RunnableConfig,
559+
writes: Sequence[tuple[str, Any]],
560+
task_id: str,
561+
task_path: str = "",
562+
) -> None:
563+
"""Asynchronously store intermediate writes linked to a checkpoint.
564+
565+
Asynchronously wraps the blocking `self.put_writes` method.
566+
567+
Args:
568+
config: Configuration of the related checkpoint.
569+
writes: List of writes to store.
570+
task_id: Identifier for the task creating the writes.
571+
task_path: Path of the task creating the writes.
572+
"""
573+
return await run_in_executor(
574+
None, self.put_writes, config, writes, task_id, task_path
575+
)
576+
577+
async def adelete_thread(
578+
self,
579+
thread_id: str,
580+
) -> None:
581+
"""Delete all checkpoints and writes associated with a specific thread ID.
582+
583+
Asynchronously wraps the blocking `self.delete_thread` method.
584+
585+
Args:
586+
thread_id: The thread ID whose checkpoints should be deleted.
587+
"""
588+
return await run_in_executor(None, self.delete_thread, thread_id)

libs/langgraph-checkpoint-mongodb/pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ addopts = "--strict-markers --strict-config --durations=5 -vv"
4040
markers = [
4141
"requires: mark tests as requiring a specific library",
4242
"compile: mark placeholder test used to compile integration tests without running them",
43+
"asyncio: mark a test as asyncio",
4344
]
4445
asyncio_mode = "auto"
4546

libs/langgraph-checkpoint-mongodb/tests/unit_tests/test_async.py

Lines changed: 114 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -1,107 +1,131 @@
11
import os
2-
from typing import Any
2+
from collections.abc import AsyncGenerator
3+
from typing import Any, Union
34

45
import pytest
6+
import pytest_asyncio
57
from bson.errors import InvalidDocument
6-
from pymongo import AsyncMongoClient
8+
from pymongo import AsyncMongoClient, MongoClient
79

8-
from langgraph.checkpoint.mongodb.aio import AsyncMongoDBSaver
10+
from langgraph.checkpoint.mongodb import AsyncMongoDBSaver, MongoDBSaver
911

10-
MONGODB_URI = os.environ.get("MONGODB_URI", "mongodb://localhost:27017")
12+
MONGODB_URI = os.environ.get(
13+
"MONGODB_URI", "mongodb://localhost:27017/?directConnection=true"
14+
)
1115
DB_NAME = os.environ.get("DB_NAME", "langgraph-test")
1216
COLLECTION_NAME = "sync_checkpoints_aio"
1317

1418

15-
async def test_asearch(input_data: dict[str, Any]) -> None:
16-
# Clear collections if they exist
17-
client: AsyncMongoClient = AsyncMongoClient(MONGODB_URI)
18-
db = client[DB_NAME]
19-
20-
for clxn in await db.list_collection_names():
21-
await db.drop_collection(clxn)
22-
23-
async with AsyncMongoDBSaver.from_conn_string(
24-
MONGODB_URI, DB_NAME, COLLECTION_NAME
25-
) as saver:
26-
# save checkpoints
27-
await saver.aput(
28-
input_data["config_1"],
29-
input_data["chkpnt_1"],
30-
input_data["metadata_1"],
31-
{},
32-
)
33-
await saver.aput(
34-
input_data["config_2"],
35-
input_data["chkpnt_2"],
36-
input_data["metadata_2"],
37-
{},
38-
)
39-
await saver.aput(
40-
input_data["config_3"],
41-
input_data["chkpnt_3"],
42-
input_data["metadata_3"],
43-
{},
44-
)
45-
46-
# call method / assertions
47-
query_1 = {"source": "input"} # search by 1 key
48-
query_2 = {
49-
"step": 1,
50-
"writes": {"foo": "bar"},
51-
} # search by multiple keys
52-
query_3: dict[str, Any] = {} # search by no keys, return all checkpoints
53-
query_4 = {"source": "update", "step": 1} # no match
54-
55-
search_results_1 = [c async for c in saver.alist(None, filter=query_1)]
56-
assert len(search_results_1) == 1
57-
assert search_results_1[0].metadata == input_data["metadata_1"]
58-
59-
search_results_2 = [c async for c in saver.alist(None, filter=query_2)]
60-
assert len(search_results_2) == 1
61-
assert search_results_2[0].metadata == input_data["metadata_2"]
62-
63-
search_results_3 = [c async for c in saver.alist(None, filter=query_3)]
64-
assert len(search_results_3) == 3
65-
66-
search_results_4 = [c async for c in saver.alist(None, filter=query_4)]
67-
assert len(search_results_4) == 0
68-
69-
# search by config (defaults to checkpoints across all namespaces)
70-
search_results_5 = [
71-
c async for c in saver.alist({"configurable": {"thread_id": "thread-2"}})
72-
]
73-
assert len(search_results_5) == 2
74-
assert {
75-
search_results_5[0].config["configurable"]["checkpoint_ns"],
76-
search_results_5[1].config["configurable"]["checkpoint_ns"],
77-
} == {"", "inner"}
78-
79-
80-
async def test_null_chars(input_data: dict[str, Any]) -> None:
19+
@pytest_asyncio.fixture(params=["run_in_executor", "aio"])
20+
async def async_saver(request: pytest.FixtureRequest) -> AsyncGenerator:
21+
if request.param == "aio":
22+
# Use async client and checkpointer
23+
aclient: AsyncMongoClient = AsyncMongoClient(MONGODB_URI)
24+
adb = aclient[DB_NAME]
25+
for clxn in await adb.list_collection_names():
26+
await adb.drop_collection(clxn)
27+
async with AsyncMongoDBSaver.from_conn_string(
28+
MONGODB_URI, DB_NAME, COLLECTION_NAME
29+
) as checkpointer:
30+
yield checkpointer
31+
await aclient.close()
32+
else:
33+
# Use sync client and checkpointer with async methods run in executor
34+
client: MongoClient = MongoClient(MONGODB_URI)
35+
db = client[DB_NAME]
36+
for clxn in db.list_collection_names():
37+
db.drop_collection(clxn)
38+
with MongoDBSaver.from_conn_string(
39+
MONGODB_URI, DB_NAME, COLLECTION_NAME
40+
) as checkpointer:
41+
yield checkpointer
42+
client.close()
43+
44+
45+
@pytest.mark.asyncio
46+
async def test_asearch(
47+
input_data: dict[str, Any], async_saver: Union[AsyncMongoDBSaver, MongoDBSaver]
48+
) -> None:
49+
# save checkpoints
50+
await async_saver.aput(
51+
input_data["config_1"],
52+
input_data["chkpnt_1"],
53+
input_data["metadata_1"],
54+
{},
55+
)
56+
await async_saver.aput(
57+
input_data["config_2"],
58+
input_data["chkpnt_2"],
59+
input_data["metadata_2"],
60+
{},
61+
)
62+
await async_saver.aput(
63+
input_data["config_3"],
64+
input_data["chkpnt_3"],
65+
input_data["metadata_3"],
66+
{},
67+
)
68+
69+
# call method / assertions
70+
query_1 = {"source": "input"} # search by 1 key
71+
query_2 = {
72+
"step": 1,
73+
"writes": {"foo": "bar"},
74+
} # search by multiple keys
75+
query_3: dict[str, Any] = {} # search by no keys, return all checkpoints
76+
query_4 = {"source": "update", "step": 1} # no match
77+
78+
search_results_1 = [c async for c in async_saver.alist(None, filter=query_1)]
79+
assert len(search_results_1) == 1
80+
assert search_results_1[0].metadata == input_data["metadata_1"]
81+
82+
search_results_2 = [c async for c in async_saver.alist(None, filter=query_2)]
83+
assert len(search_results_2) == 1
84+
assert search_results_2[0].metadata == input_data["metadata_2"]
85+
86+
search_results_3 = [c async for c in async_saver.alist(None, filter=query_3)]
87+
assert len(search_results_3) == 3
88+
89+
search_results_4 = [c async for c in async_saver.alist(None, filter=query_4)]
90+
assert len(search_results_4) == 0
91+
92+
# search by config (defaults to checkpoints across all namespaces)
93+
search_results_5 = [
94+
c async for c in async_saver.alist({"configurable": {"thread_id": "thread-2"}})
95+
]
96+
assert len(search_results_5) == 2
97+
assert {
98+
search_results_5[0].config["configurable"]["checkpoint_ns"],
99+
search_results_5[1].config["configurable"]["checkpoint_ns"],
100+
} == {"", "inner"}
101+
102+
103+
@pytest.mark.asyncio
104+
async def test_null_chars(
105+
input_data: dict[str, Any], async_saver: Union[AsyncMongoDBSaver, MongoDBSaver]
106+
) -> None:
81107
"""In MongoDB string *values* can be any valid UTF-8 including nulls.
82108
*Field names*, however, cannot contain nulls characters."""
83-
async with AsyncMongoDBSaver.from_conn_string(
84-
MONGODB_URI, DB_NAME, COLLECTION_NAME
85-
) as saver:
86-
null_str = "\x00abc" # string containing null character
87109

88-
# 1. null string in field *value*
89-
null_value_cfg = await saver.aput(
110+
null_str = "\x00abc" # string containing null character
111+
112+
# 1. null string in field *value*
113+
null_value_cfg = await async_saver.aput(
114+
input_data["config_1"],
115+
input_data["chkpnt_1"],
116+
{"my_key": null_str},
117+
{},
118+
)
119+
null_tuple = await async_saver.aget_tuple(null_value_cfg)
120+
assert null_tuple.metadata["my_key"] == null_str # type: ignore
121+
cps = [c async for c in async_saver.alist(None, filter={"my_key": null_str})]
122+
assert cps[0].metadata["my_key"] == null_str
123+
124+
# 2. null string in field *name*
125+
with pytest.raises(InvalidDocument):
126+
await async_saver.aput(
90127
input_data["config_1"],
91128
input_data["chkpnt_1"],
92-
{"my_key": null_str},
129+
{null_str: "my_value"}, # type: ignore
93130
{},
94131
)
95-
null_tuple = await saver.aget_tuple(null_value_cfg)
96-
assert null_tuple.metadata["my_key"] == null_str # type: ignore
97-
cps = [c async for c in saver.alist(None, filter={"my_key": null_str})]
98-
assert cps[0].metadata["my_key"] == null_str
99-
100-
# 2. null string in field *name*
101-
with pytest.raises(InvalidDocument):
102-
await saver.aput(
103-
input_data["config_1"],
104-
input_data["chkpnt_1"],
105-
{null_str: "my_value"}, # type: ignore
106-
{},
107-
)

0 commit comments

Comments
 (0)