Skip to content

Commit e0eeaca

Browse files
committed
move loop module from the browsergym
1 parent 55f2ffa commit e0eeaca

File tree

12 files changed

+1105
-68
lines changed

12 files changed

+1105
-68
lines changed

src/agentlab/agents/README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ have to specify the type of each field (You can use Any if it is unknown)*
9999
```python
100100
from dataclasses import dataclass
101101
from browsergym.experiment.agent import Agent
102-
from browsergym.experiment.loop import AgentArgs
102+
from agentlab.experiments.loop import AgentArgs
103103

104104

105105
@dataclass
@@ -116,7 +116,7 @@ class CustomAgentArgs(AgentArgs):
116116
To run experiments with your custom agent, define an instance of `ExpArgs` with the required parameters.
117117

118118
```python
119-
from browsergym.experiment.loop import ExpArgs
119+
from agentlab.experiments.loop import ExpArgs
120120

121121
exp_args = ExpArgs(
122122
agent_args=CustomAgentArgs(custom_param="value"),

src/agentlab/agents/generic_agent/reproducibility_agent.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,10 @@
2020

2121
import bgym
2222
from browsergym.experiments.agent import AgentInfo
23-
from browsergym.experiments.loop import ExpArgs, ExpResult, yield_all_exp_results
2423
from bs4 import BeautifulSoup
25-
from langchain.schema import AIMessage, BaseMessage
26-
from langchain_community.adapters.openai import convert_message_to_dict
2724

2825
from agentlab.agents.agent_args import AgentArgs
29-
from agentlab.agents.dynamic_prompting import ActionFlags
26+
from agentlab.experiments.loop import ExpArgs, ExpResult, yield_all_exp_results
3027
from agentlab.experiments.study import Study
3128
from agentlab.llm.chat_api import make_assistant_message
3229
from agentlab.llm.llm_utils import Discussion, messages_to_dict
@@ -65,7 +62,6 @@ def get_stats(self):
6562

6663
@dataclass
6764
class ReproAgentArgs(GenericAgentArgs):
68-
6965
# starting with "_" will prevent from being part of the index in the load_results function
7066
_repro_dir: str = None
7167

@@ -81,7 +77,6 @@ def make_agent(self):
8177

8278

8379
class ReproAgent(GenericAgent):
84-
8580
def __init__(
8681
self,
8782
chat_model_args,
@@ -93,7 +88,6 @@ def __init__(
9388
super().__init__(chat_model_args, flags, max_retry)
9489

9590
def get_action(self, obs):
96-
9791
# replace the chat model with a reproducible chat that will mimic the
9892
# same answers
9993
step = len(self.actions)
@@ -218,7 +212,10 @@ def make_repro_agent(agent_args: AgentArgs, exp_dir: Path | str):
218212

219213
def _make_diff(old_str, new_str):
220214
page = difflib.HtmlDiff().make_file(
221-
old_str.splitlines(), new_str.splitlines(), fromdesc="Old Version", todesc="New Version"
215+
old_str.splitlines(),
216+
new_str.splitlines(),
217+
fromdesc="Old Version",
218+
todesc="New Version",
222219
)
223220
page = page.replace('nowrap="nowrap"', "") # Remove nowrap attribute
224221
page = _set_style(page, DIFF_STYLE)

src/agentlab/analyze/agent_xray.py

Lines changed: 32 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,13 @@
1212
import numpy as np
1313
import pandas as pd
1414
from attr import dataclass
15-
from browsergym.experiments.loop import ExpResult, StepInfo
1615
from langchain.schema import BaseMessage, HumanMessage
1716
from openai import OpenAI
1817
from PIL import Image
1918

2019
from agentlab.analyze import inspect_results
2120
from agentlab.experiments.exp_utils import RESULTS_DIR
21+
from agentlab.experiments.loop import ExpResult, StepInfo
2222
from agentlab.experiments.study import get_most_recent_study
2323
from agentlab.llm.chat_api import make_system_message, make_user_message
2424
from agentlab.llm.llm_utils import BaseMessage as AgentLabBaseMessage
@@ -201,7 +201,6 @@ def run_gradio(results_dir: Path):
201201
"""
202202
)
203203
with gr.Row():
204-
205204
exp_dir_choice = gr.Dropdown(
206205
choices=get_directory_contents(results_dir),
207206
value=select_dir_instructions,
@@ -297,7 +296,10 @@ def run_gradio(results_dir: Path):
297296
state_error = gr.Markdown(label="Next Step Error", elem_classes="my-markdown")
298297

299298
profiling_gr = gr.Image(
300-
label="Profiling", show_label=False, interactive=False, show_download_button=False
299+
label="Profiling",
300+
show_label=False,
301+
interactive=False,
302+
show_download_button=False,
301303
)
302304

303305
gr.HTML(
@@ -418,7 +420,14 @@ def run_gradio(results_dir: Path):
418420
exp_dir_choice.change(
419421
fn=new_exp_dir,
420422
inputs=exp_dir_choice,
421-
outputs=[agent_table, agent_id, constants, variables, global_stats, error_report],
423+
outputs=[
424+
agent_table,
425+
agent_id,
426+
constants,
427+
variables,
428+
global_stats,
429+
error_report,
430+
],
422431
)
423432

424433
agent_table.select(fn=on_select_agent, inputs=agent_table, outputs=[agent_id])
@@ -454,7 +463,8 @@ def run_gradio(results_dir: Path):
454463
screenshot_gallery.select(fn=gallery_step_change, inputs=episode_id, outputs=step_id)
455464
step_id.change(fn=if_active("DOM HTML")(update_html), outputs=html_code)
456465
step_id.change(
457-
fn=if_active("Pruned DOM HTML")(update_pruned_html), outputs=pruned_html_code
466+
fn=if_active("Pruned DOM HTML")(update_pruned_html),
467+
outputs=pruned_html_code,
458468
)
459469
step_id.change(fn=if_active("AXTree")(update_axtree), outputs=axtree_code)
460470
step_id.change(fn=if_active("Chat Messages")(update_chat_messages), outputs=chat_messages)
@@ -475,10 +485,14 @@ def run_gradio(results_dir: Path):
475485
# we need to update them individually when the tab is selected
476486
tab_screenshot.select(fn=update_screenshot, inputs=som_or_not, outputs=screenshot)
477487
tab_screenshot_pair.select(
478-
fn=update_screenshot_pair, inputs=som_or_not, outputs=[screenshot1, screenshot2]
488+
fn=update_screenshot_pair,
489+
inputs=som_or_not,
490+
outputs=[screenshot1, screenshot2],
479491
)
480492
tab_screenshot_gallery.select(
481-
fn=update_screenshot_gallery, inputs=som_or_not, outputs=[screenshot_gallery]
493+
fn=update_screenshot_gallery,
494+
inputs=som_or_not,
495+
outputs=[screenshot_gallery],
482496
)
483497
tab_html.select(fn=update_html, outputs=html_code)
484498
tab_pruned_html.select(fn=update_pruned_html, outputs=pruned_html_code)
@@ -617,7 +631,7 @@ def update_logs():
617631
try:
618632
return f"""{info.exp_result.logs}"""
619633
except FileNotFoundError:
620-
return f"""No Logs"""
634+
return """No Logs"""
621635

622636

623637
def update_stats():
@@ -757,11 +771,11 @@ def get_episode_info(info: Info):
757771

758772
info = f"""\
759773
### {env_args.task_name} (seed: {env_args.task_seed})
760-
### Step {info.step} / {len(steps_info)-1} (Reward: {cum_reward:.1f})
774+
### Step {info.step} / {len(steps_info) - 1} (Reward: {cum_reward:.1f})
761775
762776
**Goal:**
763777
764-
{code(str(AgentLabBaseMessage('', goal)))}
778+
{code(str(AgentLabBaseMessage("", goal)))}
765779
766780
**Task info:**
767781
@@ -770,7 +784,7 @@ def get_episode_info(info: Info):
770784
**exp_dir:**
771785
772786
<small style="line-height: 1; margin: 0; padding: 0;">{code(exp_dir_str)}</small>"""
773-
except Exception as e:
787+
except Exception:
774788
info = f"""\
775789
**Error while getting episode info**
776790
{code(traceback.format_exc())}"""
@@ -942,7 +956,6 @@ def update_error_report():
942956

943957

944958
def new_exp_dir(exp_dir, progress=gr.Progress(), just_refresh=False):
945-
946959
if exp_dir == select_dir_instructions:
947960
return None, None
948961

@@ -1075,7 +1088,6 @@ def add_patch(ax, start, stop, color, label, edge=False):
10751088

10761089

10771090
def plot_profiling(ax, step_info_list: list[StepInfo], summary_info: dict, progress_fn):
1078-
10791091
if len(step_info_list) == 0:
10801092
warning("No step info to plot")
10811093
return None
@@ -1123,7 +1135,13 @@ def plot_profiling(ax, step_info_list: list[StepInfo], summary_info: dict, progr
11231135

11241136
if step_info.action is not None:
11251137
# Blue rectangle for agent_start to agent_stop
1126-
add_patch(ax, prof.agent_start, prof.agent_stop, colors[10], labels.pop("agent", None))
1138+
add_patch(
1139+
ax,
1140+
prof.agent_start,
1141+
prof.agent_stop,
1142+
colors[10],
1143+
labels.pop("agent", None),
1144+
)
11271145

11281146
# Black vertical bar at agent stop
11291147
ax.axvline(prof.agent_stop, color="black", linewidth=3)

src/agentlab/analyze/inspect_results.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,11 @@
1010

1111
import numpy as np
1212
import pandas as pd
13-
from browsergym.experiments.loop import ExpResult, get_exp_result, yield_all_exp_results
1413
from IPython.display import display
1514
from tqdm import tqdm
1615

16+
from agentlab.experiments.loop import ExpResult, get_exp_result, yield_all_exp_results
17+
1718
# TODO find a more portable way to code set_task_category_as_index at least
1819
# handle dynamic imports. We don't want to always import workarena
1920
# from browsergym.workarena import TASK_CATEGORY_MAP
@@ -83,7 +84,7 @@ def set_index_from_variables(
8384
white = any([fnmatch.fnmatch(var, pattern) for pattern in index_white_list])
8485
black = any([fnmatch.fnmatch(var, pattern) for pattern in index_black_list])
8586

86-
if white and (not black) and (not var in index_variables):
87+
if white and (not black) and (var not in index_variables):
8788
index_variables.append(var)
8889

8990
for var in index_variables:
@@ -205,7 +206,7 @@ def report_constant_and_variables(df, show_stack_traces=True):
205206
if i >= 2:
206207
break
207208
if len(unique_counts) > 3:
208-
print(f" ...\n")
209+
print(" ...\n")
209210

210211

211212
def get_std_err(df, metric):
@@ -235,7 +236,7 @@ def get_sample_std_err(df, metric):
235236

236237

237238
def summarize(sub_df):
238-
if not "cum_reward" in sub_df:
239+
if "cum_reward" not in sub_df:
239240
record = dict(
240241
avg_reward=np.nan,
241242
std_err=np.nan,
@@ -745,7 +746,7 @@ def summarize_study(result_df: pd.DataFrame) -> pd.DataFrame:
745746
def split_by_key(df: pd.DataFrame, key):
746747
"""Return a dict of dataframes spearted by the given key."""
747748
# check if key in df
748-
if not (key in df.columns):
749+
if key not in df.columns:
749750
df = df.reset_index(key, inplace=False)
750751

751752
df_dict = {}
@@ -775,7 +776,7 @@ def get_all_summaries(results_dir: Path, skip_hidden=True, ignore_cache=False, i
775776
summary.set_index("study_dir", inplace=True)
776777
summaries.append(summary)
777778

778-
except Exception as e:
779+
except Exception:
779780
traceback.print_exc()
780781
continue
781782

src/agentlab/experiments/exp_utils.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,10 @@
66
from pathlib import Path
77
from time import sleep, time
88

9-
from browsergym.experiments.loop import ExpArgs, yield_all_exp_results
109
from tqdm import tqdm
1110

11+
from agentlab.experiments.loop import ExpArgs, yield_all_exp_results
12+
1213
logger = logging.getLogger(__name__) # Get logger based on module name
1314

1415

@@ -63,7 +64,6 @@ def timeout_manager(seconds: int = None):
6364
return
6465

6566
def alarm_handler(signum, frame):
66-
6767
logger.warning(f"Operation timed out after {seconds}s, raising TimeoutError.")
6868
# send sigint
6969
# os.kill(os.getpid(), signal.SIGINT) # this doesn't seem to do much I don't know why
@@ -176,11 +176,11 @@ def hide_some_exp(base_dir, filter: callable, just_test):
176176

177177
msg = f"Searching {len(exp_list)} experiments to move to _* expriments where `filter(exp_args)` is True."
178178
if just_test:
179-
msg += f"\nNote: This is a just a test, no experiments will be moved. Set `just_test=False` to move them."
179+
msg += "\nNote: This is a just a test, no experiments will be moved. Set `just_test=False` to move them."
180180

181181
logging.info(msg)
182182

183-
exp_list = tqdm(exp_list, desc=f"Filtering experiments.")
183+
exp_list = tqdm(exp_list, desc="Filtering experiments.")
184184

185185
filtered_out = []
186186
for exp in exp_list:

src/agentlab/experiments/launch_exp.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@
33
from pathlib import Path
44

55
import bgym
6-
from browsergym.experiments.loop import ExpArgs, yield_all_exp_results
76

87
from agentlab.experiments.exp_utils import run_exp
8+
from agentlab.experiments.loop import ExpArgs, yield_all_exp_results
99

1010

1111
def run_experiments(
@@ -142,8 +142,8 @@ def find_incomplete(study_dir: str | Path, include_errors=True):
142142
else:
143143
logging.info(f"Found {job_count} incomplete experiments in {study_dir}.")
144144

145-
message = f"Make sure the processes that were running are all stopped. Otherwise, "
146-
f"there will be concurrent writing in the same directories.\n"
145+
message = "Make sure the processes that were running are all stopped. Otherwise, "
146+
"there will be concurrent writing in the same directories.\n"
147147

148148
logging.info(message)
149149

@@ -193,7 +193,9 @@ def _hide_completed(exp_result: bgym.ExpResult, include_errors: bool = True):
193193

194194

195195
# TODO remove this function once ray backend is stable
196-
def _split_sequential_exp(exp_args_list: list[ExpArgs]) -> tuple[list[ExpArgs], list[ExpArgs]]:
196+
def _split_sequential_exp(
197+
exp_args_list: list[ExpArgs],
198+
) -> tuple[list[ExpArgs], list[ExpArgs]]:
197199
"""split exp_args that are flagged as sequential from those that are not"""
198200
sequential_exp_args = []
199201
parallel_exp_args = []

0 commit comments

Comments
 (0)