Skip to content

Commit f20f981

Browse files
committed
Typing
1 parent d25325d commit f20f981

File tree

3 files changed

+23
-36
lines changed

3 files changed

+23
-36
lines changed

libs/langgraph-checkpoint-mongodb/langgraph/checkpoint/mongodb/aio.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -359,7 +359,7 @@ async def aput(
359359
type_, serialized_checkpoint = self.serde.dumps_typed(checkpoint)
360360
metadata = metadata.copy()
361361
metadata.update(config.get("metadata", {}))
362-
doc = {
362+
doc: dict[str, Any] = {
363363
"parent_checkpoint_id": config["configurable"].get("checkpoint_id"),
364364
"type": type_,
365365
"checkpoint": serialized_checkpoint,
@@ -423,7 +423,7 @@ async def aput_writes(
423423

424424
type_, serialized_value = self.serde.dumps_typed(value)
425425

426-
update_doc = {
426+
update_doc: dict[str, Any] = {
427427
"channel": channel,
428428
"type": type_,
429429
"value": serialized_value,

libs/langgraph-checkpoint-mongodb/langgraph/checkpoint/mongodb/saver.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -438,7 +438,7 @@ def put_writes(
438438

439439
type_, serialized_value = self.serde.dumps_typed(value)
440440

441-
update_doc = {
441+
update_doc: dict[str, Any] = {
442442
"channel": channel,
443443
"type": type_,
444444
"value": serialized_value,

libs/langgraph-checkpoint-mongodb/tests/unit_tests/test_interrupt.py

Lines changed: 20 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,10 @@
33

44
import pytest
55
import pytest_asyncio
6-
from pymongo import AsyncMongoClient, MongoClient
6+
from langchain_core.runnables import RunnableConfig
7+
from pymongo import MongoClient
78

8-
from langgraph.checkpoint.mongodb import AsyncMongoDBSaver, MongoDBSaver
9+
from langgraph.checkpoint.mongodb import MongoDBSaver
910
from langgraph.types import Interrupt
1011

1112
MONGODB_URI = os.environ.get(
@@ -17,35 +18,23 @@
1718
TTL: int = 60 * 60
1819

1920

20-
@pytest_asyncio.fixture(params=["run_in_executor", "aio"])
21+
@pytest_asyncio.fixture()
2122
async def async_saver(request: pytest.FixtureRequest) -> AsyncGenerator:
22-
if request.param == "aio":
23-
# Use async client and checkpointer
24-
aclient: AsyncMongoClient = AsyncMongoClient(MONGODB_URI)
25-
adb = aclient[DB_NAME]
26-
for clxn in await adb.list_collection_names():
27-
await adb.drop_collection(clxn)
28-
async with AsyncMongoDBSaver.from_conn_string(
29-
MONGODB_URI, DB_NAME, COLLECTION_NAME, WRITES_COLLECTION_NAME, TTL
30-
) as checkpointer:
31-
yield checkpointer
32-
await aclient.close()
33-
else:
34-
# Use sync client and checkpointer with async methods run in executor
35-
client: MongoClient = MongoClient(MONGODB_URI)
36-
db = client[DB_NAME]
37-
for clxn in db.list_collection_names():
38-
db.drop_collection(clxn)
39-
with MongoDBSaver.from_conn_string(
40-
MONGODB_URI, DB_NAME, COLLECTION_NAME, WRITES_COLLECTION_NAME, TTL
41-
) as checkpointer:
42-
yield checkpointer
43-
client.close()
44-
45-
46-
def test_put_writes_on_interrupt(async_saver: MongoDBSaver):
23+
# Use sync client and checkpointer with async methods run in executor
24+
client: MongoClient = MongoClient(MONGODB_URI)
25+
db = client[DB_NAME]
26+
for clxn in db.list_collection_names():
27+
db.drop_collection(clxn)
28+
with MongoDBSaver.from_conn_string(
29+
MONGODB_URI, DB_NAME, COLLECTION_NAME, WRITES_COLLECTION_NAME, TTL
30+
) as checkpointer:
31+
yield checkpointer
32+
client.close()
33+
34+
35+
async def test_put_writes_on_interrupt(async_saver: MongoDBSaver) -> None:
4736
"""Test that no error is raised when interrupted workflow updates writes."""
48-
config = {
37+
config: RunnableConfig = {
4938
"configurable": {
5039
"checkpoint_id": "check1",
5140
"thread_id": "thread1",
@@ -61,13 +50,11 @@ def test_put_writes_on_interrupt(async_saver: MongoDBSaver):
6150
(
6251
Interrupt(
6352
value="please provide input",
64-
resumable=True,
65-
ns=["human_feedback:1b798da3"],
6653
),
6754
),
6855
)
6956
]
70-
async_saver.aput_writes(config, writes1, task_id, task_path)
57+
await async_saver.aput_writes(config, writes1, task_id, task_path)
7158

7259
writes2 = [("__interrupt__", (Interrupt(value="please provide another input"),))]
73-
async_saver.aput_writes(config, writes2, task_id, task_path)
60+
await async_saver.aput_writes(config, writes2, task_id, task_path)

0 commit comments

Comments
 (0)