@@ -54,6 +54,43 @@ class JokeOutput(TypedDict):
5454class JokeState (JokeInput , JokeOutput ): ...
5555
5656
57+ def fanout_to_subgraph () -> StateGraph :
58+ # Subgraph nodes create a joke.
59+ def edit (state : JokeInput ):
60+ return {"subject" : f"{ state ["subject" ]} , and cats" }
61+
62+ def generate (state : JokeInput ):
63+ return {"jokes" : [f"Joke about the year { state ['subject' ]} " ]}
64+
65+ def bump (state : JokeOutput ):
66+ return {"jokes" : [state ["jokes" ][0 ] + " and another" ]}
67+
68+ def bump_loop (state : JokeOutput ):
69+ return END if state ["jokes" ][0 ].endswith (" and another" * 10 ) else "bump"
70+
71+ subgraph = StateGraph (JokeState , joke_subjects = JokeInput , output = JokeOutput )
72+ subgraph .add_node ("edit" , edit )
73+ subgraph .add_node ("generate" , generate )
74+ subgraph .add_node ("bump" , bump )
75+ subgraph .set_entry_point ("edit" )
76+ subgraph .add_edge ("edit" , "generate" )
77+ subgraph .add_edge ("generate" , "bump" )
78+ subgraph .add_node ("bump_loop" , bump_loop )
79+ subgraph .add_conditional_edges ("bump" , bump_loop )
80+ subgraph .set_finish_point ("generate" )
81+ subgraphc = subgraph .compile ()
82+
83+ # Parent graph maps the joke-generating subgraph.
84+ def fanout (state : OverallState ):
85+ return [Send ("generate_joke" , {"subject" : s }) for s in state ["subjects" ]]
86+
87+ parentgraph = StateGraph (OverallState )
88+ parentgraph .add_node ("generate_joke" , subgraphc )
89+ parentgraph .add_conditional_edges (START , fanout )
90+ parentgraph .add_edge ("generate_joke" , END )
91+ return parentgraph
92+
93+
5794@pytest .fixture
5895def joke_subjects ():
5996 years = [str (2025 - 10 * i ) for i in range (N_SUBJECTS )]
@@ -112,42 +149,6 @@ def test_sync(
112149 "in_memory" : checkpointer_memory ,
113150 }
114151
115- def fanout_to_subgraph () -> StateGraph :
116- # Subgraph nodes create a joke
117- def edit (state : JokeInput ):
118- return {"subject" : f"{ state ["subject" ]} , and cats" }
119-
120- def generate (state : JokeInput ):
121- return {"jokes" : [f"Joke about the year { state ['subject' ]} " ]}
122-
123- def bump (state : JokeOutput ):
124- return {"jokes" : [state ["jokes" ][0 ] + " and another" ]}
125-
126- def bump_loop (state : JokeOutput ):
127- return END if state ["jokes" ][0 ].endswith (" and another" * 10 ) else "bump"
128-
129- subgraph = StateGraph (JokeState , joke_subjects = JokeInput , output = JokeOutput )
130- subgraph .add_node ("edit" , edit )
131- subgraph .add_node ("generate" , generate )
132- subgraph .add_node ("bump" , bump )
133- subgraph .set_entry_point ("edit" )
134- subgraph .add_edge ("edit" , "generate" )
135- subgraph .add_edge ("generate" , "bump" )
136- subgraph .add_node ("bump_loop" , bump_loop )
137- subgraph .add_conditional_edges ("bump" , bump_loop )
138- subgraph .set_finish_point ("generate" )
139- subgraphc = subgraph .compile ()
140-
141- # parent graph maps the joke-generating subgraph
142- def fanout (state : OverallState ):
143- return [Send ("generate_joke" , {"subject" : s }) for s in state ["subjects" ]]
144-
145- parentgraph = StateGraph (OverallState )
146- parentgraph .add_node ("generate_joke" , subgraphc )
147- parentgraph .add_conditional_edges (START , fanout )
148- parentgraph .add_edge ("generate_joke" , END )
149- return parentgraph
150-
151152 print ("\n \n Begin test_sync" )
152153 for cname , checkpointer in checkpointers .items ():
153154 assert isinstance (checkpointer , BaseCheckpointSaver )
@@ -173,47 +174,11 @@ async def test_async(
173174 "in_memory_async" : checkpointer_memory ,
174175 }
175176
176- async def fanout_to_subgraph () -> StateGraph :
177- # Subgraph nodes create a joke
178- async def edit (state : JokeInput ):
179- subject = state ["subject" ]
180- return {"subject" : f"{ subject } , and cats" }
181-
182- async def generate (state : JokeInput ):
183- return {"jokes" : [f"Joke about the year { state ['subject' ]} " ]}
184-
185- async def bump (state : JokeOutput ):
186- return {"jokes" : [state ["jokes" ][0 ] + " and another" ]}
187-
188- async def bump_loop (state : JokeOutput ):
189- return END if state ["jokes" ][0 ].endswith (" and another" * 10 ) else "bump"
190-
191- subgraph = StateGraph (JokeState , joke_subjects = JokeInput , output = JokeOutput )
192- subgraph .add_node ("edit" , edit )
193- subgraph .add_node ("generate" , generate )
194- subgraph .add_node ("bump" , bump )
195- subgraph .set_entry_point ("edit" )
196- subgraph .add_edge ("edit" , "generate" )
197- subgraph .add_edge ("generate" , "bump" )
198- subgraph .add_conditional_edges ("bump" , bump_loop )
199- subgraph .set_finish_point ("generate" )
200- subgraphc = subgraph .compile ()
201-
202- # parent graph maps the joke-generating subgraph
203- async def fanout (state : OverallState ):
204- return [Send ("generate_joke" , {"subject" : s }) for s in state ["subjects" ]]
205-
206- parentgraph = StateGraph (OverallState )
207- parentgraph .add_node ("generate_joke" , subgraphc )
208- parentgraph .add_conditional_edges (START , fanout )
209- parentgraph .add_edge ("generate_joke" , END )
210- return parentgraph
211-
212177 print ("\n \n Begin test_async" )
213178 for cname , checkpointer in checkpointers .items ():
214179 assert isinstance (checkpointer , BaseCheckpointSaver )
215180
216- graphc = (await fanout_to_subgraph ()).compile (checkpointer = checkpointer )
181+ graphc = (fanout_to_subgraph ()).compile (checkpointer = checkpointer )
217182 config = {"configurable" : {"thread_id" : cname }}
218183 start = time .monotonic ()
219184 out = [c async for c in graphc .astream (joke_subjects , config = config )]
0 commit comments