Skip to content

Commit dc1508b

Browse files
committed
Try to improve coverage
1 parent 4eb1a0b commit dc1508b

File tree

6 files changed

+282
-37
lines changed

6 files changed

+282
-37
lines changed

pydantic_graph/pydantic_graph/beta/graph.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -512,7 +512,7 @@ def __post_init__(self):
512512
@property
513513
def task_group(self) -> TaskGroup:
514514
if self._task_group is None:
515-
raise RuntimeError("This graph iterator hasn't been started")
515+
raise RuntimeError("This graph iterator hasn't been started") # pragma: no cover
516516
return self._task_group
517517

518518
async def iter_graph( # noqa C901
@@ -528,7 +528,7 @@ async def iter_graph( # noqa C901
528528
# Handle task results
529529
async with self.iter_stream_receiver:
530530
while self.active_tasks or self.active_reducers:
531-
async for task_result in self.iter_stream_receiver:
531+
async for task_result in self.iter_stream_receiver: # pragma: no branch
532532
# If we encounter a mock task, add it to the active tasks to ensure we don't proceed until everything downstream is handled
533533
if (
534534
not task_result.source_is_finished
@@ -653,11 +653,10 @@ async def iter_graph( # noqa C901
653653
'Graph run completed, but no result was produced. This is either a bug in the graph or a bug in the graph runner.'
654654
)
655655

656-
async def _finish_task(self, task_id: TaskID, keep_cancel_scope: bool = False) -> None:
657-
if not keep_cancel_scope:
658-
scope = self.cancel_scopes.pop(task_id, None)
659-
if scope is not None:
660-
scope.cancel()
656+
async def _finish_task(self, task_id: TaskID) -> None:
657+
scope = self.cancel_scopes.pop(task_id, None)
658+
if scope is not None:
659+
scope.cancel()
661660
self.active_tasks.pop(task_id, None)
662661

663662
def _handle_execution_request(self, request: Sequence[GraphTask]) -> None:

pydantic_graph/pydantic_graph/beta/graph_builder.py

Lines changed: 50 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
import inspect
1111
from collections import Counter, defaultdict
12-
from collections.abc import Callable, Iterable
12+
from collections.abc import AsyncIterable, Callable, Iterable
1313
from dataclasses import dataclass, replace
1414
from types import NoneType
1515
from typing import Any, Generic, Literal, cast, get_origin, get_type_hints, overload
@@ -44,7 +44,7 @@
4444
Path,
4545
PathBuilder,
4646
)
47-
from pydantic_graph.beta.step import NodeStep, Step, StepFunction, StepNode
47+
from pydantic_graph.beta.step import NodeStep, Step, StepAsyncIteratorFunction, StepContext, StepFunction, StepNode
4848
from pydantic_graph.beta.util import TypeOrTypeExpression, get_callable_name, unpack_type_expression
4949
from pydantic_graph.exceptions import GraphBuildingError
5050
from pydantic_graph.nodes import BaseNode, End
@@ -162,21 +162,21 @@ def end_node(self) -> EndNode[GraphOutputT]:
162162
return self._end_node
163163

164164
@overload
165-
def _step(
165+
def step(
166166
self,
167167
*,
168168
node_id: str | None = None,
169169
label: str | None = None,
170170
) -> Callable[[StepFunction[StateT, DepsT, InputT, OutputT]], Step[StateT, DepsT, InputT, OutputT]]: ...
171171
@overload
172-
def _step(
172+
def step(
173173
self,
174174
call: StepFunction[StateT, DepsT, InputT, OutputT],
175175
*,
176176
node_id: str | None = None,
177177
label: str | None = None,
178178
) -> Step[StateT, DepsT, InputT, OutputT]: ...
179-
def _step(
179+
def step(
180180
self,
181181
call: StepFunction[StateT, DepsT, InputT, OutputT] | None = None,
182182
*,
@@ -186,10 +186,10 @@ def _step(
186186
Step[StateT, DepsT, InputT, OutputT]
187187
| Callable[[StepFunction[StateT, DepsT, InputT, OutputT]], Step[StateT, DepsT, InputT, OutputT]]
188188
):
189-
"""Create a step from a step function (internal implementation).
189+
"""Create a step from a step function.
190190
191-
This internal method handles the actual step creation logic and
192-
automatic edge inference from type hints.
191+
This method can be used as a decorator or called directly to create
192+
a step node from an async function.
193193
194194
Args:
195195
call: The step function to wrap
@@ -204,7 +204,7 @@ def _step(
204204
def decorator(
205205
func: StepFunction[StateT, DepsT, InputT, OutputT],
206206
) -> Step[StateT, DepsT, InputT, OutputT]:
207-
return self._step(call=func, node_id=node_id, label=label)
207+
return self.step(call=func, node_id=node_id, label=label)
208208

209209
return decorator
210210

@@ -215,29 +215,48 @@ def decorator(
215215
return step
216216

217217
@overload
218-
def step(
218+
def step_async_iterable(
219219
self,
220220
*,
221221
node_id: str | None = None,
222222
label: str | None = None,
223-
) -> Callable[[StepFunction[StateT, DepsT, InputT, OutputT]], Step[StateT, DepsT, InputT, OutputT]]: ...
223+
) -> Callable[
224+
[StepAsyncIteratorFunction[StateT, DepsT, InputT, OutputT]], Step[StateT, DepsT, InputT, AsyncIterable[OutputT]]
225+
]: ...
224226
@overload
225-
def step(
227+
def step_async_iterable(
226228
self,
227-
call: StepFunction[StateT, DepsT, InputT, OutputT],
229+
call: StepAsyncIteratorFunction[StateT, DepsT, InputT, OutputT],
228230
*,
229231
node_id: str | None = None,
230232
label: str | None = None,
231-
) -> Step[StateT, DepsT, InputT, OutputT]: ...
232-
def step(
233+
) -> Step[StateT, DepsT, InputT, AsyncIterable[OutputT]]: ...
234+
@overload
235+
def step_async_iterable(
233236
self,
234-
call: StepFunction[StateT, DepsT, InputT, OutputT] | None = None,
237+
call: StepAsyncIteratorFunction[StateT, DepsT, InputT, OutputT] | None = None,
235238
*,
236239
node_id: str | None = None,
237240
label: str | None = None,
238241
) -> (
239-
Step[StateT, DepsT, InputT, OutputT]
240-
| Callable[[StepFunction[StateT, DepsT, InputT, OutputT]], Step[StateT, DepsT, InputT, OutputT]]
242+
Step[StateT, DepsT, InputT, AsyncIterable[OutputT]]
243+
| Callable[
244+
[StepAsyncIteratorFunction[StateT, DepsT, InputT, OutputT]],
245+
Step[StateT, DepsT, InputT, AsyncIterable[OutputT]],
246+
]
247+
): ...
248+
def step_async_iterable(
249+
self,
250+
call: StepAsyncIteratorFunction[StateT, DepsT, InputT, OutputT] | None = None,
251+
*,
252+
node_id: str | None = None,
253+
label: str | None = None,
254+
) -> (
255+
Step[StateT, DepsT, InputT, AsyncIterable[OutputT]]
256+
| Callable[
257+
[StepAsyncIteratorFunction[StateT, DepsT, InputT, OutputT]],
258+
Step[StateT, DepsT, InputT, AsyncIterable[OutputT]],
259+
]
241260
):
242261
"""Create a step from a step function.
243262
@@ -253,9 +272,19 @@ def step(
253272
Either a Step instance or a decorator function
254273
"""
255274
if call is None:
256-
return self._step(node_id=node_id, label=label)
257-
else:
258-
return self._step(call=call, node_id=node_id, label=label)
275+
276+
def decorator(
277+
func: StepAsyncIteratorFunction[StateT, DepsT, InputT, OutputT],
278+
) -> Step[StateT, DepsT, InputT, AsyncIterable[OutputT]]:
279+
return self.step_async_iterable(call=func, node_id=node_id, label=label)
280+
281+
return decorator
282+
283+
# We need to wrap the call so that we can call `await` even though the result is an async iterator
284+
async def wrapper(ctx: StepContext[StateT, DepsT, InputT]):
285+
return call(ctx)
286+
287+
return self.step(call=wrapper, node_id=node_id, label=label)
259288

260289
@overload
261290
def join(

pydantic_graph/pydantic_graph/beta/join.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -143,8 +143,6 @@ class ReduceFirstValue(Generic[T]):
143143

144144
def __call__(self, ctx: ReducerContext[object, object], current: T, inputs: T) -> T:
145145
"""The reducer function."""
146-
if ctx.cancelled_sibling_tasks:
147-
return current
148146
ctx.cancel_sibling_tasks()
149147
return inputs
150148

pydantic_graph/pydantic_graph/beta/step.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
from __future__ import annotations
99

10-
from collections.abc import Awaitable
10+
from collections.abc import AsyncIterator, Awaitable
1111
from dataclasses import dataclass
1212
from typing import Any, Generic, Protocol, cast, get_origin, overload
1313

@@ -90,6 +90,33 @@ def __call__(self, ctx: StepContext[StateT, DepsT, InputT]) -> Awaitable[OutputT
9090
raise NotImplementedError
9191

9292

93+
class StepAsyncIteratorFunction(Protocol[StateT, DepsT, InputT, OutputT]):
94+
"""Protocol for step functions that can be executed in the graph.
95+
96+
Step functions are async callables that receive a step context and return
97+
a result. This protocol enables serialization and deserialization of step
98+
calls similar to how evaluators work.
99+
100+
Type Parameters:
101+
StateT: The type of the graph state
102+
DepsT: The type of the dependencies
103+
InputT: The type of the input data
104+
OutputT: The type of the output data
105+
"""
106+
107+
def __call__(self, ctx: StepContext[StateT, DepsT, InputT]) -> AsyncIterator[OutputT]:
108+
"""Execute the step function with the given context.
109+
110+
Args:
111+
ctx: The step context containing state, dependencies, and inputs
112+
113+
Returns:
114+
An awaitable that resolves to the step's output
115+
"""
116+
raise NotImplementedError
117+
yield
118+
119+
93120
AnyStepFunction = StepFunction[Any, Any, Any, Any]
94121
"""Type alias for a step function with any type parameters."""
95122

tests/graph/beta/test_graph_edge_cases.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -205,12 +205,11 @@ def get_early_stopping_reducer():
205205

206206
def reduce(ctx: ReducerContext[EarlyStopState, object], current: int, inputs: int) -> int:
207207
nonlocal count
208-
if not ctx.cancelled_sibling_tasks:
209-
count += 1
210-
current += inputs
211-
if count >= 2:
212-
ctx.state.stopped = True # update the state so we can assert on it later
213-
ctx.cancel_sibling_tasks()
208+
count += 1
209+
current += inputs
210+
if count >= 2:
211+
ctx.state.stopped = True # update the state so we can assert on it later
212+
ctx.cancel_sibling_tasks()
214213
return current
215214

216215
return reduce

0 commit comments

Comments
 (0)