Skip to content

Commit 5d72e8b

Browse files
authored
Tapeagent for Workarena benchmark (#113)
* guided tapeagent example, initial commit * remove unused prompts * reformat * use tape agent directly from the tapeagents repo examples * working version of the workarena tape agent * remove empty ignored roles * stopping the loop by issuing none action * fixes * use flatten from tapeagents * fix * full run * fix folder check in installation script * fix comment
1 parent 096cb59 commit 5d72e8b

File tree

4 files changed

+184
-0
lines changed

4 files changed

+184
-0
lines changed
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
TapeAgents/
2+
tapedata.sqlite
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
#!/bin/bash
2+
3+
if [ ! -d "$(dirname "$0")/TapeAgents" ]; then
4+
# Clone the repository to this directory
5+
git clone https://github.com/ServiceNow/TapeAgents.git "$(dirname "$0")/TapeAgents"
6+
# Install the package in editable mode
7+
pip install -e "$(dirname "$0")/TapeAgents"
8+
else
9+
echo "TapeAgents directory already exists. Skipping installation."
10+
fi
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
from agentlab.agents.tapeagent.tapeagent import TapeAgentArgs
2+
from agentlab.experiments import study_generators
3+
from agentlab.llm.llm_configs import CHAT_MODEL_ARGS_DICT
4+
5+
6+
def main(benchmark: str, n_jobs: int, reproducibility: bool):
7+
agent_args = TapeAgentArgs(
8+
chat_model_args=CHAT_MODEL_ARGS_DICT["openai/gpt-4o-mini-2024-07-18"]
9+
)
10+
if reproducibility:
11+
agent_args.set_reproducibility_mode()
12+
study = study_generators.run_agents_on_benchmark(agent_args, benchmark)
13+
study.run(n_jobs=n_jobs, parallel_backend="joblib", strict_reproducibility=reproducibility)
14+
study.append_to_journal(strict_reproducibility=reproducibility)
15+
16+
17+
if __name__ == "__main__": # necessary for dask backend
18+
n_jobs = 8 # 1 when debugging in VSCode, -1 to use all available cores
19+
benchmark = "workarena.l1"
20+
main(benchmark, n_jobs, reproducibility=True)
Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
import logging
2+
from dataclasses import dataclass
3+
from pathlib import Path
4+
from typing import Any
5+
6+
import bgym
7+
8+
from agentlab.agents.agent_args import AgentArgs
9+
from agentlab.llm.chat_api import BaseModelArgs
10+
from agentlab.llm.tracking import cost_tracker_decorator
11+
12+
##############################
13+
# TODO: replace this hacky imports after releasing tapeagents and tapeagents[examples] to pypi
14+
try:
15+
from tapeagents.llms import LiteLLM
16+
from tapeagents.tools.gym_browser import flatten_axtree
17+
except ImportError as e:
18+
print("Please run install_tapeagents.sh to install tapeagents first.")
19+
raise e
20+
21+
import sys
22+
23+
sys.path.append(str(Path(__file__).parent.resolve() / "TapeAgents"))
24+
##############################
25+
26+
from examples.workarena.agent import WorkArenaAgent
27+
from examples.workarena.steps import (
28+
WorkArenaAction,
29+
ClickAction,
30+
GoBackAction,
31+
GoForwardAction,
32+
GotoPageAction,
33+
HoverAction,
34+
InputTextAction,
35+
PageObservation,
36+
PressAction,
37+
SelectOptionAction,
38+
ScrollAction,
39+
WorkArenaTape,
40+
WorkArenaTask,
41+
StopStep,
42+
)
43+
44+
45+
logger = logging.getLogger(__name__)
46+
logger.setLevel(logging.INFO)
47+
48+
49+
@dataclass
50+
class TapeAgentArgs(AgentArgs):
51+
agent_name: str = "WorkarenaTapeAgent"
52+
chat_model_args: BaseModelArgs = None
53+
54+
def make_agent(self) -> bgym.Agent:
55+
llm = LiteLLM(
56+
model_name=self.chat_model_args.model_name,
57+
use_cache=False,
58+
context_size=self.chat_model_args.max_total_tokens,
59+
parameters={"temperature": self.chat_model_args.temperature},
60+
)
61+
return WorkarenaTapeAgent(llm)
62+
63+
def set_reproducibility_mode(self):
64+
self.chat_model_args.temperature = 0
65+
66+
def prepare(self):
67+
return self.chat_model_args.prepare_server()
68+
69+
def close(self):
70+
return self.chat_model_args.close_server()
71+
72+
73+
class WorkarenaTapeAgent(bgym.Agent):
74+
tape: WorkArenaTape
75+
76+
def __init__(self, llm: LiteLLM):
77+
self.tapeagent = WorkArenaAgent.create(llm)
78+
self.tape = WorkArenaTape()
79+
80+
def obs_preprocessor(self, obs: dict) -> dict:
81+
axtree = obs.pop("axtree_object")
82+
obs["axtree_txt"] = flatten_axtree(axtree)
83+
return obs
84+
85+
@cost_tracker_decorator
86+
def get_action(self, obs: Any) -> tuple[str, bgym.AgentInfo]:
87+
self.update_tape(obs)
88+
# run agent and collect thoughts and last action
89+
tape_segment = []
90+
action = None
91+
logger.info(f"Run tape with {len(self.tape)} steps")
92+
for event in self.tapeagent.run(self.tape):
93+
if not event.step:
94+
continue
95+
step = event.step
96+
tape_segment.append(step)
97+
logger.info(f"Generated step: {step.llm_view()}")
98+
if isinstance(step, WorkArenaAction):
99+
action = self.step_to_action(step)
100+
self.tape += tape_segment
101+
102+
logger.info(f"Action string: {action}")
103+
return (
104+
action,
105+
bgym.AgentInfo(
106+
extra_info={"tape_segment": [step.model_dump() for step in tape_segment]},
107+
stats={},
108+
),
109+
)
110+
111+
def update_tape(self, obs: dict):
112+
"""
113+
Update tape with new observation
114+
"""
115+
obs_step = PageObservation(text=obs["axtree_txt"], current_page=1, total_pages=1)
116+
self.tape = self.tape.append(obs_step)
117+
if len(self.tape) == 1: # first observation
118+
logger.info("First observation, adding goal to tape")
119+
self.tape = self.tape.append(WorkArenaTask(task=obs["goal"]))
120+
121+
def step_to_action(self, action: WorkArenaAction) -> str | None:
122+
"""
123+
Convert action step to an action string with function call
124+
"""
125+
action_str = ""
126+
if isinstance(action, GotoPageAction):
127+
action_str = f"goto('{action.url}')"
128+
elif isinstance(action, ClickAction):
129+
action_str = (
130+
f"click('{action.bid}', button='{action.button}', modifiers={action.modifiers})"
131+
)
132+
elif isinstance(action, SelectOptionAction):
133+
action_str = f"select_option('{action.bid}', '{action.option}')"
134+
elif isinstance(action, HoverAction):
135+
action_str = f"hover('{action.bid}')"
136+
elif isinstance(action, InputTextAction):
137+
text = action.text.replace("'", "\\'")
138+
action_str = f"fill('{action.bid}', '{text}')"
139+
elif isinstance(action, PressAction):
140+
f"press('{action.bid}', '{action.key_comb}')"
141+
elif isinstance(action, GoBackAction):
142+
action_str = "go_back()"
143+
elif isinstance(action, GoForwardAction):
144+
action_str = "go_forward()"
145+
elif isinstance(action, StopStep):
146+
logger.info("Stopping the loop")
147+
action_str = None
148+
elif isinstance(action, ScrollAction):
149+
action_str = "noop()" # TODO: implement scroll action
150+
else:
151+
raise ValueError(f"Unknown action type: {action}")
152+
return action_str

0 commit comments

Comments
 (0)