Skip to content

Commit f8d1e47

Browse files
recursixgasse
andauthored
Replacing Dask with Ray (#100)
* dask-dependencies * minor * replace with ray * adjust tests and move a few things * markdown report * automatic relaunch * add dependencies * reformat * fix unit-test * catch timeout * fixing bugs and making things work * adress comments and black format * new dependencies viewer * Update benchmark to use visualwebarena instead of webarena * Fix import and uncomment code in get_ray_url.py * Add ignore_dependencies option to Study and _agents_on_benchmark functions * Update load_most_recent method to include contains parameter * Update load_most_recent method to accept contains parameter and add warning for ignored dependencies in _agents_on_benchmark * Refactor backend preparation in Study class and improve logging for ignored dependencies * finallly some results with claude on webarena * Add warnings for Windows timeouts and clarify parallel backend options; update get_results method to conditionally save outputs * black * ensure timeout is int (For the 3rd time?) * Refactor timeout handling in context manager; update test to reduce avg_step_timeout and rename test function * black * Change parallel backend from "joblib" to "ray" in run_experiments function * Update src/agentlab/experiments/study.py Co-authored-by: Maxime Gasse <[email protected]> * Update src/agentlab/analyze/inspect_results.py Co-authored-by: Maxime Gasse <[email protected]> * Refactor logging initialization and update layout configurations in dependency graph plotting; adjust node size and font size for better visualization --------- Co-authored-by: Maxime Gasse <[email protected]>
1 parent f6ac587 commit f8d1e47

17 files changed

+943
-288
lines changed

main.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
repository.
77
"""
88

9-
import bgym
109
import logging
1110
from agentlab.agents.generic_agent import (
1211
RANDOM_SEARCH_AGENT,
@@ -26,7 +25,7 @@
2625

2726
# ## select the benchmark to run on
2827
benchmark = "miniwob_tiny_test"
29-
# benchmark = "miniwob_all"
28+
# benchmark = "miniwob"
3029
# benchmark = "workarena_l1"
3130
# benchmark = "workarena_l2"
3231
# benchmark = "workarena_l3"
@@ -53,13 +52,18 @@
5352

5453
if relaunch:
5554
# relaunch an existing study
56-
study = Study.load_most_recent()
57-
study.find_incomplete(relaunch_mode="incomplete_or_error")
55+
study = Study.load_most_recent(contains=None)
56+
study.find_incomplete(include_errors=True)
5857

5958
else:
60-
study = Study(agent_args, benchmark)
61-
62-
study.run(n_jobs=n_jobs, parallel_backend="joblib", strict_reproducibility=reproducibility_mode)
59+
study = Study(agent_args, benchmark, logging_level_stdout=logging.WARNING)
60+
61+
study.run(
62+
n_jobs=n_jobs,
63+
parallel_backend="ray",
64+
strict_reproducibility=reproducibility_mode,
65+
n_relaunch=3,
66+
)
6367

6468
if reproducibility_mode:
6569
study.append_to_journal(strict_reproducibility=True)

reproducibility_journal.csv

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ ThibaultLSDC,GenericAgent-gpt-4o,workarena_l1,0.4.1,2024-10-23_22-30-06,2024-10-
2626
ThibaultLSDC,GenericAgent-anthropic_claude-3.5-sonnet:beta,workarena_l1,0.4.1,2024-10-23_22-30-06,2024-10-23_14-17-40,0.564,0.027,1,330/330,None,Linux (#66-Ubuntu SMP Fri Aug 30 13:56:20 UTC 2024),3.12.7,1.39.0,0.2.3,4cd1e2d4189ddfbeb94129f7b0c9a00c3400ebac,,0.9.0,f25bdcd6b946fc4a79cdbee5fbcad53548af8724,
2727
ThibaultLSDC,GenericAgent-meta-llama_llama-3.1-70b-instruct,workarena_l1,0.4.1,2024-10-23_22-30-06,2024-10-23_14-17-40,0.279,0.025,0,330/330,None,Linux (#66-Ubuntu SMP Fri Aug 30 13:56:20 UTC 2024),3.12.7,1.39.0,0.2.3,4cd1e2d4189ddfbeb94129f7b0c9a00c3400ebac,,0.9.0,f25bdcd6b946fc4a79cdbee5fbcad53548af8724,
2828
ThibaultLSDC,GenericAgent-openai_o1-mini-2024-09-12,workarena_l1,0.4.1,2024-10-23_22-30-06,2024-10-23_14-17-40,0.567,0.027,4,330/330,None,Linux (#66-Ubuntu SMP Fri Aug 30 13:56:20 UTC 2024),3.12.7,1.39.0,0.2.3,4cd1e2d4189ddfbeb94129f7b0c9a00c3400ebac,,0.9.0,f25bdcd6b946fc4a79cdbee5fbcad53548af8724,
29+
recursix,GenericAgent-anthropic_claude-3.5-sonnet:beta,webarena,0.11.3,2024-11-02_23-50-17,22a9d3f5-9d86-455e-b451-3ea17690ce8a,0.329,0.016,0,812/812,None,Linux (#66-Ubuntu SMP Fri Aug 30 13:56:20 UTC 2024),3.12.6,1.39.0,0.2.3,418a05d90c74800cd66371b7846ef861185b8c47,,0.11.3,160167ff0d2631826f0131e8e30b92ef448d6881,
2930
ThibaultLSDC,GenericAgent-gpt-4o-mini,workarena_l2_agent_curriculum_eval,0.4.1,2024-10-24_17-08-53,2024-10-23_17-10-46,0.013,0.007,2,235/235,None,Linux (#66-Ubuntu SMP Fri Aug 30 13:56:20 UTC 2024),3.12.7,1.39.0,0.2.3,827d847995f19dc337f3899427340bdddbd81cd5,,0.10.0,None,
3031
ThibaultLSDC,GenericAgent-gpt-4o,workarena_l2_agent_curriculum_eval,0.4.1,2024-10-24_17-08-53,2024-10-23_17-10-46,0.085,0.018,3,233/235,None,Linux (#66-Ubuntu SMP Fri Aug 30 13:56:20 UTC 2024),3.12.7,1.39.0,0.2.3,827d847995f19dc337f3899427340bdddbd81cd5,,0.10.0,None,
3132
ThibaultLSDC,GenericAgent-anthropic_claude-3.5-sonnet:beta,workarena_l2_agent_curriculum_eval,0.4.1,2024-10-24_17-08-53,2024-10-23_17-10-46,0.391,0.032,3,235/235,None,Linux (#66-Ubuntu SMP Fri Aug 30 13:56:20 UTC 2024),3.12.7,1.39.0,0.2.3,827d847995f19dc337f3899427340bdddbd81cd5,,0.10.0,None,

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,3 +20,4 @@ gradio>=5
2020
gitpython # for the reproducibility script
2121
requests
2222
matplotlib
23+
ray[default]

src/agentlab/analyze/inspect_results.ipynb

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,11 @@
151151
"metadata": {},
152152
"outputs": [],
153153
"source": [
154-
"print(inspect_results.error_report(result_df, max_stack_trace=1))"
154+
"from IPython.display import Markdown, display\n",
155+
"\n",
156+
"report = inspect_results.error_report(result_df, max_stack_trace=2, use_log=True)\n",
157+
"# display(Markdown(report))\n",
158+
"print(report)"
155159
]
156160
},
157161
{
@@ -166,7 +170,7 @@
166170
],
167171
"metadata": {
168172
"kernelspec": {
169-
"display_name": "ui-copilot",
173+
"display_name": "Python 3",
170174
"language": "python",
171175
"name": "python3"
172176
},

src/agentlab/analyze/inspect_results.py

Lines changed: 45 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -581,10 +581,12 @@ def set_wrap_style(df):
581581
# ------------
582582

583583

584-
def map_err_key(err_msg):
584+
def map_err_key(err_msg: str):
585585
if err_msg is None:
586586
return err_msg
587587

588+
# remove logs from the message if any
589+
err_msg = err_msg[: err_msg.find("=== logs ===")].rstrip()
588590
regex_replacements = [
589591
(
590592
r"your messages resulted in \d+ tokens",
@@ -601,7 +603,7 @@ def map_err_key(err_msg):
601603
return err_msg
602604

603605

604-
def error_report(df: pd.DataFrame, max_stack_trace=10):
606+
def error_report(df: pd.DataFrame, max_stack_trace=10, use_log=False):
605607
"""Report the error message for each agent."""
606608

607609
if "err_key" not in df:
@@ -611,35 +613,62 @@ def error_report(df: pd.DataFrame, max_stack_trace=10):
611613
report = []
612614
for err_key, count in unique_counts.items():
613615
report.append("-------------------")
614-
report.append(f"{count}x : {err_key}\n")
616+
report.append(f"## {count}x : " + err_key.replace("\n", "<br>") + "\n")
617+
615618
# find sub_df with this error message
616619
sub_df = df[df["err_key"] == err_key]
617620
idx = 0
618621

619622
exp_result_list = [get_exp_result(row.exp_dir) for _, row in sub_df.iterrows()]
620-
task_names = [exp_result.exp_args.env_args.task_name for exp_result in exp_result_list]
621-
622-
# count unique using numpy
623-
unique_task_names, counts = np.unique(task_names, return_counts=True)
624-
task_and_count = sorted(zip(unique_task_names, counts), key=lambda x: x[1], reverse=True)
625-
for task_name, count in task_and_count:
626-
report.append(f"{count:2d} {task_name}")
623+
exp_result_list = sorted(exp_result_list, key=lambda x: x.exp_args.env_args.task_name)
624+
for exp_result in exp_result_list:
625+
report.append(
626+
f"* {exp_result.exp_args.env_args.task_name} seed: {exp_result.exp_args.env_args.task_seed}"
627+
)
627628

628629
report.append(f"\nShowing Max {max_stack_trace} stack traces:\n")
629630
for exp_result in exp_result_list:
630631
if idx >= max_stack_trace:
631632
break
632-
# print task name and stack trace
633-
stack_trace = exp_result.summary_info.get("stack_trace", "")
634-
report.append(f"Task Name: {exp_result.exp_args.env_args.task_name}\n")
635-
report.append(f"exp_dir: {exp_result.exp_dir}\n")
636-
report.append(f"Stack Trace: \n {stack_trace}\n")
637-
report.append("\n")
633+
634+
if not use_log:
635+
# print task name and stack trace
636+
stack_trace = exp_result.summary_info.get("stack_trace", "")
637+
report.append(f"Task Name: {exp_result.exp_args.env_args.task_name}\n")
638+
report.append(f"exp_dir: {exp_result.exp_dir}\n")
639+
report.append(f"Stack Trace: \n {stack_trace}\n")
640+
report.append("\n")
641+
else:
642+
report.append(f"```bash\n{_format_log(exp_result)}\n```")
643+
638644
idx += 1
639645

640646
return "\n".join(report)
641647

642648

649+
def _format_log(exp_result: ExpResult, head_lines=10, tail_lines=50):
650+
"""Extract head and tail of the log. Try to find the traceback."""
651+
log = exp_result.logs
652+
if log is None:
653+
return "No log found"
654+
655+
log_lines = log.split("\n")
656+
if len(log_lines) <= head_lines + tail_lines:
657+
return log
658+
659+
# first 10 lines:
660+
log_head = "\n".join(log_lines[:head_lines])
661+
662+
try:
663+
traceback_idx = log.rindex("Traceback (most recent call last):")
664+
tail_idx = log.rindex("action:", 0, traceback_idx)
665+
log_tail = log[tail_idx:]
666+
except ValueError:
667+
log_tail = "\n".join(log_lines[-tail_lines:])
668+
669+
return log_head + "\n...\n...truncated middle of the log\n...\n" + log_tail
670+
671+
643672
def categorize_error(row):
644673
if pd.isna(row.get("err_msg", None)):
645674
return None

src/agentlab/experiments/exp_utils.py

Lines changed: 146 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,12 @@
44
from tqdm import tqdm
55
import logging
66
from browsergym.experiments.loop import ExpArgs
7+
from contextlib import contextmanager
8+
import signal
9+
import sys
10+
from time import time, sleep
11+
12+
logger = logging.getLogger(__name__) # Get logger based on module name
713

814

915
# TODO move this to a more appropriate place
@@ -19,8 +25,148 @@
1925
RESULTS_DIR.mkdir(parents=True, exist_ok=True)
2026

2127

28+
def run_exp(exp_arg: ExpArgs, *dependencies, avg_step_timeout=60):
29+
"""Run exp_args.run() with a timeout and handle dependencies."""
30+
episode_timeout = _episode_timeout(exp_arg, avg_step_timeout=avg_step_timeout)
31+
with timeout_manager(seconds=episode_timeout):
32+
return exp_arg.run()
33+
34+
35+
def _episode_timeout(exp_arg: ExpArgs, avg_step_timeout=60):
36+
"""Some logic to determine the episode timeout."""
37+
max_steps = getattr(exp_arg.env_args, "max_steps", None)
38+
if max_steps is None:
39+
episode_timeout_global = 10 * 60 * 60 # 10 hours
40+
else:
41+
episode_timeout_global = exp_arg.env_args.max_steps * avg_step_timeout
42+
43+
episode_timeout_exp = getattr(exp_arg, "episode_timeout", episode_timeout_global)
44+
45+
return min(episode_timeout_global, episode_timeout_exp)
46+
47+
48+
@contextmanager
49+
def timeout_manager(seconds: int = None):
50+
"""Context manager to handle timeouts."""
51+
52+
if isinstance(seconds, float):
53+
seconds = max(1, int(seconds)) # make sure seconds is at least 1
54+
55+
if seconds is None or sys.platform == "win32":
56+
try:
57+
logger.warning("Timeouts are not supported on Windows.")
58+
yield
59+
finally:
60+
pass
61+
return
62+
63+
def alarm_handler(signum, frame):
64+
65+
logger.warning(
66+
f"Operation timed out after {seconds}s, sending SIGINT and raising TimeoutError."
67+
)
68+
# send sigint
69+
os.kill(os.getpid(), signal.SIGINT)
70+
71+
# Still raise TimeoutError for immediate handling
72+
raise TimeoutError(f"Operation timed out after {seconds} seconds")
73+
74+
previous_handler = signal.signal(signal.SIGALRM, alarm_handler)
75+
signal.alarm(seconds)
76+
77+
try:
78+
yield
79+
finally:
80+
signal.alarm(0)
81+
signal.signal(signal.SIGALRM, previous_handler)
82+
83+
84+
def add_dependencies(exp_args_list: list[ExpArgs], task_dependencies: dict[str, list[str]] = None):
85+
"""Add dependencies to a list of ExpArgs.
86+
87+
Args:
88+
exp_args_list: list[ExpArgs]
89+
A list of experiments to run.
90+
task_dependencies: dict
91+
A dictionary mapping task names to a list of task names that they
92+
depend on. If None or empty, no dependencies are added.
93+
94+
Returns:
95+
list[ExpArgs]
96+
The modified exp_args_list with dependencies added.
97+
"""
98+
99+
if task_dependencies is None or all([len(dep) == 0 for dep in task_dependencies.values()]):
100+
# nothing to be done
101+
return exp_args_list
102+
103+
for exp_args in exp_args_list:
104+
exp_args.make_id() # makes sure there is an exp_id
105+
106+
exp_args_map = {exp_args.env_args.task_name: exp_args for exp_args in exp_args_list}
107+
if len(exp_args_map) != len(exp_args_list):
108+
raise ValueError(
109+
(
110+
"Task names are not unique in exp_args_map, "
111+
"you can't run multiple seeds with task dependencies."
112+
)
113+
)
114+
115+
for task_name in exp_args_map.keys():
116+
if task_name not in task_dependencies:
117+
raise ValueError(f"Task {task_name} is missing from task_dependencies")
118+
119+
# turn dependencies from task names to exp_ids
120+
for task_name, exp_args in exp_args_map.items():
121+
exp_args.depends_on = tuple(
122+
exp_args_map[dep_name].exp_id for dep_name in task_dependencies[task_name]
123+
)
124+
125+
return exp_args_list
126+
127+
128+
# Mock implementation of the ExpArgs class with timestamp checks for unit testing
129+
class MockedExpArgs:
130+
def __init__(self, exp_id, depends_on=None):
131+
self.exp_id = exp_id
132+
self.depends_on = depends_on if depends_on else []
133+
self.start_time = None
134+
self.end_time = None
135+
self.env_args = None
136+
137+
def run(self):
138+
self.start_time = time()
139+
140+
# # simulate playright code, (this was causing issues due to python async loop)
141+
# import playwright.sync_api
142+
143+
# pw = playwright.sync_api.sync_playwright().start()
144+
# pw.selectors.set_test_id_attribute("mytestid")
145+
sleep(3) # Simulate task execution time
146+
self.end_time = time()
147+
return self
148+
149+
150+
def make_seeds(n, offset=42):
151+
raise DeprecationWarning("This function will be removed. Comment out this error if needed.")
152+
return [seed + offset for seed in range(n)]
153+
154+
155+
def order(exp_args_list: list[ExpArgs]):
156+
raise DeprecationWarning("This function will be removed. Comment out this error if needed.")
157+
"""Store the order of the list of experiments to be able to sort them back.
158+
159+
This is important for progression or ablation studies.
160+
"""
161+
for i, exp_args in enumerate(exp_args_list):
162+
exp_args.order = i
163+
return exp_args_list
164+
165+
166+
# This was an old function for filtering some issue with the experiments.
22167
def hide_some_exp(base_dir, filter: callable, just_test):
23168
"""Move all experiments that match the filter to a new name."""
169+
raise DeprecationWarning("This function will be removed. Comment out this error if needed.")
24170
exp_list = list(yield_all_exp_results(base_dir, progress_fn=None))
25171

26172
msg = f"Searching {len(exp_list)} experiments to move to _* expriments where `filter(exp_args)` is True."
@@ -38,17 +184,3 @@ def hide_some_exp(base_dir, filter: callable, just_test):
38184
_move_old_exp(exp.exp_dir)
39185
filtered_out.append(exp)
40186
return filtered_out
41-
42-
43-
def make_seeds(n, offset=42):
44-
return [seed + offset for seed in range(n)]
45-
46-
47-
def order(exp_args_list: list[ExpArgs]):
48-
"""Store the order of the list of experiments to be able to sort them back.
49-
50-
This is important for progression or ablation studies.
51-
"""
52-
for i, exp_args in enumerate(exp_args_list):
53-
exp_args.order = i
54-
return exp_args_list
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
import ray
2+
3+
context = ray.init(address="auto", ignore_reinit_error=True)
4+
5+
print(context)

0 commit comments

Comments
 (0)