Skip to content

Commit 592686d

Browse files
committed
Try to improve coverage
1 parent dc1508b commit 592686d

File tree

5 files changed

+57
-95
lines changed

5 files changed

+57
-95
lines changed

pydantic_graph/pydantic_graph/beta/graph.py

Lines changed: 8 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -84,10 +84,6 @@ def __init__(self, value: OutputT):
8484
def value(self) -> OutputT:
8585
return self._value
8686

87-
@property
88-
def end_marker_id(self) -> str:
89-
return f'EndMarker:{id(self)}'
90-
9187

9288
@dataclass
9389
class JoinItem:
@@ -493,8 +489,6 @@ class _GraphIterator(Generic[StateT, DepsT, OutputT]):
493489

494490
cancel_scopes: dict[TaskID, CancelScope] = field(init=False)
495491
active_tasks: dict[TaskID, GraphTask] = field(init=False)
496-
pending_task_results: set[TaskID] = field(init=False)
497-
cancelled_tasks: set[TaskID] = field(init=False)
498492
active_reducers: dict[tuple[JoinID, NodeRunID], JoinState] = field(init=False)
499493
iter_stream_sender: MemoryObjectSendStream[_GraphTaskResult] = field(init=False)
500494
iter_stream_receiver: MemoryObjectReceiveStream[_GraphTaskResult] = field(init=False)
@@ -506,9 +500,6 @@ def __post_init__(self):
506500
self.active_reducers = {}
507501
self.iter_stream_sender, self.iter_stream_receiver = create_memory_object_stream[_GraphTaskResult]()
508502

