Skip to content

Commit c478324

Browse files
committed
WIP
1 parent 83e9312 commit c478324

File tree

3 files changed

+130
-6
lines changed

3 files changed

+130
-6
lines changed

pydantic_graph/pydantic_graph/beta/graph_builder.py

Lines changed: 118 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@
4646
)
4747
from pydantic_graph.beta.step import NodeStep, Step, StepContext, StepFunction, StepNode, StreamFunction
4848
from pydantic_graph.beta.util import TypeOrTypeExpression, get_callable_name, unpack_type_expression
49-
from pydantic_graph.exceptions import GraphBuildingError
49+
from pydantic_graph.exceptions import GraphBuildingError, GraphValidationError
5050
from pydantic_graph.nodes import BaseNode, End
5151

5252
StateT = TypeVar('StateT', infer_variance=True)
@@ -633,12 +633,16 @@ def _edge_from_return_hint(
633633
return edge.to(decision)
634634

635635
# Graph building
636-
def build(self) -> Graph[StateT, DepsT, GraphInputT, GraphOutputT]:
636+
def build(self, validate_graph_structure: bool = True) -> Graph[StateT, DepsT, GraphInputT, GraphOutputT]:
637637
"""Build the final executable graph from the accumulated nodes and edges.
638638
639639
This method performs validation, normalization, and analysis of the graph
640640
structure to create a complete, executable graph instance.
641641
642+
Args:
643+
validate_graph_structure: whether to perform validation of the graph structure
644+
See the docstring of `_validate_graph_structure` below for more details.
645+
642646
Returns:
643647
A complete Graph instance ready for execution
644648
@@ -651,9 +655,8 @@ def build(self) -> Graph[StateT, DepsT, GraphInputT, GraphOutputT]:
651655
nodes, edges_by_source = _replace_placeholder_node_ids(nodes, edges_by_source)
652656
nodes, edges_by_source = _flatten_paths(nodes, edges_by_source)
653657
nodes, edges_by_source = _normalize_forks(nodes, edges_by_source)
654-
# TODO(P2): Warn/error if the graph is not connected
655-
# TODO(P2): Warn/error if there is no start node / edges, or end node / edges
656-
# TODO(P2): Warn/error if any non-End node is a dead end
658+
if validate_graph_structure:
659+
_validate_graph_structure(nodes, edges_by_source)
657660
parent_forks = _collect_dominating_forks(nodes, edges_by_source)
658661

659662
return Graph[StateT, DepsT, GraphInputT, GraphOutputT](
@@ -669,6 +672,116 @@ def build(self) -> Graph[StateT, DepsT, GraphInputT, GraphOutputT]:
669672
)
670673

671674

675+
def _validate_graph_structure( # noqa C901
676+
nodes: dict[NodeID, AnyNode],
677+
edges_by_source: dict[NodeID, list[Path]],
678+
) -> None:
679+
"""Validate the graph structure for common issues.
680+
681+
This function raises an error if any of the following criteria are not met:
682+
1. There are edges from the start node
683+
2. There are edges to the end node
684+
3. No non-End node is a dead end (no outgoing edges)
685+
4. The end node is reachable from the start node
686+
5. All nodes are reachable from the start node
687+
688+
Note 1: Under some circumstances it may be reasonable to build a graph that violates one or more of
689+
the above conditions. We may eventually add support for more granular control over validation,
690+
but today, if you want to build a graph that violates any of these assumptions you need to pass
691+
`validate_graph_structure=False` to the call to `GraphBuilder.build`.
692+
693+
Note 2: Some of the earlier items in the above list are redundant with the later items.
694+
I've included the earlier items in the list as a reminder to ourselves if/when we add more granular validation
695+
because you might want to check the earlier items but not the later items, as described in Note 1.
696+
697+
Args:
698+
nodes: The nodes in the graph
699+
edges_by_source: The edges by source node
700+
701+
Raises:
702+
GraphBuildingError: If any of the aforementioned structural issues are found.
703+
"""
704+
how_to_suppress = ' If this is intentional, you can suppress this error by passing `validate_graph_structure=False` to the call to `GraphBuilder.build`.'
705+
706+
# Extract all destination IDs from edges and decision branches
707+
all_destinations: set[NodeID] = set()
708+
709+
def _collect_destinations_from_path(path: Path) -> None:
710+
for item in path.items:
711+
if isinstance(item, DestinationMarker):
712+
all_destinations.add(item.destination_id)
713+
714+
for paths in edges_by_source.values():
715+
for path in paths:
716+
_collect_destinations_from_path(path)
717+
718+
# Also collect destinations from decision branches
719+
for node in nodes.values():
720+
if isinstance(node, Decision):
721+
for branch in node.branches:
722+
_collect_destinations_from_path(branch.path)
723+
724+
# Check 1: Check if there are edges from the start node
725+
start_edges = edges_by_source.get(StartNode.id, [])
726+
if not start_edges:
727+
raise GraphValidationError('The graph has no edges from the start node.' + how_to_suppress)
728+
729+
# Check 2: Check if there are edges to the end node
730+
if EndNode.id not in all_destinations:
731+
raise GraphValidationError('The graph has no edges to the end node.' + how_to_suppress)
732+
733+
# Check 3: Find all nodes with no outgoing edges (dead ends)
734+
dead_end_nodes: list[NodeID] = []
735+
for node_id, node in nodes.items():
736+
# Skip the end node itself
737+
if isinstance(node, EndNode):
738+
continue
739+
740+
# Check if this node has any outgoing edges
741+
has_edges = node_id in edges_by_source and len(edges_by_source[node_id]) > 0
742+
743+
# Also check if it's a decision node with branches
744+
if isinstance(node, Decision):
745+
has_edges = has_edges or len(node.branches) > 0
746+
747+
if not has_edges:
748+
dead_end_nodes.append(node_id)
749+
750+
if dead_end_nodes:
751+
raise GraphValidationError(f'The following nodes have no outgoing edges: {dead_end_nodes}.' + how_to_suppress)
752+
753+
# Checks 4 and 5: Ensure all nodes (and in particular, the end node) are reachable from the start node
754+
reachable: set[NodeID] = {StartNode.id}
755+
to_visit = [StartNode.id]
756+
757+
while to_visit:
758+
current_id = to_visit.pop()
759+
760+
# Add destinations from regular edges
761+
for path in edges_by_source.get(current_id, []):
762+
for item in path.items:
763+
if isinstance(item, DestinationMarker):
764+
if item.destination_id not in reachable:
765+
reachable.add(item.destination_id)
766+
to_visit.append(item.destination_id)
767+
768+
# Add destinations from decision branches
769+
current_node = nodes.get(current_id)
770+
if isinstance(current_node, Decision):
771+
for branch in current_node.branches:
772+
for item in branch.path.items:
773+
if isinstance(item, DestinationMarker):
774+
if item.destination_id not in reachable:
775+
reachable.add(item.destination_id)
776+
to_visit.append(item.destination_id)
777+
778+
unreachable_nodes = [node_id for node_id in nodes if node_id not in reachable]
779+
if unreachable_nodes:
780+
raise GraphValidationError(
781+
f'The following nodes are not reachable from the start node: {unreachable_nodes}.' + how_to_suppress
782+
)
783+
784+
672785
def _flatten_paths(
673786
nodes: dict[NodeID, AnyNode], edges: dict[NodeID, list[Path]]
674787
) -> tuple[dict[NodeID, AnyNode], dict[NodeID, list[Path]]]:

pydantic_graph/pydantic_graph/exceptions.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,17 @@ def __init__(self, message: str):
2626
super().__init__(message)
2727

2828

29+
class GraphValidationError(ValueError):
30+
"""An error raised during graph validation."""
31+
32+
message: str
33+
"""The error message."""
34+
35+
def __init__(self, message: str):
36+
self.message = message
37+
super().__init__(message)
38+
39+
2940
class GraphRuntimeError(RuntimeError):
3041
"""Error caused by an issue during graph execution."""
3142

tests/graph/beta/test_graph_iteration.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -388,7 +388,7 @@ async def second_step(ctx: StepContext[IterState, None, int]) -> int:
388388
g.edge_from(second_step).to(g.end_node),
389389
)
390390

391-
graph = g.build()
391+
graph = g.build(validate_graph_structure=False)
392392
state = IterState()
393393

394394
override_done = False

0 commit comments

Comments
 (0)