Skip to content

Commit 519c5d6

Browse files
committed
Rename IDs and update some docs
1 parent d5f7729 commit 519c5d6

File tree

11 files changed

+87
-87
lines changed

11 files changed

+87
-87
lines changed

docs/graph/beta/index.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
# Beta Graph API
22

33
!!! warning "Beta API"
4-
This is the new beta graph API. It provides enhanced capabilities for parallel execution, conditional branching, and complex workflows. The original graph API is still available and documented in the [main graph documentation](../../graph.md).
4+
This is the new beta graph API. It provides enhanced capabilities for parallel execution, conditional branching, and complex workflows.
5+
The original graph API is still available (and compatible of interop with the new beta API) and is documented in the [main graph documentation](../../graph.md).
56

67
## Overview
78

@@ -13,7 +14,7 @@ The beta graph API in `pydantic-graph` provides a powerful builder pattern for c
1314
- **Broadcast operations** for sending the same data to multiple parallel paths
1415
- **Join nodes and Reducers** for aggregating results from parallel execution
1516

16-
This API is designed for advanced workflows where you need explicit control over parallelism, routing, and data aggregation.
17+
This API is designed for advanced workflows where you want declarative control over parallelism, routing, and data aggregation.
1718

1819
## Installation
1920

pydantic_graph/pydantic_graph/beta/decision.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414
from typing_extensions import Never, Self, TypeVar
1515