509-
self.pending_task_results = set()
510-
self.cancelled_tasks = set()
511-
512503
@property
513504
def task_group(self) -> TaskGroup:
514505
if self._task_group is None:
@@ -529,21 +520,6 @@ async def iter_graph( # noqa C901
529520
async with self.iter_stream_receiver:
530521
while self.active_tasks or self.active_reducers:
531522
async for task_result in self.iter_stream_receiver: # pragma: no branch
532-
# If we encounter a mock task, add it to the active tasks to ensure we don't proceed until everything downstream is handled
533-
if (
534-
not task_result.source_is_finished
535-
and task_result.source.task_id not in self.active_tasks
536-
):
537-
self.active_tasks[task_result.source.task_id] = task_result.source
538-
539-
if task_result.source.task_id in self.cancelled_tasks:
540-
if task_result.source_is_finished:
541-
self.cancelled_tasks.remove(task_result.source.task_id)
542-
continue
543-
544-
if task_result.source_is_finished:
545-
self.pending_task_results.discard(task_result.source.task_id)
546-
547523
if isinstance(task_result.result, JoinItem):
548524
maybe_overridden_result = task_result.result
549525
else:
@@ -574,7 +550,7 @@ async def iter_graph( # noqa C901
574550
join_state.current = join_node.reduce(context, join_state.current, result.inputs)
575551
if join_state.cancelled_sibling_tasks:
576552
await self._cancel_sibling_tasks(parent_fork_id, fork_run_id)
577-
if task_result.source_is_finished:
553+
if task_result.source_is_finished: # pragma: no branch
578554
await self._finish_task(task_result.source.task_id)
579555
else:
580556
for new_task in maybe_overridden_result:
@@ -636,13 +612,18 @@ async def iter_graph( # noqa C901
636612
join_node, join_state.current, join_state.downstream_fork_stack
637613
)
638614
maybe_overridden_result = yield new_tasks
639-
if isinstance(maybe_overridden_result, EndMarker):
615+
if isinstance(maybe_overridden_result, EndMarker): # pragma: no cover
616+
# This is theoretically reachable but it would be awkward.
617+
# Probably a better way to get coverage here would be to unify the code pat
618+
# with the other `if isinstance(maybe_overridden_result, EndMarker):`
640619
self.task_group.cancel_scope.cancel()
641620
return
642621
for new_task in maybe_overridden_result:
643622
self.active_tasks[new_task.task_id] = new_task
644623
new_task_ids = {t.task_id for t in maybe_overridden_result}
645-
for t in new_tasks:
624+
for t in new_tasks: # pragma: no cover
625+
# Same note as above about how this is theoretically reachable but we should
626+
# just get coverage by unifying the code paths
646627
if t.task_id not in new_task_ids:
647628
await self._finish_task(t.task_id)
648629
self._handle_execution_request(maybe_overridden_result)
@@ -674,7 +655,6 @@ async def _run_tracked_task(self, t_: GraphTask):
674655
await self.iter_stream_sender.send(_GraphTaskResult(t_, new_tasks, False))
675656
await self.iter_stream_sender.send(_GraphTaskResult(t_, []))
676657
else:
677-
self.pending_task_results.add(t_.task_id)
678658
await self.iter_stream_sender.send(_GraphTaskResult(t_, result))
679659

680660
async def _run_task(
@@ -872,8 +852,6 @@ async def _cancel_sibling_tasks(self, parent_fork_id: ForkID, node_run_id: NodeR
872852
else:
873853
pass
874854
for task_id in task_ids_to_cancel:
875-
if task_id in self.pending_task_results:
876-
self.cancelled_tasks.add(task_id)
877855
await self._finish_task(task_id)
878856

879857

pydantic_graph/pydantic_graph/beta/graph_builder.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@
4444
Path,
4545
PathBuilder,
4646
)
47-
from pydantic_graph.beta.step import NodeStep, Step, StepAsyncIteratorFunction, StepContext, StepFunction, StepNode
47+
from pydantic_graph.beta.step import NodeStep, Step, StepContext, StepFunction, StepNode, StreamFunction
4848
from pydantic_graph.beta.util import TypeOrTypeExpression, get_callable_name, unpack_type_expression
4949
from pydantic_graph.exceptions import GraphBuildingError
5050
from pydantic_graph.nodes import BaseNode, End
@@ -215,50 +215,50 @@ def decorator(
215215
return step
216216

217217
@overload
218-
def step_async_iterable(
218+
def stream(
219219
self,
220220
*,
221221
node_id: str | None = None,
222222
label: str | None = None,
223223
) -> Callable[
224-
[StepAsyncIteratorFunction[StateT, DepsT, InputT, OutputT]], Step[StateT, DepsT, InputT, AsyncIterable[OutputT]]
224+
[StreamFunction[StateT, DepsT, InputT, OutputT]], Step[StateT, DepsT, InputT, AsyncIterable[OutputT]]
225225
]: ...
226226
@overload
227-
def step_async_iterable(
227+
def stream(
228228
self,
229-
call: StepAsyncIteratorFunction[StateT, DepsT, InputT, OutputT],
229+
call: StreamFunction[StateT, DepsT, InputT, OutputT],
230230
*,
231231
node_id: str | None = None,
232232
label: str | None = None,
233233
) -> Step[StateT, DepsT, InputT, AsyncIterable[OutputT]]: ...
234234
@overload
235-
def step_async_iterable(
235+
def stream(
236236
self,
237-
call: StepAsyncIteratorFunction[StateT, DepsT, InputT, OutputT] | None = None,
237+
call: StreamFunction[StateT, DepsT, InputT, OutputT] | None = None,
238238
*,
239239
node_id: str | None = None,
240240
label: str | None = None,
241241
) -> (
242242
Step[StateT, DepsT, InputT, AsyncIterable[OutputT]]
243243
| Callable[
244-
[StepAsyncIteratorFunction[StateT, DepsT, InputT, OutputT]],
244+
[StreamFunction[StateT, DepsT, InputT, OutputT]],
245245
Step[StateT, DepsT, InputT, AsyncIterable[OutputT]],
246246
]
247247
): ...
248-
def step_async_iterable(
248+
def stream(
249249
self,
250-
call: StepAsyncIteratorFunction[StateT, DepsT, InputT, OutputT] | None = None,
250+
call: StreamFunction[StateT, DepsT, InputT, OutputT] | None = None,
251251
*,
252252
node_id: str | None = None,
253253
label: str | None = None,
254254
) -> (
255255
Step[StateT, DepsT, InputT, AsyncIterable[OutputT]]
256256
| Callable[
257-
[StepAsyncIteratorFunction[StateT, DepsT, InputT, OutputT]],
257+
[StreamFunction[StateT, DepsT, InputT, OutputT]],
258258
Step[StateT, DepsT, InputT, AsyncIterable[OutputT]],
259259
]
260260
):
261-
"""Create a step from a step function.
261+
"""Create a step from an async iterator (which functions like a "stream").
262262
263263
This method can be used as a decorator or called directly to create
264264
a step node from an async function.
@@ -274,9 +274,9 @@ def step_async_iterable(
274274
if call is None:
275275

276276
def decorator(
277-
func: StepAsyncIteratorFunction[StateT, DepsT, InputT, OutputT],
277+
func: StreamFunction[StateT, DepsT, InputT, OutputT],
278278
) -> Step[StateT, DepsT, InputT, AsyncIterable[OutputT]]:
279-
return self.step_async_iterable(call=func, node_id=node_id, label=label)
279+
return self.stream(call=func, node_id=node_id, label=label)
280280

281281
return decorator
282282

pydantic_graph/pydantic_graph/beta/join.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -70,10 +70,6 @@ def state(self) -> StateT:
7070
def deps(self) -> DepsT:
7171
return self._deps
7272

73-
@property
74-
def cancelled_sibling_tasks(self):
75-
return self._join_state.cancelled_sibling_tasks
76-
7773
def cancel_sibling_tasks(self):
7874
self._join_state.cancelled_sibling_tasks = True
7975

pydantic_graph/pydantic_graph/beta/step.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -67,9 +67,7 @@ def inputs(self) -> InputT:
6767
class StepFunction(Protocol[StateT, DepsT, InputT, OutputT]):
6868
"""Protocol for step functions that can be executed in the graph.
6969
70-
Step functions are async callables that receive a step context and return
71-
a result. This protocol enables serialization and deserialization of step
72-
calls similar to how evaluators work.
70+
Step functions are async callables that receive a step context and return a result.
7371
7472
Type Parameters:
7573
StateT: The type of the graph state
@@ -90,12 +88,10 @@ def __call__(self, ctx: StepContext[StateT, DepsT, InputT]) -> Awaitable[OutputT
9088
raise NotImplementedError
9189

9290

93-
class StepAsyncIteratorFunction(Protocol[StateT, DepsT, InputT, OutputT]):
94-
"""Protocol for step functions that can be executed in the graph.
91+
class StreamFunction(Protocol[StateT, DepsT, InputT, OutputT]):
92+
"""Protocol for stream functions that can be executed in the graph.
9593
96-
Step functions are async callables that receive a step context and return
97-
a result. This protocol enables serialization and deserialization of step
98-
calls similar to how evaluators work.
94+
Stream functions are async callables that receive a step context and return an async iterator.
9995
10096
Type Parameters:
10197
StateT: The type of the graph state
@@ -105,13 +101,13 @@ class StepAsyncIteratorFunction(Protocol[StateT, DepsT, InputT, OutputT]):
105101
"""
106102

107103
def __call__(self, ctx: StepContext[StateT, DepsT, InputT]) -> AsyncIterator[OutputT]:
108-
"""Execute the step function with the given context.
104+
"""Execute the stream function with the given context.
109105
110106
Args:
111107
ctx: The step context containing state, dependencies, and inputs
112108
113109
Returns:
114-
An awaitable that resolves to the step's output
110+
An async iterator yielding the streamed output
115111
"""
116112
raise NotImplementedError
117113
yield

tests/graph/beta/test_graph_iteration.py

Lines changed: 29 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -286,7 +286,7 @@ async def test_iter_with_async_iterable_map():
286286

287287
g = GraphBuilder(state_type=IterState, output_type=list[int])
288288

289-
@g.step_async_iterable
289+
@g.stream()
290290
async def generate_async(ctx: StepContext[IterState, None, None]) -> AsyncIterator[int]:
291291
for i in [1, 2, 3, 4]:
292292
yield i
@@ -346,21 +346,18 @@ async def process(ctx: StepContext[IterState, None, int]) -> int:
346346

347347
async with graph.iter(state=state) as run:
348348
while True:
349-
try:
350-
event = await run.next()
351-
if isinstance(event, list):
352-
# Filter out tasks where the node is 'process' and input is > 3
353-
filtered_tasks = [
354-
task
355-
for task in event
356-
if not (task.node_id == NodeID('process') and isinstance(task.inputs, int) and task.inputs > 3)
357-
]
358-
if filtered_tasks != event:
359-
# Override with filtered tasks
360-
event = await run.next(filtered_tasks)
361-
if isinstance(event, EndMarker):
362-
break
363-
except StopAsyncIteration:
349+
event = await run.next()
350+
if isinstance(event, list):
351+
# Filter out tasks where the node is 'process' and input is > 3
352+
filtered_tasks = [
353+
task
354+
for task in event
355+
if not (task.node_id == NodeID('process') and isinstance(task.inputs, int) and task.inputs > 3)
356+
]
357+
if filtered_tasks != event:
358+
# Override with filtered tasks
359+
event = await run.next(filtered_tasks)
360+
if isinstance(event, EndMarker):
364361
break
365362

366363
# Only items <= 3 should have been processed
@@ -397,24 +394,21 @@ async def second_step(ctx: StepContext[IterState, None, int]) -> int:
397394
override_done = False
398395
async with graph.iter(state=state) as run:
399396
while True:
400-
try:
401-
event = await run.next()
402-
if isinstance(event, EndMarker) and not override_done:
403-
# Instead of ending, create a new task
404-
# Get the fork_stack from the EndMarker's source
405-
fork_stack = run.next_task[0].fork_stack if isinstance(run.next_task, list) else ()
406-
407-
new_task = GraphTask(
408-
node_id=NodeID('second_step'),
409-
inputs=event.value,
410-
fork_stack=fork_stack,
411-
)
412-
413-
override_done = True
414-
event = await run.next([new_task])
415-
if isinstance(event, EndMarker) and override_done:
416-
break
417-
except StopAsyncIteration:
397+
event = await run.next()
398+
if isinstance(event, EndMarker) and not override_done:
399+
# Instead of ending, create a new task
400+
# Get the fork_stack from the EndMarker's source
401+
fork_stack = run.next_task[0].fork_stack if isinstance(run.next_task, list) else ()
402+
403+
new_task = GraphTask(
404+
node_id=NodeID('second_step'),
405+
inputs=event.value,
406+
fork_stack=fork_stack,
407+
)
408+
409+
override_done = True
410+
event = await run.next([new_task])
411+
if isinstance(event, EndMarker) and override_done:
418412
break
419413

420414
result = run.output
@@ -461,9 +455,7 @@ async def step3(ctx: StepContext[IterState, None, int]) -> str: # pragma: no co
461455
if any(task.node_id == NodeID('step2') for task in event):
462456
# Override with an EndMarker to terminate early
463457
early_exit_done = True
464-
event = await run.next(EndMarker('early_exit'))
465-
if isinstance(event, EndMarker):
466-
break
458+
await run.next(EndMarker('early_exit'))
467459
except StopAsyncIteration:
468460
break
469461

0 commit comments

Comments
 (0)