|
22 | 22 | from pydantic_graph.beta.graph import Graph |
23 | 23 | from pydantic_graph.beta.id_types import ForkID, JoinID, NodeID, generate_placeholder_node_id, replace_placeholder_id |
24 | 24 | from pydantic_graph.beta.join import Join, JoinNode, ReducerFunction |
| 25 | +from pydantic_graph.beta.mermaid import build_mermaid_graph |
25 | 26 | from pydantic_graph.beta.node import ( |
26 | 27 | EndNode, |
27 | 28 | Fork, |
|
59 | 60 | T = TypeVar('T', infer_variance=True) |
60 | 61 |
|
61 | 62 |
|
| 63 | +class GraphBuildingError(ValueError): |
| 64 | + """An error raised during graph-building.""" |
| 65 | + |
| 66 | + pass |
| 67 | + |
| 68 | + |
62 | 69 | @dataclass(init=False) |
63 | 70 | class GraphBuilder(Generic[StateT, DepsT, GraphInputT, GraphOutputT]): |
64 | 71 | """A builder for constructing executable graph definitions. |
@@ -476,7 +483,7 @@ def match_node( |
476 | 483 | Returns: |
477 | 484 | A DecisionBranch for the BaseNode type |
478 | 485 | """ |
479 | | - assert False # TODO: Need to cover this in a test |
| 486 | + # TODO: Need to cover this in a test |
480 | 487 | node = NodeStep(source) |
481 | 488 | path = Path(items=[DestinationMarker(node.id)]) |
482 | 489 | return DecisionBranch(source=source, matches=matches, path=path, destinations=[node]) |
@@ -532,7 +539,9 @@ def _insert_node(self, node: AnyNode) -> None: |
532 | 539 | elif isinstance(existing, NodeStep) and isinstance(node, NodeStep) and existing.node_type is node.node_type: |
533 | 540 | pass |
534 | 541 | elif existing is not node: |
535 | | - raise ValueError(f'All nodes must have unique node IDs. {node.id!r} was the ID for {existing} and {node}') |
| 542 | + raise GraphBuildingError( |
| 543 | + f'All nodes must have unique node IDs. {node.id!r} was the ID for {existing} and {node}' |
| 544 | + ) |
536 | 545 |
|
537 | 546 | def _edge_from_return_hint( |
538 | 547 | self, node: SourceNode[StateT, DepsT, Any], return_hint: TypeOrTypeExpression[Any] |
@@ -631,9 +640,6 @@ def build(self) -> Graph[StateT, DepsT, GraphInputT, GraphOutputT]: |
631 | 640 | nodes, edges_by_source = _flatten_paths(nodes, edges_by_source) |
632 | 641 | nodes, edges_by_source = _normalize_forks(nodes, edges_by_source) |
633 | 642 | parent_forks = _collect_dominating_forks(nodes, edges_by_source) |
634 | | - print(nodes) |
635 | | - print(edges_by_source) |
636 | | - print(parent_forks) |
637 | 643 |
|
638 | 644 | return Graph[StateT, DepsT, GraphInputT, GraphOutputT]( |
639 | 645 | name=self.name, |
@@ -720,10 +726,9 @@ def _normalize_forks( |
720 | 726 | if len(edges_from_source) == 1: |
721 | 727 | new_edges[source_id] = edges_from_source |
722 | 728 | continue |
723 | | - assert False # TODO: Need to cover this in a test, specifically by using `.to()` with multiple arguments |
724 | 729 | new_fork = Fork[Any, Any](id=ForkID(NodeID(f'{node.id}_broadcast_fork')), is_map=False, downstream_join_id=None) |
725 | 730 | new_nodes[new_fork.id] = new_fork |
726 | | - new_edges[source_id] = [Path(items=[BroadcastMarker(fork_id=new_fork.id, paths=edges_from_source)])] |
| 731 | + new_edges[source_id] = [Path(items=[DestinationMarker(new_fork.id)])] |
727 | 732 | new_edges[new_fork.id] = edges_from_source |
728 | 733 |
|
729 | 734 | return new_nodes, new_edges |
@@ -796,9 +801,22 @@ def _handle_path(path: Path, last_source_id: NodeID): |
796 | 801 | join.id, explicit_fork_id=join.parent_fork_id, prefer_closest=join.preferred_parent_fork == 'closest' |
797 | 802 | ) |
798 | 803 | if dominating_fork is None: |
799 | | - # TODO(P3): Print out the mermaid graph and explain the problem |
800 | | - assert False # TODO: Need to cover this in a test |
801 | | - raise ValueError(f'Join node {join.id} has no dominating fork') |
| 804 | + rendered_mermaid_graph = build_mermaid_graph(graph_nodes, graph_edges_by_source).render() |
| 805 | + error_message = f"""\ |
| 806 | +For every Join J in the graph, there must be a Fork F between the StartNode and J satisfying: |
| 807 | +* Every path from the StartNode to J passes through F |
| 808 | +* There are no cycles in the graph including both J and F. |
| 809 | +In this case, F is called a "dominating fork" for J. |
| 810 | +
|
| 811 | +This is used to determine when all tasks upstream of this Join are complete and we can proceed with execution. |
| 812 | +
|
| 813 | +Mermaid diagram: |
| 814 | +{rendered_mermaid_graph} |
| 815 | +
|
| 816 | +Join {join.id!r} in this graph has no dominating fork.\ |
| 817 | +""" |
| 818 | + # TODO: Need to cover this in a test |
| 819 | + raise GraphBuildingError(error_message) |
802 | 820 | dominating_forks[join.id] = dominating_fork |
803 | 821 |
|
804 | 822 | return dominating_forks |
|
0 commit comments