@@ -15,6 +15,9 @@ class ExecutionCancelledError(Exception):
1515 pass
1616
1717
18+ DEFAULT_GLOBAL_EXECUTION_TIMEOUT = 3600 # 1 hour default
19+
20+
1821class AsyncExecutor (BaseExecutor ):
1922 def __init__ (
2023 self ,
@@ -25,6 +28,10 @@ def __init__(
2528 self .broadcast_fn = broadcast_fn
2629 self ._current_execution_id : str | None = None
2730 self ._cancel_event : asyncio .Event = asyncio .Event ()
31+ # Global execution timeout (prevents runaway executions)
32+ self .global_timeout = (config or {}).get (
33+ "global_timeout" , DEFAULT_GLOBAL_EXECUTION_TIMEOUT
34+ )
2835
2936 def cancel (self ) -> bool :
3037 if self ._current_execution_id :
@@ -228,145 +235,41 @@ async def execute(
228235 self ._cancel_event .clear ()
229236
230237 logger .info (
231- "[%s] === EXECUTION START === nodes=%d, edges=%d" ,
238+ "[%s] === EXECUTION START === nodes=%d, edges=%d, global_timeout=%ds " ,
232239 execution_id [:8 ],
233240 len (nodes ),
234241 len (edges ),
242+ self .global_timeout ,
235243 )
236244
237245 self ._report_execution_start (execution_id , len (nodes ), len (edges ))
238246
239247 try :
240- (
241- node_instances ,
242- _dependencies ,
243- connections ,
244- execution_layers ,
245- feedback_connections ,
246- ) = self ._build_execution_graph (nodes , edges )
247-
248- has_feedback = len (feedback_connections ) > 0
249- max_iterations = 20
250-
251- logger .info (
252- "[%s] Execution plan: %d layers, %d feedback edges" ,
253- execution_id [:8 ],
254- len (execution_layers ),
255- len (feedback_connections ),
256- )
257-
258- self ._validate_workflow (node_instances )
259-
260- context , node_outputs = self ._initialize_execution_state (
261- execution_id
248+ # Wrap entire execution in global timeout
249+ return await asyncio .wait_for (
250+ self ._execute_workflow (nodes , edges , execution_id ),
251+ timeout = self .global_timeout ,
262252 )
263-
264- iteration = 0
265- while True :
266- iteration += 1
267- context .round_num = iteration
268- self ._check_cancelled ()
269-
270- if has_feedback :
271- logger .info (
272- "[%s] === ITERATION %d START ===" ,
273- execution_id [:8 ],
274- iteration ,
275- )
276- await self ._broadcast (
277- {
278- "type" : "iteration_start" ,
279- "execution_id" : execution_id ,
280- "iteration" : iteration ,
281- }
282- )
283-
284- for layer_index , layer in enumerate (execution_layers ):
285- self ._check_cancelled ()
286-
287- logger .info (
288- "[%s] Starting layer %d/%d with %d node(s): %s" ,
289- execution_id [:8 ],
290- layer_index ,
291- len (execution_layers ) - 1 ,
292- len (layer ),
293- layer ,
294- )
295-
296- self ._report_layer_start (
297- execution_id , layer_index , len (execution_layers ), layer
298- )
299-
300- tasks = []
301- for node_id in layer :
302- node = node_instances [node_id ]
303- inputs = self ._gather_node_inputs (
304- node_id , connections , node_outputs
305- )
306- for src , tgt , src_h , tgt_h in feedback_connections :
307- if tgt == node_id and src in node_outputs :
308- src_outputs = node_outputs [src ]
309- if src_h in src_outputs :
310- inputs [tgt_h ] = src_outputs [src_h ]
311- task = self ._execute_single_node (
312- node_id , node , inputs , context , execution_id
313- )
314- tasks .append ((node_id , task ))
315-
316- results = await asyncio .gather (
317- * [task for _ , task in tasks ], return_exceptions = True
318- )
319-
320- for (node_id , _ ), result in zip (
321- tasks , results , strict = False
322- ):
323- if isinstance (result , ExecutionCancelledError ):
324- raise result
325- if isinstance (result , BaseException ):
326- raise result
327- node_outputs [node_id ] = result
328-
329- if not has_feedback :
330- break
331-
332- done = False
333- for node_id , outputs in node_outputs .items ():
334- if outputs .get ("done" ) is True :
335- done = True
336- logger .info (
337- "[%s] Termination signal from node %s" ,
338- execution_id [:8 ],
339- node_id ,
340- )
341- break
342-
343- if done :
344- break
345-
346- if iteration >= max_iterations :
347- logger .warning (
348- "[%s] Max iterations (%d) reached" ,
349- execution_id [:8 ],
350- max_iterations ,
351- )
352- break
353-
354- context .node_outputs = node_outputs
355-
356- logger .info (
357- "[%s] === EXECUTION COMPLETE === iterations=%d, outputs from %d nodes" ,
253+ except TimeoutError :
254+ logger .error (
255+ "[%s] === GLOBAL TIMEOUT === execution exceeded %ds" ,
358256 execution_id [:8 ],
359- iteration ,
360- len (node_outputs ),
257+ self .global_timeout ,
258+ )
259+ await self ._broadcast (
260+ {
261+ "type" : "execution_error" ,
262+ "execution_id" : execution_id ,
263+ "error" : {
264+ "type" : "GlobalTimeoutError" ,
265+ "message" : f"Execution timed out after { self .global_timeout } s" ,
266+ },
267+ }
361268 )
362-
363- self ._report_execution_complete (execution_id , len (node_outputs ))
364-
365269 return {
366270 "execution_id" : execution_id ,
367- "outputs " : node_outputs ,
271+ "error " : f"Execution timed out after { self . global_timeout } s" ,
368272 }
369-
370273 except ExecutionCancelledError :
371274 logger .info ("[%s] Execution cancelled" , execution_id [:8 ])
372275 await self ._broadcast (
@@ -389,3 +292,136 @@ async def execute(
389292 finally :
390293 self ._current_execution_id = None
391294 self ._cancel_event .clear ()
295+
296+ async def _execute_workflow (
297+ self ,
298+ nodes : list [dict [str , Any ]],
299+ edges : list [dict [str , Any ]],
300+ execution_id : str ,
301+ ) -> dict [str , Any ]:
302+ """Core workflow execution logic, extracted for timeout wrapping."""
303+ (
304+ node_instances ,
305+ _dependencies ,
306+ connections ,
307+ execution_layers ,
308+ feedback_connections ,
309+ ) = self ._build_execution_graph (nodes , edges )
310+
311+ has_feedback = len (feedback_connections ) > 0
312+ max_iterations = 20
313+
314+ logger .info (
315+ "[%s] Execution plan: %d layers, %d feedback edges" ,
316+ execution_id [:8 ],
317+ len (execution_layers ),
318+ len (feedback_connections ),
319+ )
320+
321+ self ._validate_workflow (node_instances )
322+
323+ context , node_outputs = self ._initialize_execution_state (execution_id )
324+
325+ iteration = 0
326+ while True :
327+ iteration += 1
328+ context .round_num = iteration
329+ self ._check_cancelled ()
330+
331+ if has_feedback :
332+ logger .info (
333+ "[%s] === ITERATION %d START ===" ,
334+ execution_id [:8 ],
335+ iteration ,
336+ )
337+ await self ._broadcast (
338+ {
339+ "type" : "iteration_start" ,
340+ "execution_id" : execution_id ,
341+ "iteration" : iteration ,
342+ }
343+ )
344+
345+ for layer_index , layer in enumerate (execution_layers ):
346+ self ._check_cancelled ()
347+
348+ logger .info (
349+ "[%s] Starting layer %d/%d with %d node(s): %s" ,
350+ execution_id [:8 ],
351+ layer_index ,
352+ len (execution_layers ) - 1 ,
353+ len (layer ),
354+ layer ,
355+ )
356+
357+ self ._report_layer_start (
358+ execution_id , layer_index , len (execution_layers ), layer
359+ )
360+
361+ tasks = []
362+ for node_id in layer :
363+ node = node_instances [node_id ]
364+ inputs = self ._gather_node_inputs (
365+ node_id , connections , node_outputs
366+ )
367+ for src , tgt , src_h , tgt_h in feedback_connections :
368+ if tgt == node_id and src in node_outputs :
369+ src_outputs = node_outputs [src ]
370+ if src_h in src_outputs :
371+ inputs [tgt_h ] = src_outputs [src_h ]
372+ task = self ._execute_single_node (
373+ node_id , node , inputs , context , execution_id
374+ )
375+ tasks .append ((node_id , task ))
376+
377+ results = await asyncio .gather (
378+ * [task for _ , task in tasks ], return_exceptions = True
379+ )
380+
381+ for (node_id , _ ), result in zip (tasks , results , strict = False ):
382+ if isinstance (result , ExecutionCancelledError ):
383+ raise result
384+ if isinstance (result , BaseException ):
385+ raise result
386+ node_outputs [node_id ] = result
387+
388+ if not has_feedback :
389+ break
390+
391+ done = False
392+ for node_id , outputs in node_outputs .items ():
393+ if outputs .get ("done" ) is True :
394+ done = True
395+ logger .info (
396+ "[%s] Termination signal from node %s" ,
397+ execution_id [:8 ],
398+ node_id ,
399+ )
400+ break
401+
402+ if done :
403+ break
404+
405+ if iteration >= max_iterations :
406+ logger .warning (
407+ "[%s] Max iterations (%d) reached" ,
408+ execution_id [:8 ],
409+ max_iterations ,
410+ )
411+ break
412+
413+ context .node_outputs = node_outputs
414+
415+ logger .info (
416+ "[%s] === EXECUTION COMPLETE === iterations=%d, outputs from %d nodes" ,
417+ execution_id [:8 ],
418+ iteration ,
419+ len (node_outputs ),
420+ )
421+
422+ self ._report_execution_complete (execution_id , len (node_outputs ))
423+
424+ return {
425+ "execution_id" : execution_id ,
426+ "outputs" : node_outputs ,
427+ }
0 commit comments