Skip to content

Commit 7943715

Browse files
committed
fix: add rate limiting, timing attack prevention, global timeout
1 parent c0aec6e commit 7943715

File tree

7 files changed

+492
-184
lines changed

7 files changed

+492
-184
lines changed

src/arbitrium/core/executor/async_executor.py

Lines changed: 161 additions & 125 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@ class ExecutionCancelledError(Exception):
1515
pass
1616

1717

18+
DEFAULT_GLOBAL_EXECUTION_TIMEOUT = 3600 # 1 hour default
19+
20+
1821
class AsyncExecutor(BaseExecutor):
1922
def __init__(
2023
self,
@@ -25,6 +28,10 @@ def __init__(
2528
self.broadcast_fn = broadcast_fn
2629
self._current_execution_id: str | None = None
2730
self._cancel_event: asyncio.Event = asyncio.Event()
31+
# Global execution timeout (prevents runaway executions)
32+
self.global_timeout = (config or {}).get(
33+
"global_timeout", DEFAULT_GLOBAL_EXECUTION_TIMEOUT
34+
)
2835

2936
def cancel(self) -> bool:
3037
if self._current_execution_id:
@@ -228,145 +235,41 @@ async def execute(
228235
self._cancel_event.clear()
229236

230237
logger.info(
231-
"[%s] === EXECUTION START === nodes=%d, edges=%d",
238+
"[%s] === EXECUTION START === nodes=%d, edges=%d, global_timeout=%ds",
232239
execution_id[:8],
233240
len(nodes),
234241
len(edges),
242+
self.global_timeout,
235243
)
236244

237245
self._report_execution_start(execution_id, len(nodes), len(edges))
238246

239247
try:
240-
(
241-
node_instances,
242-
_dependencies,
243-
connections,
244-
execution_layers,
245-
feedback_connections,
246-
) = self._build_execution_graph(nodes, edges)
247-
248-
has_feedback = len(feedback_connections) > 0
249-
max_iterations = 20
250-
251-
logger.info(
252-
"[%s] Execution plan: %d layers, %d feedback edges",
253-
execution_id[:8],
254-
len(execution_layers),
255-
len(feedback_connections),
256-
)
257-
258-
self._validate_workflow(node_instances)
259-
260-
context, node_outputs = self._initialize_execution_state(
261-
execution_id
248+
# Wrap entire execution in global timeout
249+
return await asyncio.wait_for(
250+
self._execute_workflow(nodes, edges, execution_id),
251+
timeout=self.global_timeout,
262252
)
263-
264-
iteration = 0
265-
while True:
266-
iteration += 1
267-
context.round_num = iteration
268-
self._check_cancelled()
269-
270-
if has_feedback:
271-
logger.info(
272-
"[%s] === ITERATION %d START ===",
273-
execution_id[:8],
274-
iteration,
275-
)
276-
await self._broadcast(
277-
{
278-
"type": "iteration_start",
279-
"execution_id": execution_id,
280-
"iteration": iteration,
281-
}
282-
)
283-
284-
for layer_index, layer in enumerate(execution_layers):
285-
self._check_cancelled()
286-
287-
logger.info(
288-
"[%s] Starting layer %d/%d with %d node(s): %s",
289-
execution_id[:8],
290-
layer_index,
291-
len(execution_layers) - 1,
292-
len(layer),
293-
layer,
294-
)
295-
296-
self._report_layer_start(
297-
execution_id, layer_index, len(execution_layers), layer
298-
)
299-
300-
tasks = []
301-
for node_id in layer:
302-
node = node_instances[node_id]
303-
inputs = self._gather_node_inputs(
304-
node_id, connections, node_outputs
305-
)
306-
for src, tgt, src_h, tgt_h in feedback_connections:
307-
if tgt == node_id and src in node_outputs:
308-
src_outputs = node_outputs[src]
309-
if src_h in src_outputs:
310-
inputs[tgt_h] = src_outputs[src_h]
311-
task = self._execute_single_node(
312-
node_id, node, inputs, context, execution_id
313-
)
314-
tasks.append((node_id, task))
315-
316-
results = await asyncio.gather(
317-
*[task for _, task in tasks], return_exceptions=True
318-
)
319-
320-
for (node_id, _), result in zip(
321-
tasks, results, strict=False
322-
):
323-
if isinstance(result, ExecutionCancelledError):
324-
raise result
325-
if isinstance(result, BaseException):
326-
raise result
327-
node_outputs[node_id] = result
328-
329-
if not has_feedback:
330-
break
331-
332-
done = False
333-
for node_id, outputs in node_outputs.items():
334-
if outputs.get("done") is True:
335-
done = True
336-
logger.info(
337-
"[%s] Termination signal from node %s",
338-
execution_id[:8],
339-
node_id,
340-
)
341-
break
342-
343-
if done:
344-
break
345-
346-
if iteration >= max_iterations:
347-
logger.warning(
348-
"[%s] Max iterations (%d) reached",
349-
execution_id[:8],
350-
max_iterations,
351-
)
352-
break
353-
354-
context.node_outputs = node_outputs
355-
356-
logger.info(
357-
"[%s] === EXECUTION COMPLETE === iterations=%d, outputs from %d nodes",
253+
except TimeoutError:
254+
logger.error(
255+
"[%s] === GLOBAL TIMEOUT === execution exceeded %ds",
358256
execution_id[:8],
359-
iteration,
360-
len(node_outputs),
257+
self.global_timeout,
258+
)
259+
await self._broadcast(
260+
{
261+
"type": "execution_error",
262+
"execution_id": execution_id,
263+
"error": {
264+
"type": "GlobalTimeoutError",
265+
"message": f"Execution timed out after {self.global_timeout}s",
266+
},
267+
}
361268
)
362-
363-
self._report_execution_complete(execution_id, len(node_outputs))
364-
365269
return {
366270
"execution_id": execution_id,
367-
"outputs": node_outputs,
271+
"error": f"Execution timed out after {self.global_timeout}s",
368272
}
369-
370273
except ExecutionCancelledError:
371274
logger.info("[%s] Execution cancelled", execution_id[:8])
372275
await self._broadcast(
@@ -389,3 +292,136 @@ async def execute(
389292
finally:
390293
self._current_execution_id = None
391294
self._cancel_event.clear()
295+
296+
async def _execute_workflow(
297+
self,
298+
nodes: list[dict[str, Any]],
299+
edges: list[dict[str, Any]],
300+
execution_id: str,
301+
) -> dict[str, Any]:
302+
"""Core workflow execution logic, extracted for timeout wrapping."""
303+
(
304+
node_instances,
305+
_dependencies,
306+
connections,
307+
execution_layers,
308+
feedback_connections,
309+
) = self._build_execution_graph(nodes, edges)
310+
311+
has_feedback = len(feedback_connections) > 0
312+
max_iterations = 20
313+
314+
logger.info(
315+
"[%s] Execution plan: %d layers, %d feedback edges",
316+
execution_id[:8],
317+
len(execution_layers),
318+
len(feedback_connections),
319+
)
320+
321+
self._validate_workflow(node_instances)
322+
323+
context, node_outputs = self._initialize_execution_state(execution_id)
324+
325+
iteration = 0
326+
while True:
327+
iteration += 1
328+
context.round_num = iteration
329+
self._check_cancelled()
330+
331+
if has_feedback:
332+
logger.info(
333+
"[%s] === ITERATION %d START ===",
334+
execution_id[:8],
335+
iteration,
336+
)
337+
await self._broadcast(
338+
{
339+
"type": "iteration_start",
340+
"execution_id": execution_id,
341+
"iteration": iteration,
342+
}
343+
)
344+
345+
for layer_index, layer in enumerate(execution_layers):
346+
self._check_cancelled()
347+
348+
logger.info(
349+
"[%s] Starting layer %d/%d with %d node(s): %s",
350+
execution_id[:8],
351+
layer_index,
352+
len(execution_layers) - 1,
353+
len(layer),
354+
layer,
355+
)
356+
357+
self._report_layer_start(
358+
execution_id, layer_index, len(execution_layers), layer
359+
)
360+
361+
tasks = []
362+
for node_id in layer:
363+
node = node_instances[node_id]
364+
inputs = self._gather_node_inputs(
365+
node_id, connections, node_outputs
366+
)
367+
for src, tgt, src_h, tgt_h in feedback_connections:
368+
if tgt == node_id and src in node_outputs:
369+
src_outputs = node_outputs[src]
370+
if src_h in src_outputs:
371+
inputs[tgt_h] = src_outputs[src_h]
372+
task = self._execute_single_node(
373+
node_id, node, inputs, context, execution_id
374+
)
375+
tasks.append((node_id, task))
376+
377+
results = await asyncio.gather(
378+
*[task for _, task in tasks], return_exceptions=True
379+
)
380+
381+
for (node_id, _), result in zip(tasks, results, strict=False):
382+
if isinstance(result, ExecutionCancelledError):
383+
raise result
384+
if isinstance(result, BaseException):
385+
raise result
386+
node_outputs[node_id] = result
387+
388+
if not has_feedback:
389+
break
390+
391+
done = False
392+
for node_id, outputs in node_outputs.items():
393+
if outputs.get("done") is True:
394+
done = True
395+
logger.info(
396+
"[%s] Termination signal from node %s",
397+
execution_id[:8],
398+
node_id,
399+
)
400+
break
401+
402+
if done:
403+
break
404+
405+
if iteration >= max_iterations:
406+
logger.warning(
407+
"[%s] Max iterations (%d) reached",
408+
execution_id[:8],
409+
max_iterations,
410+
)
411+
break
412+
413+
context.node_outputs = node_outputs
414+
415+
logger.info(
416+
"[%s] === EXECUTION COMPLETE === iterations=%d, outputs from %d nodes",
417+
execution_id[:8],
418+
iteration,
419+
len(node_outputs),
420+
)
421+
422+
self._report_execution_complete(execution_id, len(node_outputs))
423+
424+
return {
425+
"execution_id": execution_id,
426+
"outputs": node_outputs,
427+
}

0 commit comments

Comments
 (0)