Skip to content

Commit f3c9f12

Browse files
committed
feat: add join edge and release notes
1 parent c4c3a8e commit f3c9f12

File tree

3 files changed

+186
-47
lines changed

3 files changed

+186
-47
lines changed

docs/user-guide/release-notes.md

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,16 @@ See below for all notable changes to the GraphAI library.
44

55
### Added
66
- Direct Starlette support for `GraphEvent` and `EventCallback` objects
7+
- Parallel node execution for graphs with multiple outgoing edges from a single node
8+
- Automatic concurrent execution when a node has multiple successors
9+
- State merging from parallel branches
10+
- Configurable through standard `add_edge()` calls
11+
- New `add_join()` method for explicit convergence of parallel branches
12+
- Synchronizes multiple parallel branches to a single destination node
13+
- Ensures convergence node executes only once with merged state
14+
- Prevents duplicate execution of downstream nodes
15+
- New `add_parallel()` convenience method for creating parallel branches
16+
- Syntactic sugar for adding multiple edges from one source to multiple destinations
717

818
## [0.0.9] - 2025-09-05
919

graphai/graph.py

Lines changed: 90 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import asyncio
33
from typing import Any, Iterable, Protocol
44
from graphlib import TopologicalSorter, CycleError
5+
56
from graphai.callback import Callback
67
from graphai.utils import logger
78

