1+ """
2+ base_graph module
3+ """
14import time
25import warnings
3- from langchain_community .callbacks import get_openai_callback
46from typing import Tuple
7+ from langchain_community .callbacks import get_openai_callback
8+ from ..integrations import BurrBridge
59
610# Import telemetry functions
711from ..telemetry import log_graph_execution , log_event
@@ -56,7 +60,7 @@ def __init__(self, nodes: list, edges: list, entry_point: str, use_burr: bool =
5660 # raise a warning if the entry point is not the first node in the list
5761 warnings .warn (
5862 "Careful! The entry point node is different from the first node in the graph." )
59-
63+
6064 # Burr configuration
6165 self .use_burr = use_burr
6266 self .burr_config = burr_config or {}
@@ -79,7 +83,8 @@ def _create_edges(self, edges: list) -> dict:
7983
8084 def _execute_standard (self , initial_state : dict ) -> Tuple [dict , list ]:
8185 """
82- Executes the graph by traversing nodes starting from the entry point using the standard method.
86+ Executes the graph by traversing nodes starting from the
87+ entry point using the standard method.
8388
8489 Args:
8590 initial_state (dict): The initial state to pass to the entry point node.
@@ -114,23 +119,25 @@ def _execute_standard(self, initial_state: dict) -> Tuple[dict, list]:
114119 curr_time = time .time ()
115120 current_node = next (node for node in self .nodes if node .node_name == current_node_name )
116121
117-
118122 # check if there is a "source" key in the node config
119123 if current_node .__class__ .__name__ == "FetchNode" :
120124 # get the second key name of the state dictionary
121125 source_type = list (state .keys ())[1 ]
122126 if state .get ("user_prompt" , None ):
123- prompt = state ["user_prompt" ] if type (state ["user_prompt" ]) == str else None
124- # quick fix for local_dir source type
127+ # Set 'prompt' if 'user_prompt' is a string, otherwise None
128+ prompt = state ["user_prompt" ] if isinstance (state ["user_prompt" ], str ) else None
129+
130+ # Convert 'local_dir' source type to 'html_dir'
125131 if source_type == "local_dir" :
126132 source_type = "html_dir"
127133 elif source_type == "url" :
128- if type ( state [ source_type ]) == list :
129- # iterate through the list of urls and see if they are strings
134+ # If the source is a list, add string URLs to 'source'
135+ if isinstance ( state [ source_type ], list ):
130136 for url in state [source_type ]:
131- if type (url ) == str :
137+ if isinstance (url , str ) :
132138 source .append (url )
133- elif type (state [source_type ]) == str :
139+ # If the source is a single string, add it to 'source'
140+ elif isinstance (state [source_type ], str ):
134141 source .append (state [source_type ])
135142
136143 # check if there is an "llm_model" variable in the class
@@ -164,7 +171,6 @@ def _execute_standard(self, initial_state: dict) -> Tuple[dict, list]:
164171 result = current_node .execute (state )
165172 except Exception as e :
166173 error_node = current_node .node_name
167-
168174 graph_execution_time = time .time () - start_time
169175 log_graph_execution (
170176 graph_name = self .graph_name ,
@@ -221,7 +227,7 @@ def _execute_standard(self, initial_state: dict) -> Tuple[dict, list]:
221227 graph_execution_time = time .time () - start_time
222228 response = state .get ("answer" , None ) if source_type == "url" else None
223229 content = state .get ("parsed_doc" , None ) if response is not None else None
224-
230+
225231 log_graph_execution (
226232 graph_name = self .graph_name ,
227233 source = source ,
@@ -251,26 +257,25 @@ def execute(self, initial_state: dict) -> Tuple[dict, list]:
251257
252258 self .initial_state = initial_state
253259 if self .use_burr :
254- from ..integrations import BurrBridge
255-
260+
256261 bridge = BurrBridge (self , self .burr_config )
257262 result = bridge .execute (initial_state )
258263 return (result ["_state" ], [])
259264 else :
260265 return self ._execute_standard (initial_state )
261-
266+
262267 def append_node (self , node ):
263268 """
264269 Adds a node to the graph.
265270
266271 Args:
267272 node (BaseNode): The node instance to add to the graph.
268273 """
269-
274+
270275 # if node name already exists in the graph, raise an exception
271276 if node .node_name in {n .node_name for n in self .nodes }:
272277 raise ValueError (f"Node with name '{ node .node_name } ' already exists in the graph. You can change it by setting the 'node_name' attribute." )
273-
278+
274279 # get the last node in the list
275280 last_node = self .nodes [- 1 ]
276281 # add the edge connecting the last node to the new node
0 commit comments