33
44import pytest
55import pytest_asyncio
6- from pymongo import AsyncMongoClient , MongoClient
6+ from langchain_core .runnables import RunnableConfig
7+ from pymongo import MongoClient
78
8- from langgraph .checkpoint .mongodb import AsyncMongoDBSaver , MongoDBSaver
9+ from langgraph .checkpoint .mongodb import MongoDBSaver
910from langgraph .types import Interrupt
1011
1112MONGODB_URI = os .environ .get (
1718TTL : int = 60 * 60
1819
1920
20- @pytest_asyncio .fixture (params = [ "run_in_executor" , "aio" ] )
21+ @pytest_asyncio .fixture ()
2122async def async_saver (request : pytest .FixtureRequest ) -> AsyncGenerator :
22- if request .param == "aio" :
23- # Use async client and checkpointer
24- aclient : AsyncMongoClient = AsyncMongoClient (MONGODB_URI )
25- adb = aclient [DB_NAME ]
26- for clxn in await adb .list_collection_names ():
27- await adb .drop_collection (clxn )
28- async with AsyncMongoDBSaver .from_conn_string (
29- MONGODB_URI , DB_NAME , COLLECTION_NAME , WRITES_COLLECTION_NAME , TTL
30- ) as checkpointer :
31- yield checkpointer
32- await aclient .close ()
33- else :
34- # Use sync client and checkpointer with async methods run in executor
35- client : MongoClient = MongoClient (MONGODB_URI )
36- db = client [DB_NAME ]
37- for clxn in db .list_collection_names ():
38- db .drop_collection (clxn )
39- with MongoDBSaver .from_conn_string (
40- MONGODB_URI , DB_NAME , COLLECTION_NAME , WRITES_COLLECTION_NAME , TTL
41- ) as checkpointer :
42- yield checkpointer
43- client .close ()
44-
45-
46- def test_put_writes_on_interrupt (async_saver : MongoDBSaver ):
23+ # Use sync client and checkpointer with async methods run in executor
24+ client : MongoClient = MongoClient (MONGODB_URI )
25+ db = client [DB_NAME ]
26+ for clxn in db .list_collection_names ():
27+ db .drop_collection (clxn )
28+ with MongoDBSaver .from_conn_string (
29+ MONGODB_URI , DB_NAME , COLLECTION_NAME , WRITES_COLLECTION_NAME , TTL
30+ ) as checkpointer :
31+ yield checkpointer
32+ client .close ()
33+
34+
35+ async def test_put_writes_on_interrupt (async_saver : MongoDBSaver ) -> None :
4736 """Test that no error is raised when interrupted workflow updates writes."""
48- config = {
37+ config : RunnableConfig = {
4938 "configurable" : {
5039 "checkpoint_id" : "check1" ,
5140 "thread_id" : "thread1" ,
@@ -61,13 +50,11 @@ def test_put_writes_on_interrupt(async_saver: MongoDBSaver):
6150 (
6251 Interrupt (
6352 value = "please provide input" ,
64- resumable = True ,
65- ns = ["human_feedback:1b798da3" ],
6653 ),
6754 ),
6855 )
6956 ]
70- async_saver .aput_writes (config , writes1 , task_id , task_path )
57+ await async_saver .aput_writes (config , writes1 , task_id , task_path )
7158
7259 writes2 = [("__interrupt__" , (Interrupt (value = "please provide another input" ),))]
73- async_saver .aput_writes (config , writes2 , task_id , task_path )
60+ await async_saver .aput_writes (config , writes2 , task_id , task_path )
0 commit comments