Skip to content

Commit 16e7526

Browse files
recursixgasse
andauthored
Clean pipeline (#117)
* yet another way to kill timedout jobs * Improve timeout handling in task polling logic * Add method to override max_steps in Study class * add support for tab visibility in observation flags and update related components * fix tests * black * Improve timeout handling in task polling logic * yet another way to kill timedout jobs (#108) * Add method to override max_steps in Study class * add support for tab visibility in observation flags and update related components * fix tests * black * black * Fix sorting bug. improve directory content retrieval with summary statistics * fix test * black * tmp * add error report, add cum cost to summary and ray backend by default * black * fix test (chaing to joblib backend) * black --------- Co-authored-by: Maxime Gasse <[email protected]>
1 parent e695e11 commit 16e7526

File tree

6 files changed

+47
-22
lines changed

6 files changed

+47
-22
lines changed

src/agentlab/agents/dynamic_prompting.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,7 @@
1010

1111
import bgym
1212
from browsergym.core.action.base import AbstractActionSet
13-
from browsergym.utils.obs import (
14-
flatten_axtree_to_str,
15-
flatten_dom_to_str,
16-
overlay_som,
17-
prune_html,
18-
)
13+
from browsergym.utils.obs import flatten_axtree_to_str, flatten_dom_to_str, overlay_som, prune_html
1914

2015
from agentlab.llm.llm_utils import (
2116
BaseMessage,

src/agentlab/analyze/agent_xray.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,10 @@ def filter_agent_id(self, agent_id: list[tuple]):
142142
max-height: 400px;
143143
overflow-y: auto;
144144
}
145+
.error-report {
146+
max-height: 700px;
147+
overflow-y: auto;
148+
}
145149
.my-code-view {
146150
max-height: 300px;
147151
overflow-y: auto;
@@ -284,6 +288,8 @@ def run_gradio(results_dir: Path):
284288
with gr.Tab("Global Stats"):
285289
global_stats = gr.DataFrame(max_height=500, show_label=False, interactive=False)
286290

291+
with gr.Tab("Error Report"):
292+
error_report = gr.Markdown(elem_classes="error-report", show_copy_button=True)
287293
with gr.Row():
288294
episode_info = gr.Markdown(label="Episode Info", elem_classes="my-markdown")
289295
action_info = gr.Markdown(label="Action Info", elem_classes="my-markdown")
@@ -411,7 +417,7 @@ def run_gradio(results_dir: Path):
411417
exp_dir_choice.change(
412418
fn=new_exp_dir,
413419
inputs=exp_dir_choice,
414-
outputs=[agent_table, agent_id, constants, variables, global_stats],
420+
outputs=[agent_table, agent_id, constants, variables, global_stats, error_report],
415421
)
416422

417423
agent_table.select(fn=on_select_agent, inputs=agent_table, outputs=[agent_id])
@@ -918,19 +924,25 @@ def get_agent_report(result_df: pd.DataFrame):
918924

919925

920926
def update_global_stats():
921-
global info
922927
stats = inspect_results.global_report(info.result_df, reduce_fn=inspect_results.summarize_stats)
923928
stats.reset_index(inplace=True)
924929
return stats
925930

926931

932+
def update_error_report():
933+
report_files = list(info.exp_list_dir.glob("error_report*.md"))
934+
if len(report_files) == 0:
935+
return "No error report found"
936+
report_files = sorted(report_files, key=os.path.getctime, reverse=True)
937+
return report_files[0].read_text()
938+
939+
927940
def new_exp_dir(exp_dir, progress=gr.Progress(), just_refresh=False):
928941

929942
if exp_dir == select_dir_instructions:
930943
return None, None
931944

932945
exp_dir = exp_dir.split(" - ")[0]
933-
global info
934946

935947
if len(exp_dir) == 0:
936948
info.exp_list_dir = None
@@ -951,7 +963,14 @@ def new_exp_dir(exp_dir, progress=gr.Progress(), just_refresh=False):
951963
agent_id = info.get_agent_id(agent_report.iloc[0])
952964

953965
constants, variables = format_constant_and_variables()
954-
return agent_report, agent_id, constants, variables, update_global_stats()
966+
return (
967+
agent_report,
968+
agent_id,
969+
constants,
970+
variables,
971+
update_global_stats(),
972+
update_error_report(),
973+
)
955974

956975

957976
def new_agent_id(agent_id: list[tuple]):

src/agentlab/analyze/inspect_results.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -297,7 +297,7 @@ def summarize(sub_df, use_bootstrap=False):
297297
n_err=err.sum(skipna=True),
298298
)
299299
if "stats.cum_cost" in sub_df:
300-
record["cum_cost"] = (sub_df["stats.cum_cost"].sum(skipna=True).round(4),)
300+
record["cum_cost"] = sub_df["stats.cum_cost"].sum(skipna=True).round(4)
301301

302302
return pd.Series(record)
303303

src/agentlab/experiments/launch_exp.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,9 @@ def run_experiments(
4040
study_dir = Path(study_dir)
4141
study_dir.mkdir(parents=True, exist_ok=True)
4242

43-
if n_jobs == 1 and parallel_backend != "sequential":
44-
logging.warning("Only 1 job, switching to sequential backend.")
45-
parallel_backend = "sequential"
43+
# if n_jobs == 1 and parallel_backend != "sequential":
44+
# logging.warning("Only 1 job, switching to sequential backend.")
45+
# parallel_backend = "sequential"
4646

4747
logging.info(f"Saving experiments to {study_dir}")
4848
for exp_args in exp_args_list:

src/agentlab/experiments/study.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ def set_reproducibility_info(self, strict_reproducibility=False, comment=None):
123123
def run(
124124
self,
125125
n_jobs=1,
126-
parallel_backend="joblib",
126+
parallel_backend="ray",
127127
strict_reproducibility=False,
128128
n_relaunch=3,
129129
relaunch_errors=True,

tests/agents/test_agent.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,9 @@ def test_generic_agent():
2525

2626
with tempfile.TemporaryDirectory() as tmp_dir:
2727

28-
launch_exp.run_experiments(1, [exp_args], Path(tmp_dir) / "generic_agent_test")
28+
launch_exp.run_experiments(
29+
1, [exp_args], Path(tmp_dir) / "generic_agent_test", parallel_backend="joblib"
30+
)
2931

3032
result_record = inspect_results.load_result_df(tmp_dir, progress_fn=None)
3133

@@ -144,9 +146,12 @@ def test_generic_agent_parse_retry():
144146
)
145147

146148
with tempfile.TemporaryDirectory() as tmp_dir:
147-
launch_exp.run_experiments(1, [exp_args], Path(tmp_dir) / "generic_agent_test")
149+
# TODO why these tests don't work with ray backend?
150+
launch_exp.run_experiments(
151+
1, [exp_args], Path(tmp_dir) / "generic_agent_test", parallel_backend="joblib"
152+
)
148153
result_record = inspect_results.load_result_df(tmp_dir, progress_fn=None)
149-
154+
print(result_record)
150155
target = {
151156
"stats.cum_n_retry": 2,
152157
"stats.cum_busted_retry": 0,
@@ -169,7 +174,9 @@ def test_bust_parse_retry():
169174
)
170175

171176
with tempfile.TemporaryDirectory() as tmp_dir:
172-
launch_exp.run_experiments(1, [exp_args], Path(tmp_dir) / "generic_agent_test")
177+
launch_exp.run_experiments(
178+
1, [exp_args], Path(tmp_dir) / "generic_agent_test", parallel_backend="joblib"
179+
)
173180
result_record = inspect_results.load_result_df(tmp_dir, progress_fn=None)
174181

175182
target = {
@@ -195,7 +202,9 @@ def test_llm_error_success():
195202
)
196203

197204
with tempfile.TemporaryDirectory() as tmp_dir:
198-
launch_exp.run_experiments(1, [exp_args], Path(tmp_dir) / "generic_agent_test")
205+
launch_exp.run_experiments(
206+
1, [exp_args], Path(tmp_dir) / "generic_agent_test", parallel_backend="joblib"
207+
)
199208
result_record = inspect_results.load_result_df(tmp_dir, progress_fn=None)
200209

201210
target = {
@@ -220,7 +229,9 @@ def test_llm_error_no_success():
220229
)
221230

222231
with tempfile.TemporaryDirectory() as tmp_dir:
223-
launch_exp.run_experiments(1, [exp_args], Path(tmp_dir) / "generic_agent_test")
232+
launch_exp.run_experiments(
233+
1, [exp_args], Path(tmp_dir) / "generic_agent_test", parallel_backend="joblib"
234+
)
224235
result_record = inspect_results.load_result_df(tmp_dir, progress_fn=None)
225236

226237
target = {
@@ -236,4 +247,4 @@ def test_llm_error_no_success():
236247

237248
if __name__ == "__main__":
238249
# test_generic_agent()
239-
test_llm_error_success()
250+
test_generic_agent_parse_retry()

0 commit comments

Comments
 (0)