99
1010import inspect
1111from collections import Counter , defaultdict
12- from collections .abc import Callable , Iterable
12+ from collections .abc import AsyncIterable , Callable , Iterable
1313from dataclasses import dataclass , replace
1414from types import NoneType
1515from typing import Any , Generic , Literal , cast , get_origin , get_type_hints , overload
4444 Path ,
4545 PathBuilder ,
4646)
47- from pydantic_graph .beta .step import NodeStep , Step , StepFunction , StepNode
47+ from pydantic_graph .beta .step import NodeStep , Step , StepAsyncIteratorFunction , StepContext , StepFunction , StepNode
4848from pydantic_graph .beta .util import TypeOrTypeExpression , get_callable_name , unpack_type_expression
4949from pydantic_graph .exceptions import GraphBuildingError
5050from pydantic_graph .nodes import BaseNode , End
@@ -162,21 +162,21 @@ def end_node(self) -> EndNode[GraphOutputT]:
162162 return self ._end_node
163163
164164 @overload
165- def _step (
165+ def step (
166166 self ,
167167 * ,
168168 node_id : str | None = None ,
169169 label : str | None = None ,
170170 ) -> Callable [[StepFunction [StateT , DepsT , InputT , OutputT ]], Step [StateT , DepsT , InputT , OutputT ]]: ...
171171 @overload
172- def _step (
172+ def step (
173173 self ,
174174 call : StepFunction [StateT , DepsT , InputT , OutputT ],
175175 * ,
176176 node_id : str | None = None ,
177177 label : str | None = None ,
178178 ) -> Step [StateT , DepsT , InputT , OutputT ]: ...
179- def _step (
179+ def step (
180180 self ,
181181 call : StepFunction [StateT , DepsT , InputT , OutputT ] | None = None ,
182182 * ,
@@ -186,10 +186,10 @@ def _step(
186186 Step [StateT , DepsT , InputT , OutputT ]
187187 | Callable [[StepFunction [StateT , DepsT , InputT , OutputT ]], Step [StateT , DepsT , InputT , OutputT ]]
188188 ):
189- """Create a step from a step function (internal implementation) .
189+ """Create a step from a step function.
190190
191- This internal method handles the actual step creation logic and
192- automatic edge inference from type hints .
191+ This method can be used as a decorator or called directly to create
192+ a step node from an async function .
193193
194194 Args:
195195 call: The step function to wrap
@@ -204,7 +204,7 @@ def _step(
204204 def decorator (
205205 func : StepFunction [StateT , DepsT , InputT , OutputT ],
206206 ) -> Step [StateT , DepsT , InputT , OutputT ]:
207- return self ._step (call = func , node_id = node_id , label = label )
207+ return self .step (call = func , node_id = node_id , label = label )
208208
209209 return decorator
210210
@@ -215,29 +215,48 @@ def decorator(
215215 return step
216216
217217 @overload
218- def step (
218+ def step_async_iterable (
219219 self ,
220220 * ,
221221 node_id : str | None = None ,
222222 label : str | None = None ,
223- ) -> Callable [[StepFunction [StateT , DepsT , InputT , OutputT ]], Step [StateT , DepsT , InputT , OutputT ]]: ...
223+ ) -> Callable [
224+ [StepAsyncIteratorFunction [StateT , DepsT , InputT , OutputT ]], Step [StateT , DepsT , InputT , AsyncIterable [OutputT ]]
225+ ]: ...
224226 @overload
225- def step (
227+ def step_async_iterable (
226228 self ,
227- call : StepFunction [StateT , DepsT , InputT , OutputT ],
229+ call : StepAsyncIteratorFunction [StateT , DepsT , InputT , OutputT ],
228230 * ,
229231 node_id : str | None = None ,
230232 label : str | None = None ,
231- ) -> Step [StateT , DepsT , InputT , OutputT ]: ...
232- def step (
233+ ) -> Step [StateT , DepsT , InputT , AsyncIterable [OutputT ]]: ...
234+ @overload
235+ def step_async_iterable (
233236 self ,
234- call : StepFunction [StateT , DepsT , InputT , OutputT ] | None = None ,
237+ call : StepAsyncIteratorFunction [StateT , DepsT , InputT , OutputT ] | None = None ,
235238 * ,
236239 node_id : str | None = None ,
237240 label : str | None = None ,
238241 ) -> (
239- Step [StateT , DepsT , InputT , OutputT ]
240- | Callable [[StepFunction [StateT , DepsT , InputT , OutputT ]], Step [StateT , DepsT , InputT , OutputT ]]
242+ Step [StateT , DepsT , InputT , AsyncIterable [OutputT ]]
243+ | Callable [
244+ [StepAsyncIteratorFunction [StateT , DepsT , InputT , OutputT ]],
245+ Step [StateT , DepsT , InputT , AsyncIterable [OutputT ]],
246+ ]
247+ ): ...
248+ def step_async_iterable (
249+ self ,
250+ call : StepAsyncIteratorFunction [StateT , DepsT , InputT , OutputT ] | None = None ,
251+ * ,
252+ node_id : str | None = None ,
253+ label : str | None = None ,
254+ ) -> (
255+ Step [StateT , DepsT , InputT , AsyncIterable [OutputT ]]
256+ | Callable [
257+ [StepAsyncIteratorFunction [StateT , DepsT , InputT , OutputT ]],
258+ Step [StateT , DepsT , InputT , AsyncIterable [OutputT ]],
259+ ]
241260 ):
242261 """Create a step from a step function.
243262
@@ -253,9 +272,19 @@ def step(
253272 Either a Step instance or a decorator function
254273 """
255274 if call is None :
256- return self ._step (node_id = node_id , label = label )
257- else :
258- return self ._step (call = call , node_id = node_id , label = label )
275+
276+ def decorator (
277+ func : StepAsyncIteratorFunction [StateT , DepsT , InputT , OutputT ],
278+ ) -> Step [StateT , DepsT , InputT , AsyncIterable [OutputT ]]:
279+ return self .step_async_iterable (call = func , node_id = node_id , label = label )
280+
281+ return decorator
282+
283+ # We need to wrap the call so that we can call `await` even though the result is an async iterator
284+ async def wrapper (ctx : StepContext [StateT , DepsT , InputT ]):
285+ return call (ctx )
286+
287+ return self .step (call = wrapper , node_id = node_id , label = label )
259288
260289 @overload
261290 def join (
0 commit comments