Skip to content

Commit 33cf7ff

Browse files
committed
Various improvements
1 parent fab8317 commit 33cf7ff

File tree

9 files changed

+238
-213
lines changed

9 files changed

+238
-213
lines changed

pydantic_graph/pydantic_graph/beta/decision.py

Lines changed: 27 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
from collections.abc import Callable, Iterable, Sequence
1111
from dataclasses import dataclass
12-
from typing import TYPE_CHECKING, Any, Final, Generic
12+
from typing import TYPE_CHECKING, Any, Generic
1313

1414
from typing_extensions import Never, Self, TypeVar
1515

@@ -123,40 +123,29 @@ class DecisionBranch(Generic[SourceT]):
123123
"""Type variable for transformed output."""
124124

125125

126-
@dataclass(kw_only=True)
126+
@dataclass(kw_only=True, frozen=True)
127127
class DecisionBranchBuilder(Generic[StateT, DepsT, OutputT, SourceT, HandledT]):
128128
"""Builder for constructing decision branches with fluent API.
129129
130130
This builder provides methods to configure branches with destinations,
131131
forks, and transformations in a type-safe manner.
132+
133+
Instances of this class should be created using [`GraphBuilder.match`][pydantic_graph.beta.graph_builder.GraphBuilder],
134+
not created directly.
132135
"""
133136

134-
# The use of `Final` on these attributes is necessary for them to be treated as read-only for purposes
135-
# of variance-inference. This could be done with `frozen` but that
136-
decision: Final[Decision[StateT, DepsT, HandledT]]
137+
_decision: Decision[StateT, DepsT, HandledT]
137138
"""The parent decision node."""
138139

139-
source: Final[TypeOrTypeExpression[SourceT]]
140+
_source: TypeOrTypeExpression[SourceT]
140141
"""The expected source type for this branch."""
141142

142-
matches: Final[Callable[[Any], bool] | None]
143+
_matches: Callable[[Any], bool] | None
143144
"""Optional matching predicate."""
144145

145-
path_builder: Final[PathBuilder[StateT, DepsT, OutputT]]
146+
_path_builder: PathBuilder[StateT, DepsT, OutputT]
146147
"""Builder for the execution path."""
147148