@@ -64,6 +65,7 @@ def __init__(
6465
self.edges: list[Any] = []
6566
self.start_node: NodeProtocol | None = None
6667
self.end_nodes: list[NodeProtocol] = []
68+
self.join_nodes: set[NodeProtocol] = set()
6769
self.Callback: type[Callback] = Callback
6870
self.max_steps = max_steps
6971
self.state = initial_state or {}
@@ -130,6 +132,18 @@ def add_node(self, node: NodeProtocol) -> Graph:
130132
self.end_nodes.append(node)
131133
return self
132134

135+
def _get_node(self, node_candidate: NodeProtocol | str) -> NodeProtocol:
136+
# first get node from graph
137+
if isinstance(node_candidate, str):
138+
node = self.nodes.get(node_candidate)
139+
else:
140+
# check if it's a node-like object by looking for required attributes
141+
if hasattr(node_candidate, "name"):
142+
node = self.nodes.get(node_candidate.name)
143+
if node is None:
144+
raise ValueError(f"Node with name '{node_candidate}' not found.")
145+
return node
146+
133147
def add_edge(
134148
self, source: NodeProtocol | str, destination: NodeProtocol | str
135149
) -> Graph:
@@ -141,33 +155,10 @@ def add_edge(
141155
"""
142156
source_node, destination_node = None, None
143157
# get source node from graph
144-
source_name: str
145-
if isinstance(source, str):
146-
source_node = self.nodes.get(source)
147-
source_name = source
148-
else:
149-
# Check if it's a node-like object by looking for required attributes
150-
if hasattr(source, "name"):
151-
source_node = self.nodes.get(source.name)
152-
source_name = source.name
153-
else:
154-
source_name = str(source)
155-
if source_node is None:
156-
raise ValueError(f"Node with name '{source_name}' not found.")
158+
source_node = self._get_node(node_candidate=source)
157159
# get destination node from graph
158-
destination_name: str
159-
if isinstance(destination, str):
160-
destination_node = self.nodes.get(destination)
161-
destination_name = destination
162-
else:
163-
# Check if it's a node-like object by looking for required attributes
164-
if hasattr(destination, "name"):
165-
destination_node = self.nodes.get(destination.name)
166-
destination_name = destination.name
167-
else:
168-
destination_name = str(destination)
169-
if destination_node is None:
170-
raise ValueError(f"Node with name '{destination_name}' not found.")
160+
destination_node = self._get_node(node_candidate=destination)
161+
# create edge
171162
edge = Edge(source_node, destination_node)
172163
self.edges.append(edge)
173164
return self
@@ -214,7 +205,6 @@ def compile(self, *, strict: bool = False) -> Graph:
214205
nodes = getattr(self, "nodes", None)
215206
if not isinstance(nodes, dict) or not nodes:
216207
raise GraphCompileError("No nodes have been added to the graph")
217-
218208
start_name: str | None = None
219209
# Bind and narrow the attribute for mypy
220210
start_node: _HasName | None = getattr(self, "start_node", None)
@@ -230,21 +220,17 @@ def compile(self, *, strict: bool = False) -> Graph:
230220
raise GraphCompileError(f"Multiple start nodes defined: {starts}")
231221
if len(starts) == 1:
232222
start_name = starts[0]
233-
234223
if not start_name:
235224
raise GraphCompileError("No start node defined")
236-
237225
# at least one end node
238226
if not any(
239227
getattr(n, "is_end", False) or getattr(n, "end", False)
240228
for n in nodes.values()
241229
):
242230
raise GraphCompileError("No end node defined")
243-
244231
# normalize edges into adjacency {src: set(dst)}
245232
raw_edges = getattr(self, "edges", None)
246233
adj: dict[str, set[str]] = {name: set() for name in nodes.keys()}
247-
248234
def _add_edge(src: str, dst: str) -> None:
249235
if src not in nodes:
250236
raise GraphCompileError(f"Edge references unknown source node: {src}")
@@ -253,7 +239,6 @@ def _add_edge(src: str, dst: str) -> None:
253239
f"Edge from {src} references unknown node(s): ['{dst}']"
254240
)
255241
adj[src].add(dst)
256-
257242
if raw_edges is None:
258243
pass
259244
elif isinstance(raw_edges, dict):
@@ -273,13 +258,11 @@ def _add_edge(src: str, dst: str) -> None:
273258
iterator = iter(raw_edges)
274259
except TypeError:
275260
raise GraphCompileError("Internal edge map has unsupported type")
276-
277261
for item in iterator:
278262
# (src, dst) OR (src, Iterable[dst])
279263
if isinstance(item, (tuple, list)) and len(item) == 2:
280264
raw_src, rhs = item
281265
src = _require_name(raw_src, "source")
282-
283266
if isinstance(rhs, str) or getattr(rhs, "name", None):
284267
dst = _require_name(rhs, "destination")
285268
_add_edge(src, rhs)
@@ -294,7 +277,6 @@ def _add_edge(src: str, dst: str) -> None:
294277
"Edge tuple second item must be a destination or an iterable of destinations"
295278
)
296279
continue
297-
298280
# Mapping-style: {"source": "...", "destination": "..."} or {"src": "...", "dst": "..."}
299281
if isinstance(item, dict):
300282
src = _require_name(item.get("source", item.get("src")), "source")
@@ -303,7 +285,6 @@ def _add_edge(src: str, dst: str) -> None:
303285
)
304286
_add_edge(src, dst)
305287
continue
306-
307288
# Object with attributes .source/.destination (or .src/.dst)
308289
if hasattr(item, "source") or hasattr(item, "src"):
309290
src = _require_name(
@@ -315,13 +296,11 @@ def _add_edge(src: str, dst: str) -> None:
315296
)
316297
_add_edge(src, dst)
317298
continue
318-
319299
# If none matched, this is an unsupported edge record
320300
raise GraphCompileError(
321301
"Edges must be dict[str, Iterable[str]] or an iterable of (src, dst), "
322302
"(src, Iterable[dst]), mapping{'source'/'destination'}, or objects with .source/.destination"
323303
)
324-
325304
# reachability from start
326305
seen: set[str] = set()
327306
stack = [start_name]
@@ -331,11 +310,9 @@ def _add_edge(src: str, dst: str) -> None:
331310
continue
332311
seen.add(cur)
333312
stack.extend(adj.get(cur, ()))
334-
335313
unreachable = sorted(set(nodes.keys()) - seen)
336314
if unreachable:
337315
raise GraphCompileError(f"Unreachable nodes: {unreachable}")
338-
339316
# optional cycle detection (strict mode)
340317
if strict:
341318
preds: dict[str, set[str]] = {n: set() for n in nodes.keys()}
@@ -346,7 +323,6 @@ def _add_edge(src: str, dst: str) -> None:
346323
list(TopologicalSorter(preds).static_order())
347324
except CycleError as e:
348325
raise GraphCompileError("cycle detected in graph (strict mode)") from e
349-
350326
return self
351327

352328
def _validate_output(self, output: dict[str, Any], node_name: str):
@@ -358,7 +334,13 @@ def _validate_output(self, output: dict[str, Any], node_name: str):
358334

359335
def _get_next_nodes(self, current_node: NodeProtocol) -> list[NodeProtocol]:
360336
"""Return all successor nodes for the given node."""
361-
return [edge.destination for edge in self.edges if edge.source == current_node]
337+
# we skip JoinEdge because they don't have regular destinations
338+
# and next nodes for those are handled in the execute method
339+
return [
340+
edge.destination
341+
for edge in self.edges
342+
if isinstance(edge, Edge) and edge.source == current_node
343+
]
362344

363345
async def _invoke_node(
364346
self, node: NodeProtocol, state: dict[str, Any], callback: Callback
@@ -379,6 +361,7 @@ async def _execute_branch(
379361
state: dict[str, Any],
380362
callback: Callback,
381363
steps: int,
364+
stop_at_join: bool = False,
382365
):
383366
"""Recursively execute a branch starting from `current_node`.
384367
When a node has multiple successors, run them concurrently and merge their outputs."""
@@ -392,6 +375,9 @@ async def _execute_branch(
392375
del output["choice"]
393376
current_node = self._get_node_by_name(node_name=next_node_name)
394377
continue
378+
if stop_at_join and current_node in self.join_nodes:
379+
# for parallel branches, wait at JoinEdge until all branches are complete
380+
return state
395381

396382
next_nodes = self._get_next_nodes(current_node)
397383
if not next_nodes:
@@ -404,17 +390,43 @@ async def _execute_branch(
404390
# Run each branch concurrently
405391
results = await asyncio.gather(
406392
*[
407-
self._execute_branch(n, state.copy(), callback, steps + 1)
393+
self._execute_branch(
394+
current_node=n,
395+
state=state.copy(),
396+
callback=callback,
397+
steps=steps + 1,
398+
stop_at_join=True, # force parallel branches to wait at JoinEdge
399+
)
408400
for n in next_nodes
409401
]
410402
)
403+
# merge states returned by each branch
411404
merged = state.copy()
412405
for res in results:
413-
# merge states returned by each branch
414406
for k, v in res.items():
415407
if k != "callback":
416408
merged[k] = v
417-
return merged
409+
if set(next_nodes) & self.join_nodes:
410+
# if any of the next nodes are join nodes, we need to continue from the
411+
# JoinEdge.destination node
412+
join_edge = next(
413+
(
414+
e for e in self.edges if isinstance(e, JoinEdge)
415+
and any(n in e.sources for n in next_nodes)
416+
),
417+
None
418+
)
419+
if not join_edge:
420+
raise Exception("No JoinEdge found for next_nodes")
421+
# set current_node (for next iteration) to the JoinEdge.destination
422+
current_node = join_edge.destination
423+
# continue to the destination node with our merged state
424+
state = merged
425+
continue
426+
else:
427+
# if this happens we have multiple branches that do not join so we
428+
# can just return the merged states
429+
return merged
418430
steps += 1
419431
if steps >= self.max_steps:
420432
raise Exception(
@@ -502,20 +514,46 @@ def _get_node_by_name(self, node_name: str) -> NodeProtocol:
502514

503515
def _get_next_node(self, current_node):
504516
for edge in self.edges:
505-
if edge.source == current_node:
517+
if isinstance(edge, Edge) and edge.source == current_node:
506518
return edge.destination
519+
# we skip JoinEdge because they don't have regular destinations
520+
# and next nodes for those are handled in the execute method
507521
raise Exception(
508522
f"No outgoing edge found for current node '{current_node.name}'."
509523
)
510524

511525
def add_parallel(
512526
self, source: NodeProtocol | str, destinations: list[NodeProtocol | str]
513527
):
514-
"""Add multiple outgoing edges from a single source node to be executed in parallel."""
528+
"""Add multiple outgoing edges from a single source node to be executed in parallel.
529+
530+
Args:
531+
source: The source node for the parallel branches.
532+
destinations: The list of destination nodes for the parallel branches.
533+
"""
515534
for dest in destinations:
516535
self.add_edge(source, dest)
517536
return self
518537

538+
def add_join(
539+
self, sources: list[NodeProtocol | str], destination: NodeProtocol | str
540+
):
541+
"""Joins multiple parallel branches into a single branch.
542+
543+
Args:
544+
sources: The list of source nodes for the join.
545+
destination: The destination node for the join.
546+
"""
547+
# get source nodes from graph
548+
source_nodes = [self._get_node(node_candidate=source) for source in sources]
549+
# get destination node from graph
550+
destination_node = self._get_node(node_candidate=destination)
551+
# create join edge
552+
edge = JoinEdge(source_nodes, destination_node)
553+
self.edges.append(edge)
554+
self.join_nodes.update(source_nodes)
555+
return self
556+
519557
def visualize(self, *, save_path: str | None = None):
520558
"""Render the current graph. If matplotlib is not installed,
521559
raise a helpful error telling users to install the viz extra.
@@ -611,3 +649,8 @@ class Edge:
611649
def __init__(self, source, destination):
612650
self.source = source
613651
self.destination = destination
652+
653+
class JoinEdge:
654+
def __init__(self, sources, destination):
655+
self.sources = sources
656+
self.destination = destination

0 commit comments

Comments
 (0)