Skip to content

Commit e100f6a

Browse files
authored
Merge pull request #23 from aurelio-labs/feat/str-nodes
feat: str nodes
2 parents bc68b4d + dbd0645 commit e100f6a

File tree

3 files changed

+89
-15
lines changed

3 files changed

+89
-15
lines changed

graphai/graph.py

Lines changed: 41 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
class Graph:
88
def __init__(self, max_steps: int = 10):
9-
self.nodes = []
9+
self.nodes = {}
1010
self.edges = []
1111
self.start_node = None
1212
self.end_nodes = []
@@ -15,7 +15,9 @@ def __init__(self, max_steps: int = 10):
1515
self.max_steps = max_steps
1616

1717
def add_node(self, node):
18-
self.nodes.append(node)
18+
if node.name in self.nodes:
19+
raise Exception(f"Node with name '{node.name}' already exists.")
20+
self.nodes[node.name] = node
1921
if node.is_start:
2022
if self.start_node is not None:
2123
raise Exception(
@@ -27,10 +29,37 @@ def add_node(self, node):
2729
if node.is_end:
2830
self.end_nodes.append(node)
2931

30-
def add_edge(self, source: _Node, destination: _Node):
31-
# TODO add logic to check that source and destination are nodes
32-
# and they exist in the graph object already
33-
edge = Edge(source, destination)
32+
def add_edge(self, source: _Node | str, destination: _Node | str):
33+
"""Adds an edge between two nodes that already exist in the graph.
34+
35+
Args:
36+
source: The source node or its name.
37+
destination: The destination node or its name.
38+
"""
39+
source_node, destination_node = None, None
40+
# get source node from graph
41+
if isinstance(source, str):
42+
source_node = self.nodes.get(source)
43+
else:
44+
# Check if it's a node-like object by looking for required attributes
45+
if hasattr(source, 'name'):
46+
source_node = self.nodes.get(source.name)
47+
if source_node is None:
48+
raise ValueError(
49+
f"Node with name '{source.name if hasattr(source, 'name') else source}' not found."
50+
)
51+
# get destination node from graph
52+
if isinstance(destination, str):
53+
destination_node = self.nodes.get(destination)
54+
else:
55+
# Check if it's a node-like object by looking for required attributes
56+
if hasattr(destination, 'name'):
57+
destination_node = self.nodes.get(destination.name)
58+
if destination_node is None:
59+
raise ValueError(
60+
f"Node with name '{destination.name if hasattr(destination, 'name') else destination}' not found."
61+
)
62+
edge = Edge(source_node, destination_node)
3463
self.edges.append(edge)
3564

3665
def add_router(self, sources: list[_Node], router: _Node, destinations: List[_Node]):
@@ -139,7 +168,7 @@ def visualize(self):
139168

140169
G = nx.DiGraph()
141170

142-
for node in self.nodes:
171+
for node in self.nodes.values():
143172
G.add_node(node.name)
144173

145174
for edge in self.edges:
@@ -173,10 +202,11 @@ def visualize(self):
173202
pos[node] = (pos[node][0] - x_center, pos[node][1])
174203

175204
# Scale the layout
176-
max_x = max(abs(p[0]) for p in pos.values())
177-
max_y = max(abs(p[1]) for p in pos.values())
178-
scale = min(0.8 / max_x, 0.8 / max_y)
179-
pos = {node: (x * scale, y * scale) for node, (x, y) in pos.items()}
205+
max_x = max(abs(p[0]) for p in pos.values()) if pos else 1
206+
max_y = max(abs(p[1]) for p in pos.values()) if pos else 1
207+
if max_x > 0 and max_y > 0:
208+
scale = min(0.8 / max_x, 0.8 / max_y)
209+
pos = {node: (x * scale, y * scale) for node, (x, y) in pos.items()}
180210

181211
else:
182212
print("Warning: The graph contains cycles. Visualization will use a spring layout.")

graphai/nodes/base.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ def _node(
3131
start: bool = False,
3232
end: bool = False,
3333
stream: bool = False,
34+
name: str | None = None,
3435
) -> Callable:
3536
"""Decorator validating node structure.
3637
"""
@@ -117,7 +118,7 @@ async def invoke(cls, input: Dict[str, Any], callback: Optional[Callback] = None
117118
return out
118119

119120
NodeClass.__name__ = func.__name__
120-
NodeClass.name = func.__name__
121+
NodeClass.name = name or func.__name__
121122
NodeClass.__doc__ = func.__doc__
122123
NodeClass.is_start = start
123124
NodeClass.is_end = end
@@ -132,14 +133,15 @@ def __call__(
132133
start: bool = False,
133134
end: bool = False,
134135
stream: bool = False,
136+
name: str | None = None,
135137
):
136138
# We must wrap the call to the decorator in a function for it to work
137139
# correctly with or without parenthesis
138-
def wrap(func: Callable, start=start, end=end, stream=stream) -> Callable:
139-
return self._node(func=func, start=start, end=end, stream=stream)
140+
def wrap(func: Callable, start=start, end=end, stream=stream, name=name) -> Callable:
141+
return self._node(func=func, start=start, end=end, stream=stream, name=name)
140142
if func:
141143
# Decorator is called without parenthesis
142-
return wrap(func=func, start=start, end=end, stream=stream)
144+
return wrap(func=func, start=start, end=end, stream=stream, name=name)
143145
# Decorator is called with parenthesis
144146
return wrap
145147

tests/unit/test_graph.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,3 +45,45 @@ async def node_end(input: str):
4545

4646
response = await graph.execute(input={"input": "ABC"})
4747
assert response == {"input": "ABCDEFG"}
48+
49+
@pytest.mark.asyncio
50+
async def test_graph_with_name(self):
51+
@node(start=True, name="start")
52+
async def node_start(input: str):
53+
"""Start node"""
54+
return {"input": input+"D"}
55+
56+
@node(name="a")
57+
async def node_a(input: str):
58+
"""Node A"""
59+
return {"input": input+"E"}
60+
61+
assert node_a.name == "a"
62+
63+
@node(name="b")
64+
async def node_b(input: str):
65+
"""Node B"""
66+
return {"input": input+"F"}
67+
68+
@node(end=True, name="end")
69+
async def node_end(input: str):
70+
"""End node"""
71+
return {"input": input+"G"}
72+
73+
graph = Graph()
74+
75+
nodes = [
76+
(node_start, "start"),
77+
(node_a, "a"),
78+
(node_b, "b"),
79+
(node_end, "end")
80+
]
81+
82+
for i, (node_fn, name) in enumerate(nodes):
83+
graph.add_node(node_fn)
84+
if i > 0:
85+
# add each edge using the name only
86+
graph.add_edge(nodes[i-1][1], name)
87+
88+
response = await graph.execute(input={"input": "ABC"})
89+
assert response == {"input": "ABCDEFG"}

0 commit comments

Comments
 (0)