148-
@property
149-
def last_fork_id(self) -> ForkID | None:
150-
"""Get the ID of the last fork in the path.
151-
152-
Returns:
153-
The fork ID if a fork exists, None otherwise.
154-
"""
155-
last_fork = self.path_builder.last_fork
156-
if last_fork is None:
157-
return None
158-
return last_fork.fork_id
159-
160149
def to(
161150
self,
162151
destination: DestinationNode[StateT, DepsT, OutputT],
@@ -173,25 +162,25 @@ def to(
173162
A completed DecisionBranch with the specified destinations.
174163
"""
175164
return DecisionBranch(
176-
source=self.source, matches=self.matches, path=self.path_builder.to(destination, *extra_destinations)
165+
source=self._source, matches=self._matches, path=self._path_builder.to(destination, *extra_destinations)
177166
)
178167

179-
def fork(
168+
def broadcast(
180169
self,
181170
get_forks: Callable[[Self], Sequence[DecisionBranch[SourceT]]],
182171
/,
183172
) -> DecisionBranch[SourceT]:
184-
"""Create a fork in the execution path.
173+
"""Create a broadcast fork in the execution path.
185174
186175
Args:
187-
get_forks: Function that generates forked decision branches.
176+
get_forks: Function (typically a lambda) that returns the broadcast forks downstream of this decision branch.
188177
189178
Returns:
190-
A completed DecisionBranch with forked execution paths.
179+
A completed DecisionBranch with broadcast-forked execution paths.
191180
"""
192181
fork_decision_branches = get_forks(self)
193182
new_paths = [b.path for b in fork_decision_branches]
194-
return DecisionBranch(source=self.source, matches=self.matches, path=self.path_builder.fork(new_paths))
183+
return DecisionBranch(source=self._source, matches=self._matches, path=self._path_builder.broadcast(new_paths))
195184

196185
def transform(
197186
self, func: TransformFunction[StateT, DepsT, OutputT, NewOutputT], /
@@ -205,10 +194,10 @@ def transform(
205194
A new DecisionBranchBuilder where the provided transform is applied prior to generating the final output.
206195
"""
207196
return DecisionBranchBuilder(
208-
decision=self.decision,
209-
source=self.source,
210-
matches=self.matches,
211-
path_builder=self.path_builder.transform(func),
197+
_decision=self._decision,
198+
_source=self._source,
199+
_matches=self._matches,
200+
_path_builder=self._path_builder.transform(func),
212201
)
213202

214203
def map(
@@ -230,10 +219,10 @@ def map(
230219
A new DecisionBranchBuilder where mapping is performed prior to generating the final output.
231220
"""
232221
return DecisionBranchBuilder(
233-
decision=self.decision,
234-
source=self.source,
235-
matches=self.matches,
236-
path_builder=self.path_builder.map(fork_id=fork_id, downstream_join_id=downstream_join_id),
222+
_decision=self._decision,
223+
_source=self._source,
224+
_matches=self._matches,
225+
_path_builder=self._path_builder.map(fork_id=fork_id, downstream_join_id=downstream_join_id),
237226
)
238227

239228
def label(self, label: str) -> DecisionBranchBuilder[StateT, DepsT, OutputT, SourceT, HandledT]:
@@ -248,8 +237,8 @@ def label(self, label: str) -> DecisionBranchBuilder[StateT, DepsT, OutputT, Sou
248237
A new DecisionBranchBuilder where the label has been applied at the end of the current path being built.
249238
"""
250239
return DecisionBranchBuilder(
251-
decision=self.decision,
252-
source=self.source,
253-
matches=self.matches,
254-
path_builder=self.path_builder.label(label),
240+
_decision=self._decision,
241+
_source=self._source,
242+
_matches=self._matches,
243+
_path_builder=self._path_builder.label(label),
255244
)

pydantic_graph/pydantic_graph/beta/graph_builder.py

Lines changed: 91 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,10 @@
88
from __future__ import annotations
99

1010
import inspect
11-
from collections import defaultdict
11+
from collections import Counter, defaultdict
1212
from collections.abc import Callable, Iterable
13-
from dataclasses import dataclass
13+
from copy import deepcopy
14+
from dataclasses import dataclass, replace
1415
from types import NoneType
1516
from typing import Any, Generic, cast, get_origin, get_type_hints, overload
1617

@@ -20,7 +21,7 @@
2021
from pydantic_graph._utils import UNSET, Unset
2122
from pydantic_graph.beta.decision import Decision, DecisionBranch, DecisionBranchBuilder
2223
from pydantic_graph.beta.graph import Graph
23-
from pydantic_graph.beta.id_types import ForkID, JoinID, NodeID
24+
from pydantic_graph.beta.id_types import ForkID, JoinID, NodeID, is_placeholder_node_id
2425
from pydantic_graph.beta.join import Join, JoinNode, ReducerFunction
2526
from pydantic_graph.beta.node import (
2627
EndNode,
@@ -59,6 +60,7 @@
5960
T = TypeVar('T', infer_variance=True)
6061

6162

63+
# TODO: Make this kw-only and drop init=False..?
6264
@dataclass(init=False)
6365
class GraphBuilder(Generic[StateT, DepsT, GraphInputT, GraphOutputT]):
6466
"""A builder for constructing executable graph definitions.
@@ -440,7 +442,9 @@ def match(
440442
node_id = NodeID(self._get_new_decision_id())
441443
decision = Decision[StateT, DepsT, Never](node_id, branches=[], note=None)
442444
new_path_builder = PathBuilder[StateT, DepsT, SourceT](working_items=[])
443-
return DecisionBranchBuilder(decision=decision, source=source, matches=matches, path_builder=new_path_builder)
445+
return DecisionBranchBuilder(
446+
_decision=decision, _source=source, _matches=matches, _path_builder=new_path_builder
447+
)
444448

445449
def match_node(
446450
self,
@@ -663,8 +667,11 @@ def build(self) -> Graph[StateT, DepsT, GraphInputT, GraphOutputT]:
663667
# TODO(P2): Allow the user to specify the parent forks; only infer them if _not_ specified
664668
# TODO(P2): Verify that any user-specified parent forks are _actually_ valid parent forks, and if not, generate a helpful error message
665669
# TODO(P3): Consider doing a deepcopy here to prevent modifications to the underlying nodes and edges
666-
nodes = self._nodes
667-
edges_by_source = self._edges_by_source
670+
671+
nodes = deepcopy(self._nodes)
672+
edges_by_source = deepcopy(self._edges_by_source)
673+
674+
nodes, edges_by_source = _replace_placeholder_node_ids(nodes, edges_by_source)
668675
nodes, edges_by_source = _flatten_paths(nodes, edges_by_source)
669676
nodes, edges_by_source = _normalize_forks(nodes, edges_by_source)
670677
parent_forks = _collect_dominating_forks(nodes, edges_by_source)
@@ -843,3 +850,81 @@ def _handle_path(path: Path, last_source_id: NodeID):
843850
dominating_forks[join_id] = dominating_fork
844851

845852
return dominating_forks
853+
854+
855+
def _replace_placeholder_node_ids(nodes: dict[NodeID, AnyNode], edges_by_source: dict[NodeID, list[Path]]):
856+
node_id_remapping = _build_placeholder_node_id_remapping(nodes)
857+
replaced_nodes = {
858+
node_id_remapping.get(name, name): _update_node_with_id_remapping(node, node_id_remapping)
859+
for name, node in nodes.items()
860+
}
861+
replaced_edges_by_source = {
862+
node_id_remapping.get(source, source): [_update_path_with_id_remapping(p, node_id_remapping) for p in paths]
863+
for source, paths in edges_by_source.items()
864+
}
865+
return replaced_nodes, replaced_edges_by_source
866+
867+
868+
def _build_placeholder_node_id_remapping(nodes: dict[NodeID, AnyNode]) -> dict[NodeID, NodeID]:
869+
"""The determinism of the generated remapping here is dependent on the determinism of the ordering of the `nodes` dict.
870+
871+
Note: If we want to generate more interesting names, we could try to make use of information about the edges
872+
into/out of the relevant nodes. I'm not sure if there's a good use case for that though so I didn't bother for now.
873+
"""
874+
counter = Counter[str]()
875+
remapping: dict[NodeID, NodeID] = {}
876+
for node_id, node in nodes.items():
877+
if not is_placeholder_node_id(node_id):
878+
continue
879+
label = type(node).__name__.lower()
880+
counter[label] = count = counter[label] + 1
881+
remapping[node_id] = NodeID(f'{label}_{count}')
882+
return remapping
883+
884+
885+
def _update_node_with_id_remapping(node: AnyNode, node_id_remapping: dict[NodeID, NodeID]) -> AnyNode:
886+
if isinstance(node, Step):
887+
# Even though steps are frozen, we use object.__setattr__ to overrule that and change the id value to make it
888+
# work with NodeStep.
889+
# Note: we have already deepcopied the inputs to this function so it should be okay to make mutations,
890+
# this could change if we change the code surrounding the code paths leading to this function call though.
891+
object.__setattr__(node, 'id', node_id_remapping.get(node.id, node.id))
892+
elif isinstance(node, Join):
893+
node = replace(node, id=JoinID(node_id_remapping.get(node.id, node.id)))
894+
elif isinstance(node, Fork):
895+
node = replace(node, id=ForkID(node_id_remapping.get(node.id, node.id)))
896+
elif isinstance(node, Decision):
897+
node = replace(
898+
node,
899+
id=node_id_remapping.get(node.id, node.id),
900+
branches=[
901+
replace(branch, path=_update_path_with_id_remapping(branch.path, node_id_remapping))
902+
for branch in node.branches
903+
],
904+
)
905+
return node
906+
907+
908+
def _update_path_with_id_remapping(path: Path, node_id_remapping: dict[NodeID, NodeID]) -> Path:
909+
path = replace(path) # prevent mutating the input; not technically necessary but could make debugging easier later
910+
for i, item in enumerate(path.items):
911+
if isinstance(item, MapMarker):
912+
downstream_join_id = item.downstream_join_id
913+
if downstream_join_id is not None:
914+
downstream_join_id = JoinID(node_id_remapping.get(downstream_join_id, downstream_join_id))
915+
path.items[i] = replace(
916+
item,
917+
fork_id=ForkID(node_id_remapping.get(item.fork_id, item.fork_id)),
918+
downstream_join_id=downstream_join_id,
919+
)
920+
elif isinstance(item, BroadcastMarker):
921+
path.items[i] = replace(
922+
item,
923+
fork_id=ForkID(node_id_remapping.get(item.fork_id, item.fork_id)),
924+
paths=[_update_path_with_id_remapping(p, node_id_remapping) for p in item.paths],
925+
)
926+
elif isinstance(item, DestinationMarker):
927+
path.items[i] = replace(
928+
item, destination_id=node_id_remapping.get(item.destination_id, item.destination_id)
929+
)
930+
return path

pydantic_graph/pydantic_graph/beta/id_types.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
from __future__ import annotations
88

9+
import uuid
910
from dataclasses import dataclass
1011
from typing import NewType
1112

@@ -54,3 +55,24 @@ class ForkStackItem:
5455
The fork stack tracks the complete path through nested parallel executions,
5556
allowing the system to coordinate and join parallel branches correctly.
5657
"""
58+
59+
60+
def generate_placeholder_node_id(debug_label: str) -> str:
61+
"""Generate a placeholder node ID, to be replaced during graph building."""
62+
return f'{_NODE_ID_PLACEHOLDER_PREFIX}:{debug_label}:{uuid.uuid4()}'
63+
64+
65+
def is_placeholder_node_id(node_id: NodeID) -> bool:
66+
"""Returns whether a given NodeID is a placeholder node ID which should be replaced during graph building."""
67+
return node_id.startswith(_NODE_ID_PLACEHOLDER_PREFIX)
68+
69+
70+
_NODE_ID_PLACEHOLDER_PREFIX = '__placeholder__:'
71+
"""
72+
When Node IDs are required but not specified when building a graph, we generate placeholder values
73+
using this prefix followed by a random string.
74+
75+
During graph building, we replace these with simpler and deterministically-selected values.
76+
This ensures that the node IDs are stable when rebuilding the graph, and makes the generated mermaid diagrams etc.
77+
easier to read.
78+
"""

pydantic_graph/pydantic_graph/beta/mermaid.py

Lines changed: 48 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,8 @@ def render(
131131
if direction is not None:
132132
lines.append(f' direction {direction}')
133133

134-
for node in self.nodes:
134+
nodes, edges = _topological_sort(self.nodes, self.edges)
135+
for node in nodes:
135136
# List all nodes in order they were created
136137
node_lines: list[str] = []
137138
if node.kind == 'start' or node.kind == 'end':
@@ -156,7 +157,7 @@ def render(
156157

157158
lines.append('')
158159

159-
for edge in self.edges:
160+
for edge in edges:
160161
# Use special [*] syntax for start/end nodes
161162
render_start_id = '[*]' if edge.start_id == StartNode.id else edge.start_id
162163
render_end_id = '[*]' if edge.end_id == EndNode.id else edge.end_id
@@ -167,3 +168,48 @@ def render(
167168
# TODO(P3): Support node notes/highlighting
168169

169170
return '\n'.join(lines)
171+
172+
173+
def _topological_sort(
174+
nodes: list[MermaidNode], edges: list[MermaidEdge]
175+
) -> tuple[list[MermaidNode], list[MermaidEdge]]:
176+
"""Sort nodes and edges in a logical topological order.
177+
178+
Uses BFS from the start node to assign depths, then sorts:
179+
- Nodes by their distance from start
180+
- Edges by the distance of their source and target nodes
181+
"""
182+
# Build adjacency list for BFS
183+
adjacency: dict[str, list[str]] = defaultdict(list)
184+
for edge in edges:
185+
adjacency[edge.start_id].append(edge.end_id)
186+
187+
# BFS to assign depth to each node (distance from start)
188+
depths: dict[str, int] = {}
189+
queue: list[tuple[str, int]] = [(StartNode.id, 0)]
190+
depths[StartNode.id] = 0
191+
192+
while queue:
193+
node_id, depth = queue.pop(0)
194+
for next_id in adjacency[node_id]:
195+
if next_id not in depths:
196+
depths[next_id] = depth + 1
197+
queue.append((next_id, depth + 1))
198+
199+
# Sort nodes by depth (distance from start), then by id for stability
200+
# Nodes not reachable from start get infinity depth (sorted to end)
201+
sorted_nodes = sorted(nodes, key=lambda n: (depths.get(n.id, float('inf')), n.id))
202+
203+
# Sort edges by source depth, then target depth
204+
# This ensures edges closer to start come first, edges closer to end come last
205+
sorted_edges = sorted(
206+
edges,
207+
key=lambda e: (
208+
depths.get(e.start_id, float('inf')),
209+
depths.get(e.end_id, float('inf')),
210+
e.start_id,
211+
e.end_id,
212+
),
213+
)
214+
215+
return sorted_nodes, sorted_edges

0 commit comments

Comments
 (0)