Skip to content

Commit f710a0b

Browse files
committed
Make work with async iterable
1 parent 0391387 commit f710a0b

File tree

1 file changed

+137
-93
lines changed
  • pydantic_graph/pydantic_graph/beta

1 file changed

+137
-93
lines changed

pydantic_graph/pydantic_graph/beta/graph.py

Lines changed: 137 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,12 @@
1010
import inspect
1111
import types
1212
import uuid
13-
from collections.abc import AsyncGenerator, AsyncIterator, Iterable, Sequence
13+
from collections.abc import AsyncGenerator, AsyncIterable, AsyncIterator, Iterable, Sequence
1414
from contextlib import AbstractContextManager, ExitStack, asynccontextmanager
1515
from dataclasses import dataclass, field
1616
from typing import TYPE_CHECKING, Any, Generic, Literal, TypeGuard, cast, get_args, get_origin, overload
1717

18-
import anyio
19-
from anyio import CancelScope, WouldBlock, create_memory_object_stream, create_task_group
18+
from anyio import CancelScope, create_memory_object_stream, create_task_group
2019
from anyio.abc import TaskGroup
2120
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
2221
from typing_extensions import TypeVar, assert_never
@@ -377,16 +376,17 @@ def __init__(
377376
run_id = GraphRunID(str(uuid.uuid4()))
378377
initial_fork_stack: ForkStack = (ForkStackItem(StartNode.id, NodeRunID(run_id), 0),)
379378
self._first_task = GraphTask(node_id=StartNode.id, inputs=inputs, fork_stack=initial_fork_stack)
380-
self._iterator = _GraphIterator[StateT, DepsT, OutputT](self.graph, self.state, self.deps).iter_graph(
381-
self._first_task
382-
)
379+
self._iterator_instance = _GraphIterator[StateT, DepsT, OutputT](self.graph, self.state, self.deps)
380+
self._iterator = self._iterator_instance.iter_graph(self._first_task)
383381

384382
self.__traceparent = traceparent
385383

386384
async def __aenter__(self):
387385
return self
388386

389387
async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any):
388+
self._iterator_instance.iter_stream_sender.close()
389+
self._iterator_instance.iter_stream_receiver.close()
390390
await self._iterator.aclose()
391391

392392
@overload
@@ -472,10 +472,17 @@ def output(self) -> OutputT | None:
472472
return None
473473

474474

475+
@dataclass
476+
class _GraphTaskAsyncIterable:
477+
iterable: AsyncIterable[Sequence[GraphTask]]
478+
fork_stack: ForkStack
479+
480+
475481
@dataclass
476482
class _GraphTaskResult:
477483
source: GraphTask
478-
result: EndMarker[Any] | Sequence[GraphTask]
484+
result: EndMarker[Any] | Sequence[GraphTask] | JoinItem
485+
source_is_finished: bool = True
479486

480487

481488
@dataclass
@@ -486,6 +493,8 @@ class _GraphIterator(Generic[StateT, DepsT, OutputT]):
486493

487494
cancel_scopes: dict[TaskID, CancelScope] = field(init=False)
488495
active_tasks: dict[TaskID, GraphTask] = field(init=False)
496+
pending_task_results: set[TaskID] = field(init=False)
497+
cancelled_tasks: set[TaskID] = field(init=False)
489498
active_reducers: dict[tuple[JoinID, NodeRunID], JoinState] = field(init=False)
490499
iter_stream_sender: MemoryObjectSendStream[_GraphTaskResult] = field(init=False)
491500
iter_stream_receiver: MemoryObjectReceiveStream[_GraphTaskResult] = field(init=False)
@@ -497,6 +506,9 @@ def __post_init__(self):
497506
self.active_reducers = {}
498507
self.iter_stream_sender, self.iter_stream_receiver = create_memory_object_stream[_GraphTaskResult]()
499508

