|
22 | 22 | from pydantic_graph._utils import AbstractSpan, get_traceparent, logfire_span |
23 | 23 | from pydantic_graph.beta.decision import Decision |
24 | 24 | from pydantic_graph.beta.id_types import ForkStack, ForkStackItem, GraphRunId, JoinId, NodeId, NodeRunId, TaskId |
25 | | -from pydantic_graph.beta.join import Join, Reducer |
| 25 | +from pydantic_graph.beta.join import Join, JoinNode, Reducer |
26 | 26 | from pydantic_graph.beta.node import ( |
27 | 27 | EndNode, |
28 | 28 | Fork, |
@@ -441,7 +441,7 @@ def output(self) -> OutputT | None: |
441 | 441 | return self._next.value |
442 | 442 | return None |
443 | 443 |
|
444 | | - async def _iter_graph( |
| 444 | + async def _iter_graph( # noqa C901 |
445 | 445 | self, |
446 | 446 | ) -> AsyncGenerator[ |
447 | 447 | EndMarker[OutputT] | JoinItem | Sequence[GraphTask], EndMarker[OutputT] | JoinItem | Sequence[GraphTask] |
@@ -574,7 +574,7 @@ async def _handle_task( |
574 | 574 | step_context = StepContext[StateT, DepsT, Any](state, deps, inputs) |
575 | 575 | output = await node.call(step_context) |
576 | 576 | if isinstance(node, NodeStep): |
577 | | - return self._handle_node(node, output, fork_stack) |
| 577 | + return self._handle_node(output, fork_stack) |
578 | 578 | else: |
579 | 579 | return self._handle_edges(node, output, fork_stack) |
580 | 580 | elif isinstance(node, Join): |
@@ -613,12 +613,13 @@ def _handle_decision( |
613 | 613 |
|
614 | 614 | def _handle_node( |
615 | 615 | self, |
616 | | - node_step: NodeStep[StateT, DepsT], |
617 | 616 | next_node: BaseNode[StateT, DepsT, Any] | End[Any], |
618 | 617 | fork_stack: ForkStack, |
619 | | - ) -> Sequence[GraphTask] | EndMarker[OutputT]: |
| 618 | + ) -> Sequence[GraphTask] | JoinItem | EndMarker[OutputT]: |
620 | 619 | if isinstance(next_node, StepNode): |
621 | 620 | return [GraphTask(next_node.step.id, next_node.inputs, fork_stack)] |
| 621 | + elif isinstance(next_node, JoinNode): |
| 622 | + return JoinItem(next_node.join.id, next_node.inputs, fork_stack) |
622 | 623 | elif isinstance(next_node, BaseNode): |
623 | 624 | node_step = NodeStep(next_node.__class__) |
624 | 625 | return [GraphTask(node_step.id, next_node, fork_stack)] |
@@ -687,7 +688,10 @@ def _handle_path(self, path: Path, inputs: Any, fork_stack: ForkStack) -> Sequen |
687 | 688 |
|
688 | 689 | def _handle_edges(self, node: AnyNode, inputs: Any, fork_stack: ForkStack) -> Sequence[GraphTask]: |
689 | 690 | edges = self.graph.edges_by_source.get(node.id, []) |
690 | | - assert len(edges) == 1 or isinstance(node, Fork) # this should have already been ensured during graph building |
| 691 | + assert len(edges) == 1 or isinstance(node, Fork), ( |
| 692 | + edges, |
| 693 | + node.id, |
| 694 | + ) # this should have already been ensured during graph building |
691 | 695 |
|
692 | 696 | new_tasks: list[GraphTask] = [] |
693 | 697 | for path in edges: |
|
0 commit comments