22
33# # Disable Ray log deduplication
44# os.environ["RAY_DEDUP_LOGS"] = "0"
5-
5+ import time
66import ray
77import bgym
8- from agentlab .experiments .exp_utils import run_exp
8+ from agentlab .experiments .exp_utils import run_exp , _episode_timeout
9+ from ray .util import state
10+ import logging
911
12+ logger = logging .getLogger (__name__ )
1013
1114run_exp = ray .remote (run_exp )
1215
@@ -15,25 +18,70 @@ def execute_task_graph(exp_args_list: list[bgym.ExpArgs], avg_step_timeout=60):
1518 """Execute a task graph in parallel while respecting dependencies using Ray."""
1619
1720 exp_args_map = {exp_args .exp_id : exp_args for exp_args in exp_args_list }
18- tasks = {}
21+ task_map = {}
1922
2023 def get_task (exp_arg : bgym .ExpArgs ):
21- if exp_arg .exp_id not in tasks :
24+ if exp_arg .exp_id not in task_map :
2225 # Get all dependency tasks first
2326 dependency_tasks = [get_task (exp_args_map [dep_key ]) for dep_key in exp_arg .depends_on ]
2427
2528 # Create new task that depends on the dependency results
26- tasks [exp_arg .exp_id ] = run_exp .remote (
29+ task_map [exp_arg .exp_id ] = run_exp .remote (
2730 exp_arg , * dependency_tasks , avg_step_timeout = avg_step_timeout
2831 )
29- return tasks [exp_arg .exp_id ]
32+ return task_map [exp_arg .exp_id ]
3033
3134 # Build task graph
3235 for exp_arg in exp_args_list :
3336 get_task (exp_arg )
3437
35- # Execute all tasks and gather results
38+ max_timeout = max ([_episode_timeout (exp_args , avg_step_timeout ) for exp_args in exp_args_list ])
39+ return poll_for_timeout (task_map , max_timeout , poll_interval = max_timeout * 0.1 )
40+
41+
42+ def poll_for_timeout (tasks : dict [str , ray .ObjectRef ], timeout : float , poll_interval : float = 1.0 ):
43+ """Cancel tasks that exceeds the timeout
44+
45+ I tried various different methods for killing a job that hangs. so far it's
46+ the only one that seems to work reliably (hopefully)
47+ """
48+ task_list = list (tasks .values ())
3649 task_ids = list (tasks .keys ())
37- results = ray .get (list (tasks .values ()))
3850
39- return {task_id : result for task_id , result in zip (task_ids , results )}
51+ logger .warning (f"Any task exceeding { timeout } seconds will be cancelled." )
52+
53+ while True :
54+ ready , not_ready = ray .wait (task_list , num_returns = len (task_list ), timeout = poll_interval )
55+ for task in not_ready :
56+ elapsed_time = get_elapsed_time (task )
57+ # print(f"Task {task.task_id().hex()} elapsed time: {elapsed_time}")
58+ if elapsed_time is not None and elapsed_time > timeout :
59+ msg = f"Task { task .task_id ().hex ()} hase been running for { elapsed_time } s, more than the timeout: { timeout } s."
60+ if elapsed_time < timeout + 60 :
61+ logger .warning (msg + " Cancelling task." )
62+ ray .cancel (task , force = False , recursive = False )
63+ else :
64+ logger .warning (msg + " Force killing." )
65+ ray .cancel (task , force = True , recursive = False )
66+ if len (ready ) == len (task_list ):
67+ results = []
68+ for task in ready :
69+ try :
70+ result = ray .get (task )
71+ except Exception as e :
72+ result = e
73+ results .append (result )
74+
75+ return {task_id : result for task_id , result in zip (task_ids , results )}
76+
77+
78+ def get_elapsed_time (task_ref : ray .ObjectRef ):
79+ task_id = task_ref .task_id ().hex ()
80+ task_info = state .get_task (task_id , address = "auto" )
81+ if task_info and task_info .start_time_ms is not None :
82+ start_time_s = task_info .start_time_ms / 1000.0 # Convert ms to s
83+ current_time_s = time .time ()
84+ elapsed_time = current_time_s - start_time_s
85+ return elapsed_time
86+ else :
87+ return None # Task has not started yet
0 commit comments