22Module for creating the base graphs
33 """
44import time
5+ import warnings
56from langchain_community .callbacks import get_openai_callback
67
78
@@ -10,31 +11,37 @@ class BaseGraph:
1011 BaseGraph manages the execution flow of a graph composed of interconnected nodes.
1112
1213 Attributes:
13- nodes (dict ): A dictionary mapping each node's name to its corresponding node instance.
14- edges (dict ): A dictionary representing the directed edges of the graph where each
14+ nodes (list ): A dictionary mapping each node's name to its corresponding node instance.
15+ edges (list ): A dictionary representing the directed edges of the graph where each
1516 key-value pair corresponds to the from-node and to-node relationship.
1617 entry_point (str): The name of the entry point node from which the graph execution begins.
1718
1819 Methods:
19- execute(initial_state): Executes the graph's nodes starting from the entry point and
20+ execute(initial_state): Executes the graph's nodes starting from the entry point and
2021 traverses the graph based on the provided initial state.
2122
2223 Args:
2324 nodes (iterable): An iterable of node instances that will be part of the graph.
24- edges (iterable): An iterable of tuples where each tuple represents a directed edge
25+ edges (iterable): An iterable of tuples where each tuple represents a directed edge
2526 in the graph, defined by a pair of nodes (from_node, to_node).
2627 entry_point (BaseNode): The node instance that represents the entry point of the graph.
2728 """
2829
29- def __init__ (self , nodes : dict , edges : dict , entry_point : str ):
30+ def __init__ (self , nodes : list , edges : list , entry_point : str ):
3031 """
3132 Initializes the graph with nodes, edges, and the entry point.
3233 """
33- self .nodes = {node .node_name : node for node in nodes }
34- self .edges = self ._create_edges (edges )
34+
35+ self .nodes = nodes
36+ self .edges = self ._create_edges ({e for e in edges })
3537 self .entry_point = entry_point .node_name
3638
37- def _create_edges (self , edges : dict ) -> dict :
39+ if nodes [0 ].node_name != entry_point .node_name :
40+ # raise a warning if the entry point is not the first node in the list
41+ warnings .warn (
42+ "Careful! The entry point node is different from the first node if the graph." )
43+
44+ def _create_edges (self , edges : list ) -> dict :
3845 """
3946 Helper method to create a dictionary of edges from the given iterable of tuples.
4047
@@ -51,8 +58,8 @@ def _create_edges(self, edges: dict) -> dict:
5158
5259 def execute (self , initial_state : dict ) -> dict :
5360 """
54- Executes the graph by traversing nodes starting from the entry point. The execution
55- follows the edges based on the result of each node's execution and continues until
61+ Executes the graph by traversing nodes starting from the entry point. The execution
62+ follows the edges based on the result of each node's execution and continues until
5663 it reaches a node with no outgoing edges.
5764
5865 Args:
@@ -61,7 +68,8 @@ def execute(self, initial_state: dict) -> dict:
6168 Returns:
6269 dict: The state after execution has completed, which may have been altered by the nodes.
6370 """
64- current_node_name = self .entry_point
71+ print (self .nodes )
72+ current_node_name = self .nodes [0 ]
6573 state = initial_state
6674
6775 # variables for tracking execution info
@@ -75,10 +83,10 @@ def execute(self, initial_state: dict) -> dict:
7583 "total_cost_USD" : 0.0 ,
7684 }
7785
78- while current_node_name is not None :
86+ for index in self . nodes :
7987
8088 curr_time = time .time ()
81- current_node = self . nodes [ current_node_name ]
89+ current_node = index
8290
8391 with get_openai_callback () as cb :
8492 result = current_node .execute (state )
0 commit comments