16-
from pydantic_graph.beta.id_types import ForkId, JoinId, NodeId
16+
from pydantic_graph.beta.id_types import ForkID, JoinID, NodeID
1717
from pydantic_graph.beta.paths import Path, PathBuilder
1818
from pydantic_graph.beta.step import StepFunction
1919
from pydantic_graph.beta.util import TypeOrTypeExpression
@@ -42,7 +42,7 @@ class Decision(Generic[StateT, DepsT, HandledT]):
4242
branches based on the input data type or custom matching logic.
4343
"""
4444

45-
id: NodeId
45+
id: NodeID
4646
"""Unique identifier for this decision node."""
4747

4848
branches: list[DecisionBranch[Any]]
@@ -145,7 +145,7 @@ class DecisionBranchBuilder(Generic[StateT, DepsT, OutputT, SourceT, HandledT]):
145145
"""Builder for the execution path."""
146146

147147
@property
148-
def last_fork_id(self) -> ForkId | None:
148+
def last_fork_id(self) -> ForkID | None:
149149
"""Get the ID of the last fork in the path.
150150
151151
Returns:
@@ -214,8 +214,8 @@ def transform(
214214
def spread(
215215
self: DecisionBranchBuilder[StateT, DepsT, Iterable[T], SourceT, HandledT],
216216
*,
217-
fork_id: ForkId | None = None,
218-
downstream_join_id: JoinId | None = None,
217+
fork_id: ForkID | None = None,
218+
downstream_join_id: JoinID | None = None,
219219
) -> DecisionBranchBuilder[StateT, DepsT, T, SourceT, HandledT]:
220220
"""Spread the branch's output.
221221

pydantic_graph/pydantic_graph/beta/graph.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from pydantic_graph import exceptions
2222
from pydantic_graph._utils import AbstractSpan, get_traceparent, logfire_span
2323
from pydantic_graph.beta.decision import Decision
24-
from pydantic_graph.beta.id_types import ForkStack, ForkStackItem, GraphRunId, JoinId, NodeId, NodeRunId, TaskId
24+
from pydantic_graph.beta.id_types import ForkStack, ForkStackItem, GraphRunID, JoinID, NodeID, NodeRunID, TaskID
2525
from pydantic_graph.beta.join import Join, JoinNode, Reducer
2626
from pydantic_graph.beta.node import (
2727
EndNode,
@@ -82,7 +82,7 @@ class JoinItem:
8282
node, along with metadata about which execution 'fork' it originated from.
8383
"""
8484

85-
join_id: JoinId
85+
join_id: JoinID
8686
"""The ID of the join node this item is targeting."""
8787

8888
inputs: Any
@@ -125,16 +125,16 @@ class Graph(Generic[StateT, DepsT, InputT, OutputT]):
125125
auto_instrument: bool
126126
"""Whether to automatically create instrumentation spans."""
127127

128-
nodes: dict[NodeId, AnyNode]
128+
nodes: dict[NodeID, AnyNode]
129129
"""All nodes in the graph indexed by their ID."""
130130

131-
edges_by_source: dict[NodeId, list[Path]]
131+
edges_by_source: dict[NodeID, list[Path]]
132132
"""Outgoing paths from each source node."""
133133

134-
parent_forks: dict[JoinId, ParentFork[NodeId]]
134+
parent_forks: dict[JoinID, ParentFork[NodeID]]
135135
"""Parent fork information for each join node."""
136136

137-
def get_parent_fork(self, join_id: JoinId) -> ParentFork[NodeId]:
137+
def get_parent_fork(self, join_id: JoinID) -> ParentFork[NodeID]:
138138
"""Get the parent fork information for a join node.
139139
140140
Args:
@@ -288,7 +288,7 @@ class GraphTask:
288288
"""
289289

290290
# With our current BaseNode thing, next_node_id and next_node_inputs are merged into `next_node` itself
291-
node_id: NodeId
291+
node_id: NodeID
292292
"""The ID of the node to execute."""
293293

294294
inputs: Any
@@ -300,7 +300,7 @@ class GraphTask:
300300
Used by the GraphRun to decide when to proceed through joins.
301301
"""
302302

303-
task_id: TaskId = field(default_factory=lambda: TaskId(str(uuid.uuid4())))
303+
task_id: TaskID = field(default_factory=lambda: TaskID(str(uuid.uuid4())))
304304
"""Unique identifier for this task."""
305305

306306

@@ -346,14 +346,14 @@ def __init__(
346346
self.inputs = inputs
347347
"""The initial input data."""
348348

349-
self._active_reducers: dict[tuple[JoinId, NodeRunId], tuple[Reducer[Any, Any, Any, Any], ForkStack]] = {}
349+
self._active_reducers: dict[tuple[JoinID, NodeRunID], tuple[Reducer[Any, Any, Any, Any], ForkStack]] = {}
350350
"""Active reducers for join operations."""
351351

352352
self._next: EndMarker[OutputT] | JoinItem | Sequence[GraphTask] | None = None
353353
"""The next item to be processed."""
354354

355-
run_id = GraphRunId(str(uuid.uuid4()))
356-
initial_fork_stack: ForkStack = (ForkStackItem(StartNode.id, NodeRunId(run_id), 0),)
355+
run_id = GraphRunID(str(uuid.uuid4()))
356+
initial_fork_stack: ForkStack = (ForkStackItem(StartNode.id, NodeRunID(run_id), 0),)
357357
self._first_task = GraphTask(node_id=StartNode.id, inputs=inputs, fork_stack=initial_fork_stack)
358358
self._iterator = self._iter_graph()
359359

@@ -446,7 +446,7 @@ async def _iter_graph( # noqa C901
446446
) -> AsyncGenerator[
447447
EndMarker[OutputT] | JoinItem | Sequence[GraphTask], EndMarker[OutputT] | JoinItem | Sequence[GraphTask]
448448
]:
449-
tasks_by_id: dict[TaskId, GraphTask] = {}
449+
tasks_by_id: dict[TaskID, GraphTask] = {}
450450
pending: set[asyncio.Task[EndMarker[OutputT] | JoinItem | Sequence[GraphTask]]] = set()
451451

452452
def _start_task(t_: GraphTask) -> None:
@@ -490,7 +490,7 @@ def _handle_result(result: EndMarker[OutputT] | JoinItem | Sequence[GraphTask])
490490
reducer.reduce(StepContext(self.state, self.deps, result.inputs))
491491
except StopIteration:
492492
# cancel all concurrently running tasks with the same fork_run_id of the parent fork
493-
task_ids_to_cancel = set[TaskId]()
493+
task_ids_to_cancel = set[TaskID]()
494494
for task_id, t in tasks_by_id.items():
495495
for item in t.fork_stack:
496496
if item.fork_id == parent_fork_id and item.node_run_id == fork_run_id:
@@ -510,7 +510,7 @@ def _handle_result(result: EndMarker[OutputT] | JoinItem | Sequence[GraphTask])
510510
done, pending = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)
511511
for task in done:
512512
task_result = task.result()
513-
source_task = tasks_by_id.pop(TaskId(task.get_name()))
513+
source_task = tasks_by_id.pop(TaskID(task.get_name()))
514514
maybe_overridden_result = yield task_result
515515
if _handle_result(maybe_overridden_result):
516516
return
@@ -632,8 +632,8 @@ def _get_completed_fork_runs(
632632
self,
633633
t: GraphTask,
634634
active_tasks: Iterable[GraphTask],
635-
) -> list[tuple[JoinId, NodeRunId]]:
636-
completed_fork_runs: list[tuple[JoinId, NodeRunId]] = []
635+
) -> list[tuple[JoinID, NodeRunID]]:
636+
completed_fork_runs: list[tuple[JoinID, NodeRunID]] = []
637637

638638
fork_run_indices = {fsi.node_run_id: i for i, fsi in enumerate(t.fork_stack)}
639639
for join_id, fork_run_id in self._active_reducers.keys():
@@ -661,7 +661,7 @@ def _handle_path(self, path: Path, inputs: Any, fork_stack: ForkStack) -> Sequen
661661
except TypeError:
662662
raise RuntimeError(f'Cannot spread non-iterable value: {inputs!r}')
663663

664-
node_run_id = NodeRunId(str(uuid.uuid4()))
664+
node_run_id = NodeRunID(str(uuid.uuid4()))
665665

666666
# If the spread specifies a downstream join id, eagerly create a reducer for it
667667
if item.downstream_join_id is not None:
@@ -698,7 +698,7 @@ def _handle_edges(self, node: AnyNode, inputs: Any, fork_stack: ForkStack) -> Se
698698
new_tasks.extend(self._handle_path(path, inputs, fork_stack))
699699
return new_tasks
700700

701-
def _is_fork_run_completed(self, tasks: Iterable[GraphTask], join_id: JoinId, fork_run_id: NodeRunId) -> bool:
701+
def _is_fork_run_completed(self, tasks: Iterable[GraphTask], join_id: JoinID, fork_run_id: NodeRunID) -> bool:
702702
# Check if any of the tasks in the graph have this fork_run_id in their fork_stack
703703
# If this is the case, then the fork run is not yet completed
704704
parent_fork = self.graph.get_parent_fork(join_id)

pydantic_graph/pydantic_graph/beta/graph_builder.py

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from pydantic_graph import _utils, exceptions
2020
from pydantic_graph.beta.decision import Decision, DecisionBranch, DecisionBranchBuilder
2121
from pydantic_graph.beta.graph import Graph
22-
from pydantic_graph.beta.id_types import ForkId, JoinId, NodeId
22+
from pydantic_graph.beta.id_types import ForkID, JoinID, NodeID
2323
from pydantic_graph.beta.join import Join, JoinNode, Reducer
2424
from pydantic_graph.beta.node import (
2525
EndNode,
@@ -99,7 +99,7 @@ def decorator(
9999
node_id = node_id or get_callable_name(reducer_type)
100100

101101
return Join[StateT, DepsT, Any, Any](
102-
id=JoinId(NodeId(node_id)),
102+
id=JoinID(NodeID(node_id)),
103103
reducer_type=reducer_type,
104104
)
105105

@@ -137,10 +137,10 @@ class GraphBuilder(Generic[StateT, DepsT, GraphInputT, GraphOutputT]):
137137
auto_instrument: bool
138138
"""Whether to automatically create instrumentation spans."""
139139

140-
_nodes: dict[NodeId, AnyNode]
140+
_nodes: dict[NodeID, AnyNode]
141141
"""Internal storage for nodes in the graph."""
142142

143-
_edges_by_source: dict[NodeId, list[Path]]
143+
_edges_by_source: dict[NodeID, list[Path]]
144144
"""Internal storage for edges by source node."""
145145

146146
_decision_index: int
@@ -253,7 +253,7 @@ def decorator(
253253

254254
node_id = node_id or get_callable_name(call)
255255

256-
step = Step[StateT, DepsT, InputT, OutputT](id=NodeId(node_id), call=call, user_label=label)
256+
step = Step[StateT, DepsT, InputT, OutputT](id=NodeID(node_id), call=call, user_label=label)
257257

258258
return step
259259

@@ -414,8 +414,8 @@ def add_spreading_edge(
414414
*,
415415
pre_spread_label: str | None = None,
416416
post_spread_label: str | None = None,
417-
fork_id: ForkId | None = None,
418-
downstream_join_id: JoinId | None = None,
417+
fork_id: ForkID | None = None,
418+
downstream_join_id: JoinID | None = None,
419419
) -> None:
420420
"""Add an edge that spreads iterable data across parallel paths.
421421
@@ -461,7 +461,7 @@ def decision(self, *, note: str | None = None) -> Decision[StateT, DepsT, Never]
461461
Returns:
462462
A new Decision node with no branches
463463
"""
464-
return Decision(id=NodeId(self._get_new_decision_id()), branches=[], note=note)
464+
return Decision(id=NodeID(self._get_new_decision_id()), branches=[], note=note)
465465

466466
def match(
467467
self,
@@ -478,7 +478,7 @@ def match(
478478
Returns:
479479
A DecisionBranchBuilder for constructing the branch
480480
"""
481-
node_id = NodeId(self._get_new_decision_id())
481+
node_id = NodeID(self._get_new_decision_id())
482482
decision = Decision[StateT, DepsT, Never](node_id, branches=[], note=None)
483483
new_path_builder = PathBuilder[StateT, DepsT, SourceT](working_items=[])
484484
return DecisionBranchBuilder(decision=decision, source=source, matches=matches, path_builder=new_path_builder)
@@ -721,8 +721,8 @@ def build(self) -> Graph[StateT, DepsT, GraphInputT, GraphOutputT]:
721721

722722

723723
def _normalize_forks(
724-
nodes: dict[NodeId, AnyNode], edges: dict[NodeId, list[Path]]
725-
) -> tuple[dict[NodeId, AnyNode], dict[NodeId, list[Path]]]:
724+
nodes: dict[NodeID, AnyNode], edges: dict[NodeID, list[Path]]
725+
) -> tuple[dict[NodeID, AnyNode], dict[NodeID, list[Path]]]:
726726
"""Normalize the graph structure so only broadcast forks have multiple outgoing edges.
727727
728728
This function ensures that any node with multiple outgoing edges is converted
@@ -736,7 +736,7 @@ def _normalize_forks(
736736
A tuple of normalized nodes and edges
737737
"""
738738
new_nodes = nodes.copy()
739-
new_edges: dict[NodeId, list[Path]] = {}
739+
new_edges: dict[NodeID, list[Path]] = {}
740740

