1515import operator
1616import os
1717import time
18- from collections .abc import Generator
19- from typing import Annotated , TypedDict
18+ from collections .abc import AsyncGenerator , Generator
19+ from typing import Annotated
2020
2121import pytest
22+ from langchain_core .runnables import RunnableConfig
23+ from typing_extensions import TypedDict
2224
2325from langgraph .checkpoint .base import BaseCheckpointSaver
2426from langgraph .checkpoint .memory import InMemorySaver
@@ -55,21 +57,21 @@ class JokeState(JokeInput, JokeOutput): ...
5557
5658def fanout_to_subgraph () -> StateGraph :
5759 # Subgraph nodes create a joke.
58- def edit (state : JokeInput ):
60+ def edit (state : JokeInput ) -> dict [ str , str ] :
5961 return {"subject" : f"{ state ["subject" ]} , and cats" }
6062
61- def generate (state : JokeInput ):
63+ def generate (state : JokeInput ) -> dict [ str , list [ str ]] :
6264 return {"jokes" : [f"Joke about the year { state ['subject' ]} " ]}
6365
64- def bump (state : JokeOutput ):
66+ def bump (state : JokeOutput ) -> dict [ str , list [ str ]] :
6567 return {"jokes" : [state ["jokes" ][0 ] + " and the year before" ]}
6668
67- def bump_loop (state : JokeOutput ):
69+ def bump_loop (state : JokeOutput ) -> JokeOutput :
6870 return (
6971 END if state ["jokes" ][0 ].endswith (" and the year before" * 10 ) else "bump"
7072 )
7173
72- subgraph = StateGraph (JokeState , joke_subjects = JokeInput , output = JokeOutput )
74+ subgraph = StateGraph (JokeState )
7375 subgraph .add_node ("edit" , edit )
7476 subgraph .add_node ("generate" , generate )
7577 subgraph .add_node ("bump" , bump )
@@ -82,18 +84,18 @@ def bump_loop(state: JokeOutput):
8284 subgraphc = subgraph .compile ()
8385
8486 # Parent graph maps the joke-generating subgraph.
85- def fanout (state : OverallState ):
87+ def fanout (state : OverallState ) -> list :
8688 return [Send ("generate_joke" , {"subject" : s }) for s in state ["subjects" ]]
8789
8890 parentgraph = StateGraph (OverallState )
89- parentgraph .add_node ("generate_joke" , subgraphc )
91+ parentgraph .add_node ("generate_joke" , subgraphc ) # type: ignore[arg-type]
9092 parentgraph .add_conditional_edges (START , fanout )
9193 parentgraph .add_edge ("generate_joke" , END )
9294 return parentgraph
9395
9496
9597@pytest .fixture
96- def joke_subjects ():
98+ def joke_subjects () -> OverallState :
9799 years = [str (2025 - 10 * i ) for i in range (N_SUBJECTS )]
98100 return {"subjects" : years }
99101
@@ -119,29 +121,32 @@ def checkpointer_mongodb() -> Generator[MongoDBSaver, None, None]:
119121
120122
121123@pytest .fixture (scope = "function" )
122- async def checkpointer_mongodb_async () -> Generator [AsyncMongoDBSaver , None , None ]:
124+ async def checkpointer_mongodb_async () -> AsyncGenerator [AsyncMongoDBSaver , None ]:
123125 async with AsyncMongoDBSaver .from_conn_string (
124126 MONGODB_URI ,
125127 db_name = DB_NAME ,
126128 checkpoint_collection_name = CHECKPOINT_CLXN_NAME + "_async" ,
127129 writes_collection_name = WRITES_CLXN_NAME + "_async" ,
128130 ) as checkpointer :
129- checkpointer .checkpoint_collection .delete_many ({})
130- checkpointer .writes_collection .delete_many ({})
131+ await checkpointer .checkpoint_collection .delete_many ({})
132+ await checkpointer .writes_collection .delete_many ({})
131133 yield checkpointer
132- checkpointer .checkpoint_collection .drop ()
133- checkpointer .writes_collection .drop ()
134+ await checkpointer .checkpoint_collection .drop ()
135+ await checkpointer .writes_collection .drop ()
134136
135137
136138@pytest .fixture (autouse = True )
137- def disable_langsmith ():
139+ def disable_langsmith () -> None :
138140 """Disable LangSmith tracing for all tests"""
139141 os .environ ["LANGCHAIN_TRACING_V2" ] = "false"
140142 os .environ ["LANGCHAIN_API_KEY" ] = ""
141143
142144
143145async def test_fanout (
144- joke_subjects , checkpointer_mongodb , checkpointer_mongodb_async , checkpointer_memory
146+ joke_subjects : OverallState ,
147+ checkpointer_mongodb : MongoDBSaver ,
148+ checkpointer_mongodb_async : AsyncMongoDBSaver ,
149+ checkpointer_memory : InMemorySaver ,
145150) -> None :
146151 checkpointers = {
147152 "mongodb" : checkpointer_mongodb ,
@@ -154,12 +159,12 @@ async def test_fanout(
154159 assert isinstance (checkpointer , BaseCheckpointSaver )
155160 print (f"\n \n Begin test of { cname } " )
156161 graphc = (fanout_to_subgraph ()).compile (checkpointer = checkpointer )
157- config = {"configurable" : {"thread_id" : cname }}
162+ config : RunnableConfig = {"configurable" : {"thread_id" : cname }}
158163 start = time .monotonic ()
159164 if "async" in cname :
160- out = [c async for c in graphc .astream (joke_subjects , config = config )]
165+ out = [c async for c in graphc .astream (joke_subjects , config = config )] # type: ignore[arg-type]
161166 else :
162- out = [c for c in graphc .stream (joke_subjects , config = config )]
167+ out = [c for c in graphc .stream (joke_subjects , config = config )] # type: ignore[arg-type]
163168 assert len (out ) == N_SUBJECTS
164169 assert isinstance (out [0 ], dict )
165170 assert out [0 ].keys () == {"generate_joke" }
0 commit comments