509+
self.pending_task_results = set()
510+
self.cancelled_tasks = set()
511+
500512
@property
501513
def task_group(self) -> TaskGroup:
502514
if self._task_group is None:
@@ -516,30 +528,70 @@ async def iter_graph( # noqa C901
516528
# Handle task results
517529
async with self.iter_stream_receiver:
518530
while self.active_tasks or self.active_reducers:
519-
while self.active_tasks:
520-
try:
521-
task_result = self.iter_stream_receiver.receive_nowait()
522-
except WouldBlock:
523-
await anyio.sleep(0.0)
531+
async for task_result in self.iter_stream_receiver:
532+
# If we encounter a mock task, add it to the active tasks to ensure we don't proceed until everything downstream is handled
533+
if (
534+
not task_result.source_is_finished
535+
and task_result.source.task_id not in self.active_tasks
536+
):
537+
self.active_tasks[task_result.source.task_id] = task_result.source
538+
539+
if task_result.source.task_id in self.cancelled_tasks:
540+
if task_result.source_is_finished:
541+
self.cancelled_tasks.remove(task_result.source.task_id)
524542
continue
525543

526-
maybe_overridden_result = yield task_result.result
544+
if task_result.source_is_finished:
545+
self.pending_task_results.discard(task_result.source.task_id)
546+
547+
if isinstance(task_result.result, JoinItem):
548+
maybe_overridden_result = task_result.result
549+
else:
550+
maybe_overridden_result = yield task_result.result
527551
if isinstance(maybe_overridden_result, EndMarker):
528552
self.task_group.cancel_scope.cancel()
529553
return
530-
for new_task in maybe_overridden_result:
531-
self.active_tasks[new_task.task_id] = new_task
532-
await self._finish_task(task_result.source.task_id)
554+
elif isinstance(maybe_overridden_result, JoinItem):
555+
result = maybe_overridden_result
556+
parent_fork_id = self.graph.get_parent_fork(result.join_id).fork_id
557+
for i, x in enumerate(result.fork_stack[::-1]):
558+
if x.fork_id == parent_fork_id:
559+
downstream_fork_stack = result.fork_stack[: len(result.fork_stack) - i]
560+
fork_run_id = x.node_run_id
561+
break
562+
else: # pragma: no cover
563+
raise RuntimeError('Parent fork run not found')
564+
565+
join_node = self.graph.nodes[result.join_id]
566+
assert isinstance(join_node, Join), f'Expected a `Join` but got {join_node}'
567+
join_state = self.active_reducers.get((result.join_id, fork_run_id))
568+
if join_state is None:
569+
current = join_node.initial_factory()
570+
join_state = self.active_reducers[(result.join_id, fork_run_id)] = JoinState(
571+
current, downstream_fork_stack
572+
)
573+
context = ReducerContext(state=self.state, deps=self.deps, join_state=join_state)
574+
join_state.current = join_node.reduce(context, join_state.current, result.inputs)
575+
if join_state.cancelled_sibling_tasks:
576+
await self._cancel_sibling_tasks(parent_fork_id, fork_run_id)
577+
if task_result.source_is_finished:
578+
await self._finish_task(task_result.source.task_id)
579+
else:
580+
for new_task in maybe_overridden_result:
581+
self.active_tasks[new_task.task_id] = new_task
582+
if task_result.source_is_finished:
583+
await self._finish_task(task_result.source.task_id)
533584

534585
tasks_by_id_values = list(self.active_tasks.values())
535586
join_tasks: list[GraphTask] = []
587+
536588
for join_id, fork_run_id in self._get_completed_fork_runs(
537589
task_result.source, tasks_by_id_values
538590
):
539591
join_state = self.active_reducers.pop((join_id, fork_run_id))
540592
join_node = self.graph.nodes[join_id]
541593
assert isinstance(join_node, Join), f'Expected a `Join` but got {join_node}'
542-
new_tasks = self._handle_edges(
594+
new_tasks = self._handle_non_fork_edges(
543595
join_node, join_state.current, join_state.downstream_fork_stack
544596
)
545597
join_tasks.extend(new_tasks)
@@ -548,14 +600,17 @@ async def iter_graph( # noqa C901
548600
self.active_tasks[new_task.task_id] = new_task
549601
self._handle_execution_request(join_tasks)
550602

551-
if not isinstance(task_result.result, EndMarker):
552-
new_task_ids = {t.task_id for t in maybe_overridden_result}
553-
for t in task_result.result:
554-
if t.task_id not in new_task_ids:
555-
await self._finish_task(
556-
t.task_id
557-
) # TODO: Rename to cancel_task or something, instead of "finish", these didn't really get started
558-
self._handle_execution_request(maybe_overridden_result)
603+
if isinstance(maybe_overridden_result, Sequence):
604+
if isinstance(task_result.result, Sequence):
605+
new_task_ids = {t.task_id for t in maybe_overridden_result}
606+
for t in task_result.result:
607+
if t.task_id not in new_task_ids:
608+
await self._finish_task(t.task_id)
609+
self._handle_execution_request(maybe_overridden_result)
610+
611+
if not self.active_tasks:
612+
# if there are no active tasks, we'll be waiting forever for the next result..
613+
break
559614

560615
if self.active_reducers: # pragma: no branch
561616
# In this case, there are no pending tasks. We can therefore finalize all active reducers whose
@@ -577,7 +632,7 @@ async def iter_graph( # noqa C901
577632
) # we're handling it now, so we can pop it
578633
join_node = self.graph.nodes[join_id]
579634
assert isinstance(join_node, Join), f'Expected a `Join` but got {join_node}'
580-
new_tasks = self._handle_edges(
635+
new_tasks = self._handle_non_fork_edges(
581636
join_node, join_state.current, join_state.downstream_fork_stack
582637
)
583638
maybe_overridden_result = yield new_tasks
@@ -620,45 +675,18 @@ async def _run_tracked_task(self, t_: GraphTask):
620675
with CancelScope() as scope:
621676
self.cancel_scopes[t_.task_id] = scope
622677
result = await self._run_task(t_)
623-
624-
if isinstance(result, EndMarker):
625-
await self.iter_stream_sender.send(_GraphTaskResult(t_, result))
626-
return
627-
628-
new_tasks: list[GraphTask] = []
629-
if isinstance(result, JoinItem):
630-
parent_fork_id = self.graph.get_parent_fork(result.join_id).fork_id
631-
for i, x in enumerate(result.fork_stack[::-1]):
632-
if x.fork_id == parent_fork_id:
633-
downstream_fork_stack = result.fork_stack[: len(result.fork_stack) - i]
634-
fork_run_id = x.node_run_id
635-
break
636-
else: # pragma: no cover
637-
raise RuntimeError('Parent fork run not found')
638-
639-
join_node = self.graph.nodes[result.join_id]
640-
assert isinstance(join_node, Join), f'Expected a `Join` but got {join_node}'
641-
join_state = self.active_reducers.get((result.join_id, fork_run_id))
642-
if join_state is None:
643-
current = join_node.initial_factory()
644-
join_state = self.active_reducers[(result.join_id, fork_run_id)] = JoinState(
645-
current, downstream_fork_stack
646-
)
647-
648-
context = ReducerContext(state=self.state, deps=self.deps, join_state=join_state)
649-
join_state.current = join_node.reduce(context, join_state.current, result.inputs)
650-
if join_state.cancelled_sibling_tasks:
651-
await self._cancel_sibling_tasks(parent_fork_id, fork_run_id)
678+
if isinstance(result, _GraphTaskAsyncIterable):
679+
async for new_tasks in result.iterable:
680+
await self.iter_stream_sender.send(_GraphTaskResult(t_, new_tasks, False))
652681
await self.iter_stream_sender.send(_GraphTaskResult(t_, []))
653682
else:
654-
new_tasks.extend(result)
655-
656-
await self.iter_stream_sender.send(_GraphTaskResult(t_, new_tasks))
683+
self.pending_task_results.add(t_.task_id)
684+
await self.iter_stream_sender.send(_GraphTaskResult(t_, result))
657685

658686
async def _run_task(
659687
self,
660688
task: GraphTask,
661-
) -> EndMarker[OutputT] | Sequence[GraphTask] | JoinItem:
689+
) -> EndMarker[OutputT] | Sequence[GraphTask] | _GraphTaskAsyncIterable | JoinItem:
662690
state = self.state
663691
deps = self.deps
664692

@@ -769,53 +797,63 @@ def _handle_path(self, path: Path, inputs: Any, fork_stack: ForkStack) -> Sequen
769797
else:
770798
assert_never(item)
771799

772-
def _handle_edges(self, node: AnyNode, inputs: Any, fork_stack: ForkStack) -> Sequence[GraphTask]:
800+
def _handle_edges(
801+
self, node: AnyNode, inputs: Any, fork_stack: ForkStack
802+
) -> Sequence[GraphTask] | _GraphTaskAsyncIterable:
803+
if isinstance(node, Fork):
804+
return self._handle_fork_edges(node, inputs, fork_stack)
805+
else:
806+
return self._handle_non_fork_edges(node, inputs, fork_stack)
807+
808+
def _handle_non_fork_edges(self, node: AnyNode, inputs: Any, fork_stack: ForkStack) -> Sequence[GraphTask]:
809+
# TODO: Replace `node` with a type that implies it is not a Fork
810+
edges = self.graph.edges_by_source.get(node.id, [])
811+
assert len(edges) == 1 # this should have already been ensured during graph building
812+
return self._handle_path(edges[0], inputs, fork_stack)
813+
814+
def _handle_fork_edges(
815+
self, node: Fork[Any, Any], inputs: Any, fork_stack: ForkStack
816+
) -> Sequence[GraphTask] | _GraphTaskAsyncIterable:
773817
edges = self.graph.edges_by_source.get(node.id, [])
774818
assert len(edges) == 1 or (isinstance(node, Fork) and not node.is_map), (
775819
edges,
776820
node.id,
777821
) # this should have already been ensured during graph building
778822

779823
new_tasks: list[GraphTask] = []
824+
node_run_id = NodeRunID(str(uuid.uuid4()))
825+
if node.is_map:
826+
# If the map specifies a downstream join id, eagerly create a join state for it
827+
if (join_id := node.downstream_join_id) is not None:
828+
join_node = self.graph.nodes[join_id]
829+
assert isinstance(join_node, Join)
830+
self.active_reducers[(join_id, node_run_id)] = JoinState(join_node.initial_factory(), fork_stack)
831+
832+
# Eagerly raise a clear error if the input value is not iterable as expected
833+
if _is_any_iterable(inputs):
834+
for thread_index, input_item in enumerate(inputs):
835+
item_tasks = self._handle_path(
836+
edges[0], input_item, fork_stack + (ForkStackItem(node.id, node_run_id, thread_index),)
837+
)
838+
new_tasks += item_tasks
839+
elif _is_any_async_iterable(inputs):
780840

781-
if isinstance(node, Fork):
782-
node_run_id = NodeRunID(str(uuid.uuid4()))
783-
if node.is_map:
784-
# If the map specifies a downstream join id, eagerly create a join state for it
785-
if (join_id := node.downstream_join_id) is not None:
786-
join_node = self.graph.nodes[join_id]
787-
assert isinstance(join_node, Join)
788-
self.active_reducers[(join_id, node_run_id)] = JoinState(join_node.initial_factory(), fork_stack)
789-
790-
# Eagerly raise a clear error if the input value is not iterable as expected
791-
if _is_any_iterable(inputs):
792-
for thread_index, input_item in enumerate(inputs):
841+
async def handle_async_iterable() -> AsyncIterator[Sequence[GraphTask]]:
842+
thread_index = 0
843+
async for input_item in inputs:
793844
item_tasks = self._handle_path(
794845
edges[0], input_item, fork_stack + (ForkStackItem(node.id, node_run_id, thread_index),)
795846
)
796-
new_tasks += item_tasks
797-
# elif isinstance(inputs, AsyncIterable):
798-
#
799-
# async def handle_async_iterable():
800-
# thread_index = 0
801-
# async for input_item in inputs:
802-
# item_tasks = self._handle_path(
803-
# edges[0], input_item, fork_stack + (ForkStackItem(node.id, node_run_id, thread_index),)
804-
# )
805-
# yield item_tasks
806-
# thread_index += 1
807-
#
808-
# task = GraphTask(node.id, inputs, fork_stack)
809-
#
810-
# self.tg.start_soon(self._run_tracked_task, task)
811-
else:
812-
raise RuntimeError(f'Cannot map non-iterable value: {inputs!r}')
847+
yield item_tasks
848+
thread_index += 1
849+
850+
return _GraphTaskAsyncIterable(handle_async_iterable(), fork_stack)
851+
813852
else:
814-
for i, path in enumerate(edges):
815-
new_tasks += self._handle_path(path, inputs, fork_stack + (ForkStackItem(node.id, node_run_id, i),))
853+
raise RuntimeError(f'Cannot map non-iterable value: {inputs!r}')
816854
else:
817-
new_tasks += self._handle_path(edges[0], inputs, fork_stack)
818-
855+
for i, path in enumerate(edges):
856+
new_tasks += self._handle_path(path, inputs, fork_stack + (ForkStackItem(node.id, node_run_id, i),))
819857
return new_tasks
820858

821859
def _is_fork_run_completed(self, tasks: Iterable[GraphTask], join_id: JoinID, fork_run_id: NodeRunID) -> bool:
@@ -840,8 +878,14 @@ async def _cancel_sibling_tasks(self, parent_fork_id: ForkID, node_run_id: NodeR
840878
else:
841879
pass
842880
for task_id in task_ids_to_cancel:
881+
if task_id in self.pending_task_results:
882+
self.cancelled_tasks.add(task_id)
843883
await self._finish_task(task_id)
844884

845885

846886
def _is_any_iterable(x: Any) -> TypeGuard[Iterable[Any]]:
847887
return isinstance(x, Iterable)
888+
889+
890+
def _is_any_async_iterable(x: Any) -> TypeGuard[AsyncIterable[Any]]:
891+
return isinstance(x, AsyncIterable)

0 commit comments

Comments
 (0)