741741
paths_to_handle: list[Path] = []
742742

@@ -750,7 +750,7 @@ def _normalize_forks(
750750
if len(edges_from_source) == 1:
751751
new_edges[source_id] = edges_from_source
752752
continue
753-
new_fork = Fork[Any, Any](id=ForkId(NodeId(f'{node.id}_broadcast_fork')), is_spread=False)
753+
new_fork = Fork[Any, Any](id=ForkID(NodeID(f'{node.id}_broadcast_fork')), is_spread=False)
754754
new_nodes[new_fork.id] = new_fork
755755
new_edges[source_id] = [Path(items=[BroadcastMarker(fork_id=new_fork.id, paths=edges_from_source)])]
756756
new_edges[new_fork.id] = edges_from_source
@@ -772,8 +772,8 @@ def _normalize_forks(
772772

773773

774774
def _collect_dominating_forks(
775-
graph_nodes: dict[NodeId, AnyNode], graph_edges_by_source: dict[NodeId, list[Path]]
776-
) -> dict[JoinId, ParentFork[NodeId]]:
775+
graph_nodes: dict[NodeID, AnyNode], graph_edges_by_source: dict[NodeID, list[Path]]
776+
) -> dict[JoinID, ParentFork[NodeID]]:
777777
"""Find the dominating fork for each join node in the graph.
778778
779779
This function analyzes the graph structure to find the parent fork that
@@ -791,10 +791,10 @@ def _collect_dominating_forks(
791791
ValueError: If any join node lacks a dominating fork
792792
"""
793793
nodes = set(graph_nodes)
794-
start_ids: set[NodeId] = {StartNode.id}
795-
edges: dict[NodeId, list[NodeId]] = defaultdict(list)
794+
start_ids: set[NodeID] = {StartNode.id}
795+
edges: dict[NodeID, list[NodeID]] = defaultdict(list)
796796

797-
fork_ids: set[NodeId] = set(start_ids)
797+
fork_ids: set[NodeID] = set(start_ids)
798798
for source_id in nodes:
799799
working_source_id = source_id
800800
node = graph_nodes.get(source_id)
@@ -803,7 +803,7 @@ def _collect_dominating_forks(
803803
fork_ids.add(node.id)
804804
continue
805805

806-
def _handle_path(path: Path, last_source_id: NodeId):
806+
def _handle_path(path: Path, last_source_id: NodeID):
807807
"""Process a path and collect edges and fork information.
808808
809809
Args:
@@ -840,7 +840,7 @@ def _handle_path(path: Path, last_source_id: NodeId):
840840
)
841841

842842
join_ids = {node.id for node in graph_nodes.values() if isinstance(node, Join)}
843-
dominating_forks: dict[JoinId, ParentFork[NodeId]] = {}
843+
dominating_forks: dict[JoinID, ParentFork[NodeID]] = {}
844844
for join_id in join_ids:
845845
dominating_fork = finder.find_parent_fork(join_id)
846846
if dominating_fork is None:

pydantic_graph/pydantic_graph/beta/id_types.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,23 +9,23 @@
99
from dataclasses import dataclass
1010
from typing import NewType
1111

12-
NodeId = NewType('NodeId', str)
12+
NodeID = NewType('NodeID', str)
1313
"""Unique identifier for a node in the graph."""
1414

15-
NodeRunId = NewType('NodeRunId', str)
15+
NodeRunID = NewType('NodeRunID', str)
1616
"""Unique identifier for a specific execution instance of a node."""
1717

1818
# The following aliases are just included for clarity; making them NewTypes is a hassle
19-
JoinId = NodeId
19+
JoinID = NodeID
2020
"""Alias for NodeId when referring to join nodes."""
2121

22-
ForkId = NodeId
22+
ForkID = NodeID
2323
"""Alias for NodeId when referring to fork nodes."""
2424

25-
GraphRunId = NewType('GraphRunId', str)
25+
GraphRunID = NewType('GraphRunID', str)
2626
"""Unique identifier for a complete graph execution run."""
2727

28-
TaskId = NewType('TaskId', str)
28+
TaskID = NewType('TaskID', str)
2929
"""Unique identifier for a task within the graph execution."""
3030

3131

@@ -38,9 +38,9 @@ class ForkStackItem:
3838
and coordinate parallel branches of execution.
3939
"""
4040

41-
fork_id: ForkId
41+
fork_id: ForkID
4242
"""The ID of the node that created this fork."""
43-
node_run_id: NodeRunId
43+
node_run_id: NodeRunID
4444
"""The ID associated to the specific run of the node that created this fork."""
4545
thread_index: int
4646
"""The index of the execution "thread" created during the node run that created this fork.

pydantic_graph/pydantic_graph/beta/join.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from typing_extensions import TypeVar
1515

1616
from pydantic_graph import BaseNode, End, GraphRunContext
17-
from pydantic_graph.beta.id_types import ForkId, JoinId
17+
from pydantic_graph.beta.id_types import ForkID, JoinID
1818
from pydantic_graph.beta.step import StepContext
1919

2020
StateT = TypeVar('StateT', infer_variance=True)
@@ -206,7 +206,7 @@ class Join(Generic[StateT, DepsT, InputT, OutputT]):
206206
"""
207207

208208
def __init__(
209-
self, id: JoinId, reducer_type: type[Reducer[StateT, DepsT, InputT, OutputT]], joins: ForkId | None = None
209+
self, id: JoinID, reducer_type: type[Reducer[StateT, DepsT, InputT, OutputT]], joins: ForkID | None = None
210210
) -> None:
211211
"""Initialize a join operation.
212212

0 commit comments

Comments
 (0)