22GraphIterator Module
33"""
44
5- from typing import List , Optional
5+ import asyncio
66import copy
7- from tqdm import tqdm
7+ from typing import List , Optional
8+
9+ from tqdm .asyncio import tqdm
10+
811from .base_node import BaseNode
912
1013
14+ _default_batchsize = 16
15+
16+
1117class GraphIteratorNode (BaseNode ):
1218 """
1319 A node responsible for instantiating and running multiple graph instances in parallel.
@@ -23,12 +29,20 @@ class GraphIteratorNode(BaseNode):
2329 node_name (str): The unique identifier name for the node, defaulting to "Parse".
2430 """
2531
26- def __init__ (self , input : str , output : List [str ], node_config : Optional [dict ]= None , node_name : str = "GraphIterator" ):
32+ def __init__ (
33+ self ,
34+ input : str ,
35+ output : List [str ],
36+ node_config : Optional [dict ] = None ,
37+ node_name : str = "GraphIterator" ,
38+ ):
2739 super ().__init__ (node_name , "node" , input , output , 2 , node_config )
2840
29- self .verbose = False if node_config is None else node_config .get ("verbose" , False )
41+ self .verbose = (
42+ False if node_config is None else node_config .get ("verbose" , False )
43+ )
3044
31- def execute (self , state : dict ) -> dict :
45+ def execute (self , state : dict ) -> dict :
3246 """
3347 Executes the node's logic to instantiate and run multiple graph instances in parallel.
3448
@@ -43,37 +57,78 @@ def execute(self, state: dict) -> dict:
4357 KeyError: If the input keys are not found in the state, indicating that the
4458 necessary information for running the graph instances is missing.
4559 """
60+ batchsize = self .node_config .get ("batchsize" , _default_batchsize )
4661
4762 if self .verbose :
48- print (f"--- Executing { self .node_name } Node ---" )
63+ print (f"--- Executing { self .node_name } Node with batchsize { batchsize } ---" )
64+
65+ try :
66+ eventloop = asyncio .get_event_loop ()
67+ except RuntimeError :
68+ eventloop = None
69+
70+ if eventloop and eventloop .is_running ():
71+ state = eventloop .run_until_complete (self ._async_execute (state , batchsize ))
72+ else :
73+ state = asyncio .run (self ._async_execute (state , batchsize ))
74+
75+ return state
76+
77+ async def _async_execute (self , state : dict , batchsize : int ) -> dict :
78+ """asynchronously executes the node's logic with multiple graph instances
79+ running in parallel, using a semaphore of some size for concurrency regulation
80+
81+ Args:
82+ state: The current state of the graph.
83+ batchsize: The maximum number of concurrent instances allowed.
84+
85+ Returns:
86+ The updated state with the output key containing the results
87+ aggregated out of all parallel graph instances.
4988
50- # Interpret input keys based on the provided input expression
89+ Raises:
90+ KeyError: If the input keys are not found in the state.
91+ """
92+
93+ # interprets input keys based on the provided input expression
5194 input_keys = self .get_input_keys (state )
5295
53- # Fetching data from the state based on the input keys
96+ # fetches data from the state based on the input keys
5497 input_data = [state [key ] for key in input_keys ]
5598
5699 user_prompt = input_data [0 ]
57100 urls = input_data [1 ]
58101
59102 graph_instance = self .node_config .get ("graph_instance" , None )
103+
60104 if graph_instance is None :
61- raise ValueError ("Graph instance is required for graph iteration. " )
62-
63- # set the prompt and source for each url
105+ raise ValueError ("graph instance is required for concurrent execution " )
106+
107+ # sets the prompt for the graph instance
64108 graph_instance .prompt = user_prompt
65- graphs_instances = []
109+
110+ participants = []
111+
112+ # semaphore to limit the number of concurrent tasks
113+ semaphore = asyncio .Semaphore (batchsize )
114+
115+ async def _async_run (graph ):
116+ async with semaphore :
117+ return await asyncio .to_thread (graph .run )
118+
119+ # creates a deepcopy of the graph instance for each endpoint
66120 for url in urls :
67- # make a copy of the graph instance
68- copy_graph_instance = copy .copy (graph_instance )
69- copy_graph_instance .source = url
70- graphs_instances .append (copy_graph_instance )
71-
72- # run the graph for each url and use tqdm for progress bar
73- graphs_answers = []
74- for graph in tqdm (graphs_instances , desc = "Processing Graph Instances" , disable = not self .verbose ):
75- result = graph .run ()
76- graphs_answers .append (result )
77-
78- state .update ({self .output [0 ]: graphs_answers })
121+ instance = copy .copy (graph_instance )
122+ instance .source = url
123+
124+ participants .append (instance )
125+
126+ futures = [_async_run (graph ) for graph in participants ]
127+
128+ answers = await tqdm .gather (
129+ * futures , desc = "processing graph instances" , disable = not self .verbose
130+ )
131+
132+ state .update ({self .output [0 ]: answers })
133+
79134 return state
0 commit comments