66import warnings
77from langchain_community .callbacks import get_openai_callback
88from typing import Tuple
9+ from collections import deque
910
1011
1112class BaseGraph :
@@ -26,6 +27,8 @@ class BaseGraph:
2627
2728 Raises:
2829 Warning: If the entry point node is not the first node in the list.
30+ ValueError: If conditional_node does not have exactly two outgoing edges
31+
2932
3033 Example:
3134 >>> BaseGraph(
@@ -48,7 +51,7 @@ def __init__(self, nodes: list, edges: list, entry_point: str):
4851
4952 self .nodes = nodes
5053 self .edges = self ._create_edges ({e for e in edges })
51- self .entry_point = entry_point . node_name
54+ self .entry_point = entry_point
5255
5356 if nodes [0 ].node_name != entry_point .node_name :
5457 # raise a warning if the entry point is not the first node in the list
@@ -68,13 +71,16 @@ def _create_edges(self, edges: list) -> dict:
6871
6972 edge_dict = {}
7073 for from_node , to_node in edges :
71- edge_dict [from_node .node_name ] = to_node .node_name
74+ if from_node in edge_dict :
75+ edge_dict [from_node ].append (to_node )
76+ else :
77+ edge_dict [from_node ] = [to_node ]
7278 return edge_dict
7379
7480 def execute (self , initial_state : dict ) -> Tuple [dict , list ]:
7581 """
76- Executes the graph by traversing nodes starting from the entry point. The execution
77- follows the edges based on the result of each node's execution and continues until
82+ Executes the graph by traversing nodes in breadth-first order starting from the entry point.
83+ The execution follows the edges based on the result of each node's execution and continues until
7884 it reaches a node with no outgoing edges.
7985
8086 Args:
@@ -84,7 +90,6 @@ def execute(self, initial_state: dict) -> Tuple[dict, list]:
8490 Tuple[dict, list]: A tuple containing the final state and a list of execution info.
8591 """
8692
87- current_node_name = self .nodes [0 ]
8893 state = initial_state
8994
9095 # variables for tracking execution info
@@ -98,23 +103,22 @@ def execute(self, initial_state: dict) -> Tuple[dict, list]:
98103 "total_cost_USD" : 0.0 ,
99104 }
100105
101- for index in self .nodes :
102-
106+ queue = deque ([self .entry_point ])
107+ while queue :
108+ current_node = queue .popleft ()
103109 curr_time = time .time ()
104- current_node = index
105-
106- with get_openai_callback () as cb :
110+ with get_openai_callback () as callback :
107111 result = current_node .execute (state )
108112 node_exec_time = time .time () - curr_time
109113 total_exec_time += node_exec_time
110114
111115 cb = {
112- "node_name" : index .node_name ,
113- "total_tokens" : cb .total_tokens ,
114- "prompt_tokens" : cb .prompt_tokens ,
115- "completion_tokens" : cb .completion_tokens ,
116- "successful_requests" : cb .successful_requests ,
117- "total_cost_USD" : cb .total_cost ,
116+ "node_name" : current_node .node_name ,
117+ "total_tokens" : callback .total_tokens ,
118+ "prompt_tokens" : callback .prompt_tokens ,
119+ "completion_tokens" : callback .completion_tokens ,
120+ "successful_requests" : callback .successful_requests ,
121+ "total_cost_USD" : callback .total_cost ,
118122 "exec_time" : node_exec_time ,
119123 }
120124
@@ -128,21 +132,30 @@ def execute(self, initial_state: dict) -> Tuple[dict, list]:
128132 cb_total ["successful_requests" ] += cb ["successful_requests" ]
129133 cb_total ["total_cost_USD" ] += cb ["total_cost_USD" ]
130134
131- if current_node .node_type == "conditional_node" :
132- current_node_name = result
133- elif current_node_name in self .edges :
134- current_node_name = self .edges [current_node_name ]
135- else :
136- current_node_name = None
137-
138- exec_info .append ({
139- "node_name" : "TOTAL RESULT" ,
140- "total_tokens" : cb_total ["total_tokens" ],
141- "prompt_tokens" : cb_total ["prompt_tokens" ],
142- "completion_tokens" : cb_total ["completion_tokens" ],
143- "successful_requests" : cb_total ["successful_requests" ],
144- "total_cost_USD" : cb_total ["total_cost_USD" ],
145- "exec_time" : total_exec_time ,
146- })
135+ if current_node in self .edges :
136+ current_node_connections = self .edges [current_node ]
137+ if current_node .node_type == 'conditional_node' :
138+ # Assert that there are exactly two out edges from the conditional node
139+ if len (current_node_connections ) != 2 :
140+ raise ValueError (f"Conditional node should have exactly two out connections { current_node_connections .node_name } " )
141+ if result ["next_node" ] == 0 :
142+ queue .append (current_node_connections [0 ])
143+ else :
144+ queue .append (current_node_connections [1 ])
145+ # remove the conditional node result
146+ del result ["next_node" ]
147+ else :
148+ queue .extend (node for node in current_node_connections )
149+
150+
151+ exec_info .append ({
152+ "node_name" : "TOTAL RESULT" ,
153+ "total_tokens" : cb_total ["total_tokens" ],
154+ "prompt_tokens" : cb_total ["prompt_tokens" ],
155+ "completion_tokens" : cb_total ["completion_tokens" ],
156+ "successful_requests" : cb_total ["successful_requests" ],
157+ "total_cost_USD" : cb_total ["total_cost_USD" ],
158+ "exec_time" : total_exec_time ,
159+ })
147160
148161 return state , exec_info
0 commit comments