|
22 | 22 | from pydantic_graph._utils import AbstractSpan, get_traceparent, logfire_span
|
23 | 23 | from pydantic_graph.beta.decision import Decision
|
24 | 24 | from pydantic_graph.beta.id_types import ForkStack, ForkStackItem, GraphRunId, JoinId, NodeId, NodeRunId, TaskId
|
25 |
| -from pydantic_graph.beta.join import Join, Reducer |
| 25 | +from pydantic_graph.beta.join import Join, JoinNode, Reducer |
26 | 26 | from pydantic_graph.beta.node import (
|
27 | 27 | EndNode,
|
28 | 28 | Fork,
|
@@ -441,7 +441,7 @@ def output(self) -> OutputT | None:
|
441 | 441 | return self._next.value
|
442 | 442 | return None
|
443 | 443 |
|
444 |
| - async def _iter_graph( |
| 444 | + async def _iter_graph( # noqa C901 |
445 | 445 | self,
|
446 | 446 | ) -> AsyncGenerator[
|
447 | 447 | EndMarker[OutputT] | JoinItem | Sequence[GraphTask], EndMarker[OutputT] | JoinItem | Sequence[GraphTask]
|
@@ -574,7 +574,7 @@ async def _handle_task(
|
574 | 574 | step_context = StepContext[StateT, DepsT, Any](state, deps, inputs)
|
575 | 575 | output = await node.call(step_context)
|
576 | 576 | if isinstance(node, NodeStep):
|
577 |
| - return self._handle_node(node, output, fork_stack) |
| 577 | + return self._handle_node(output, fork_stack) |
578 | 578 | else:
|
579 | 579 | return self._handle_edges(node, output, fork_stack)
|
580 | 580 | elif isinstance(node, Join):
|
@@ -613,12 +613,13 @@ def _handle_decision(
|
613 | 613 |
|
614 | 614 | def _handle_node(
|
615 | 615 | self,
|
616 |
| - node_step: NodeStep[StateT, DepsT], |
617 | 616 | next_node: BaseNode[StateT, DepsT, Any] | End[Any],
|
618 | 617 | fork_stack: ForkStack,
|
619 |
| - ) -> Sequence[GraphTask] | EndMarker[OutputT]: |
| 618 | + ) -> Sequence[GraphTask] | JoinItem | EndMarker[OutputT]: |
620 | 619 | if isinstance(next_node, StepNode):
|
621 | 620 | return [GraphTask(next_node.step.id, next_node.inputs, fork_stack)]
|
| 621 | + elif isinstance(next_node, JoinNode): |
| 622 | + return JoinItem(next_node.join.id, next_node.inputs, fork_stack) |
622 | 623 | elif isinstance(next_node, BaseNode):
|
623 | 624 | node_step = NodeStep(next_node.__class__)
|
624 | 625 | return [GraphTask(node_step.id, next_node, fork_stack)]
|
@@ -687,7 +688,10 @@ def _handle_path(self, path: Path, inputs: Any, fork_stack: ForkStack) -> Sequen
|
687 | 688 |
|
688 | 689 | def _handle_edges(self, node: AnyNode, inputs: Any, fork_stack: ForkStack) -> Sequence[GraphTask]:
|
689 | 690 | edges = self.graph.edges_by_source.get(node.id, [])
|
690 |
| - assert len(edges) == 1 or isinstance(node, Fork) # this should have already been ensured during graph building |
| 691 | + assert len(edges) == 1 or isinstance(node, Fork), ( |
| 692 | + edges, |
| 693 | + node.id, |
| 694 | + ) # this should have already been ensured during graph building |
691 | 695 |
|
692 | 696 | new_tasks: list[GraphTask] = []
|
693 | 697 | for path in edges:
|
|
0 commit comments