Skip to content

Commit dab1a48

Browse files
authored
yet another way to kill timedout jobs (#108)
1 parent 6684e3d commit dab1a48

File tree

3 files changed

+71
-22
lines changed

3 files changed

+71
-22
lines changed

src/agentlab/experiments/exp_utils.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,11 @@
2727

2828
def run_exp(exp_arg: ExpArgs, *dependencies, avg_step_timeout=60):
2929
"""Run exp_args.run() with a timeout and handle dependencies."""
30-
episode_timeout = _episode_timeout(exp_arg, avg_step_timeout=avg_step_timeout)
31-
with timeout_manager(seconds=episode_timeout):
32-
return exp_arg.run()
30+
# episode_timeout = _episode_timeout(exp_arg, avg_step_timeout=avg_step_timeout)
31+
# logger.warning(f"Running {exp_arg.exp_id} with timeout of {episode_timeout} seconds.")
32+
# with timeout_manager(seconds=episode_timeout):
33+
# this timeout method is not robust enough. using ray.cancel instead
34+
return exp_arg.run()
3335

3436

3537
def _episode_timeout(exp_arg: ExpArgs, avg_step_timeout=60):
@@ -62,13 +64,12 @@ def timeout_manager(seconds: int = None):
6264

6365
def alarm_handler(signum, frame):
6466

65-
logger.warning(
66-
f"Operation timed out after {seconds}s, sending SIGINT and raising TimeoutError."
67-
)
67+
logger.warning(f"Operation timed out after {seconds}s, raising TimeoutError.")
6868
# send sigint
69-
os.kill(os.getpid(), signal.SIGINT)
69+
# os.kill(os.getpid(), signal.SIGINT) # this doesn't seem to do much I don't know why
7070

7171
# Still raise TimeoutError for immediate handling
72+
# This works, but it doesn't seem enough to kill the job
7273
raise TimeoutError(f"Operation timed out after {seconds} seconds")
7374

7475
previous_handler = signal.signal(signal.SIGALRM, alarm_handler)

src/agentlab/experiments/graph_execution_ray.py

Lines changed: 57 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,14 @@
22

33
# # Disable Ray log deduplication
44
# os.environ["RAY_DEDUP_LOGS"] = "0"
5-
5+
import time
66
import ray
77
import bgym
8-
from agentlab.experiments.exp_utils import run_exp
8+
from agentlab.experiments.exp_utils import run_exp, _episode_timeout
9+
from ray.util import state
10+
import logging
911

12+
logger = logging.getLogger(__name__)
1013

1114
run_exp = ray.remote(run_exp)
1215

@@ -15,25 +18,70 @@ def execute_task_graph(exp_args_list: list[bgym.ExpArgs], avg_step_timeout=60):
1518
"""Execute a task graph in parallel while respecting dependencies using Ray."""
1619

1720
exp_args_map = {exp_args.exp_id: exp_args for exp_args in exp_args_list}
18-
tasks = {}
21+
task_map = {}
1922

2023
def get_task(exp_arg: bgym.ExpArgs):
21-
if exp_arg.exp_id not in tasks:
24+
if exp_arg.exp_id not in task_map:
2225
# Get all dependency tasks first
2326
dependency_tasks = [get_task(exp_args_map[dep_key]) for dep_key in exp_arg.depends_on]
2427

2528
# Create new task that depends on the dependency results
26-
tasks[exp_arg.exp_id] = run_exp.remote(
29+
task_map[exp_arg.exp_id] = run_exp.remote(
2730
exp_arg, *dependency_tasks, avg_step_timeout=avg_step_timeout
2831
)
29-
return tasks[exp_arg.exp_id]
32+
return task_map[exp_arg.exp_id]
3033

3134
# Build task graph
3235
for exp_arg in exp_args_list:
3336
get_task(exp_arg)
3437

35-
# Execute all tasks and gather results
38+
max_timeout = max([_episode_timeout(exp_args, avg_step_timeout) for exp_args in exp_args_list])
39+
return poll_for_timeout(task_map, max_timeout, poll_interval=max_timeout * 0.1)
40+
41+
42+
def poll_for_timeout(tasks: dict[str, ray.ObjectRef], timeout: float, poll_interval: float = 1.0):
43+
"""Cancel tasks that exceeds the timeout
44+
45+
I tried various different methods for killing a job that hangs. so far it's
46+
the only one that seems to work reliably (hopefully)
47+
"""
48+
task_list = list(tasks.values())
3649
task_ids = list(tasks.keys())
37-
results = ray.get(list(tasks.values()))
3850

39-
return {task_id: result for task_id, result in zip(task_ids, results)}
51+
logger.warning(f"Any task exceeding {timeout} seconds will be cancelled.")
52+
53+
while True:
54+
ready, not_ready = ray.wait(task_list, num_returns=len(task_list), timeout=poll_interval)
55+
for task in not_ready:
56+
elapsed_time = get_elapsed_time(task)
57+
# print(f"Task {task.task_id().hex()} elapsed time: {elapsed_time}")
58+
if elapsed_time is not None and elapsed_time > timeout:
59+
msg = f"Task {task.task_id().hex()} hase been running for {elapsed_time}s, more than the timeout: {timeout}s."
60+
if elapsed_time < timeout + 60:
61+
logger.warning(msg + " Cancelling task.")
62+
ray.cancel(task, force=False, recursive=False)
63+
else:
64+
logger.warning(msg + " Force killing.")
65+
ray.cancel(task, force=True, recursive=False)
66+
if len(ready) == len(task_list):
67+
results = []
68+
for task in ready:
69+
try:
70+
result = ray.get(task)
71+
except Exception as e:
72+
result = e
73+
results.append(result)
74+
75+
return {task_id: result for task_id, result in zip(task_ids, results)}
76+
77+
78+
def get_elapsed_time(task_ref: ray.ObjectRef):
79+
task_id = task_ref.task_id().hex()
80+
task_info = state.get_task(task_id, address="auto")
81+
if task_info and task_info.start_time_ms is not None:
82+
start_time_s = task_info.start_time_ms / 1000.0 # Convert ms to s
83+
current_time_s = time.time()
84+
elapsed_time = current_time_s - start_time_s
85+
return elapsed_time
86+
else:
87+
return None # Task has not started yet

tests/experiments/test_launch_exp.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import math
12
import tempfile
23
from pathlib import Path
34

@@ -63,19 +64,18 @@ def _test_launch_system(backend="ray", cause_timeout=False):
6364
if row.stack_trace is not None:
6465
print(row.stack_trace)
6566
if cause_timeout:
66-
assert row.err_msg is not None
67-
assert "Timeout" in row.err_msg
68-
assert row.cum_reward == 0
67+
# assert row.err_msg is not None
68+
assert math.isnan(row.cum_reward) or row.cum_reward == 0
6969
else:
7070
assert row.err_msg is None
7171
assert row.cum_reward == 1.0
7272

7373
study_summary = inspect_results.summarize_study(results_df)
7474
assert len(study_summary) == 1
7575
assert study_summary.std_err.iloc[0] == 0
76-
assert study_summary.n_completed.iloc[0] == "3/3"
7776

7877
if not cause_timeout:
78+
assert study_summary.n_completed.iloc[0] == "3/3"
7979
assert study_summary.avg_reward.iloc[0] == 1.0
8080

8181

@@ -91,7 +91,7 @@ def test_launch_system_ray():
9191
_test_launch_system(backend="ray")
9292

9393

94-
def _test_timeout_ray():
94+
def test_timeout_ray():
9595
_test_launch_system(backend="ray", cause_timeout=True)
9696

9797

@@ -120,7 +120,7 @@ def test_4o_mini_on_miniwob_tiny_test():
120120

121121

122122
if __name__ == "__main__":
123-
_test_timeout_ray()
123+
test_timeout_ray()
124124
# test_4o_mini_on_miniwob_tiny_test()
125125
# test_launch_system_ray()
126126
# test_launch_system_sequntial()

0 commit comments

Comments
 (0)