44from graphlib import TopologicalSorter , CycleError
55from graphai .callback import Callback
66from graphai .utils import logger
7- import asyncio
87
98
109# to fix mypy error
@@ -71,16 +70,34 @@ def __init__(
7170
7271 # Allow getting and setting the graph's internal state
7372 def get_state (self ) -> dict [str , Any ]:
74- """Get the current graph state."""
73+ """Get the current graph state.
74+
75+ Returns:
76+ The current graph state.
77+ """
7578 return self .state
7679
7780 def set_state (self , state : dict [str , Any ]) -> Graph :
78- """Set the graph state."""
81+ """Set the graph state.
82+
83+ Args:
84+ state: The new state to set for the graph.
85+
86+ Returns:
87+ The graph instance.
88+ """
7989 self .state = state
8090 return self
8191
8292 def update_state (self , values : dict [str , Any ]) -> Graph :
83- """Update the graph state with new values."""
93+ """Update the graph state with new values.
94+
95+ Args:
96+ values: The new values to update the graph state with.
97+
98+ Returns:
99+ The graph instance.
100+ """
84101 self .state .update (values )
85102 return self
86103
@@ -90,6 +107,14 @@ def reset_state(self) -> Graph:
90107 return self
91108
92109 def add_node (self , node : NodeProtocol ) -> Graph :
110+ """Adds a node to the graph.
111+
112+ Args:
113+ node: The node to add to the graph.
114+
115+ Raises:
116+ Exception: If a node with the same name already exists in the graph.
117+ """
93118 if node .name in self .nodes :
94119 raise Exception (f"Node with name '{ node .name } ' already exists." )
95120 self .nodes [node .name ] = node
@@ -153,6 +178,14 @@ def add_router(
153178 router : NodeProtocol ,
154179 destinations : list [NodeProtocol ],
155180 ) -> Graph :
181+ """Adds a router node, allowing for a decision to be made on which branch to
182+ follow based on the `choice` output of the router node.
183+
184+ Args:
185+ sources: The list of source nodes for the router.
186+ router: The router node.
187+ destinations: The list of destination nodes for the router.
188+ """
156189 if not router .is_router :
157190 raise TypeError ("A router object must be passed to the router parameter." )
158191 [self .add_edge (source , router ) for source in sources ]
@@ -169,8 +202,7 @@ def set_end_node(self, node: NodeProtocol) -> Graph:
169202 return self
170203
171204 def compile (self , * , strict : bool = False ) -> Graph :
172- """
173- Validate the graph:
205+ """Validate the graph:
174206 - exactly one start node present (or Graph.start_node set)
175207 - at least one end node present
176208 - all edges reference known nodes
0 commit comments