66import warnings
77from langchain_community .callbacks import get_openai_callback
88from typing import Tuple
9- from collections import deque
109
1110
1211class BaseGraph :
@@ -27,8 +26,6 @@ class BaseGraph:
2726
2827 Raises:
2928 Warning: If the entry point node is not the first node in the list.
30- ValueError: If conditional_node does not have exactly two outgoing edges
31-
3229
3330 Example:
3431 >>> BaseGraph(
@@ -51,7 +48,7 @@ def __init__(self, nodes: list, edges: list, entry_point: str):
5148
5249 self .nodes = nodes
5350 self .edges = self ._create_edges ({e for e in edges })
54- self .entry_point = entry_point
51+ self .entry_point = entry_point . node_name
5552
5653 if nodes [0 ].node_name != entry_point .node_name :
5754 # raise a warning if the entry point is not the first node in the list
@@ -71,16 +68,13 @@ def _create_edges(self, edges: list) -> dict:
7168
7269 edge_dict = {}
7370 for from_node , to_node in edges :
74- if from_node in edge_dict :
75- edge_dict [from_node ].append (to_node )
76- else :
77- edge_dict [from_node ] = [to_node ]
71+ edge_dict [from_node .node_name ] = to_node .node_name
7872 return edge_dict
7973
8074 def execute (self , initial_state : dict ) -> Tuple [dict , list ]:
8175 """
82- Executes the graph by traversing nodes in breadth-first order starting from the entry point.
83- The execution follows the edges based on the result of each node's execution and continues until
76+ Executes the graph by traversing nodes starting from the entry point. The execution
77+ follows the edges based on the result of each node's execution and continues until
8478 it reaches a node with no outgoing edges.
8579
8680 Args:
@@ -90,6 +84,7 @@ def execute(self, initial_state: dict) -> Tuple[dict, list]:
9084 Tuple[dict, list]: A tuple containing the final state and a list of execution info.
9185 """
9286
87+ current_node_name = self .nodes [0 ]
9388 state = initial_state
9489
9590 # variables for tracking execution info
@@ -103,22 +98,23 @@ def execute(self, initial_state: dict) -> Tuple[dict, list]:
10398 "total_cost_USD" : 0.0 ,
10499 }
105100
106- queue = deque ([self .entry_point ])
107- while queue :
108- current_node = queue .popleft ()
101+ for index in self .nodes :
102+
109103 curr_time = time .time ()
110- with get_openai_callback () as callback :
104+ current_node = index
105+
106+ with get_openai_callback () as cb :
111107 result = current_node .execute (state )
112108 node_exec_time = time .time () - curr_time
113109 total_exec_time += node_exec_time
114110
115111 cb = {
116- "node_name" : current_node .node_name ,
117- "total_tokens" : callback .total_tokens ,
118- "prompt_tokens" : callback .prompt_tokens ,
119- "completion_tokens" : callback .completion_tokens ,
120- "successful_requests" : callback .successful_requests ,
121- "total_cost_USD" : callback .total_cost ,
112+ "node_name" : index .node_name ,
113+ "total_tokens" : cb .total_tokens ,
114+ "prompt_tokens" : cb .prompt_tokens ,
115+ "completion_tokens" : cb .completion_tokens ,
116+ "successful_requests" : cb .successful_requests ,
117+ "total_cost_USD" : cb .total_cost ,
122118 "exec_time" : node_exec_time ,
123119 }
124120
@@ -132,31 +128,21 @@ def execute(self, initial_state: dict) -> Tuple[dict, list]:
132128 cb_total ["successful_requests" ] += cb ["successful_requests" ]
133129 cb_total ["total_cost_USD" ] += cb ["total_cost_USD" ]
134130
135-
136-
137- current_node_connections = self .edges [current_node ]
138- if current_node .node_type == 'conditional_node' :
139- # Assert that there are exactly two out edges from the conditional node
140- if len (current_node_connections ) != 2 :
141- raise ValueError (f"Conditional node should have exactly two out connections { current_node_connections .node_name } " )
142- if result ["next_node" ] == 0 :
143- queue .append (current_node_connections [0 ])
144- else :
145- queue .append (current_node_connections [1 ])
146- # remove the conditional node result
147- del result ["next_node" ]
148- else :
149- queue .extend (node for node in current_node_connections )
150-
151-
152- exec_info .append ({
153- "node_name" : "TOTAL RESULT" ,
154- "total_tokens" : cb_total ["total_tokens" ],
155- "prompt_tokens" : cb_total ["prompt_tokens" ],
156- "completion_tokens" : cb_total ["completion_tokens" ],
157- "successful_requests" : cb_total ["successful_requests" ],
158- "total_cost_USD" : cb_total ["total_cost_USD" ],
159- "exec_time" : total_exec_time ,
160- })
161-
162- return state , exec_info
131+ if current_node .node_type == "conditional_node" :
132+ current_node_name = result
133+ elif current_node_name in self .edges :
134+ current_node_name = self .edges [current_node_name ]
135+ else :
136+ current_node_name = None
137+
138+ exec_info .append ({
139+ "node_name" : "TOTAL RESULT" ,
140+ "total_tokens" : cb_total ["total_tokens" ],
141+ "prompt_tokens" : cb_total ["prompt_tokens" ],
142+ "completion_tokens" : cb_total ["completion_tokens" ],
143+ "successful_requests" : cb_total ["successful_requests" ],
144+ "total_cost_USD" : cb_total ["total_cost_USD" ],
145+ "exec_time" : total_exec_time ,
146+ })
147+
148+ return state , exec_info
0 commit comments