Skip to content

Commit 7e629bd

Browse files
committed
universal tape agent that can load any agent from config
1 parent 0791f2d commit 7e629bd

File tree

2 files changed

+27
-131
lines changed

2 files changed

+27
-131
lines changed
Lines changed: 26 additions & 129 deletions
Original file line numberDiff line numberDiff line change
@@ -1,152 +1,49 @@
11
import logging
22
from dataclasses import dataclass
3-
from pathlib import Path
43
from typing import Any
54

65
import bgym
6+
import hydra
7+
from tapeagents.agent import Agent
8+
from tapeagents.core import Action, Observation, Tape, Thought
79

810
from agentlab.agents.agent_args import AgentArgs
9-
from agentlab.llm.chat_api import BaseModelArgs
10-
from agentlab.llm.tracking import cost_tracker_decorator
11-
12-
##############################
13-
# TODO: replace this hacky imports after releasing tapeagents and tapeagents[examples] to pypi
14-
try:
15-
from tapeagents.llms import LiteLLM
16-
from tapeagents.tools.gym_browser import flatten_axtree
17-
except ImportError as e:
18-
print("Please run install_tapeagents.sh to install tapeagents first.")
19-
raise e
20-
21-
import sys
22-
23-
sys.path.append(str(Path(__file__).parent.resolve() / "TapeAgents"))
24-
##############################
25-
26-
from examples.workarena.agent import WorkArenaAgent
27-
from examples.workarena.steps import (
28-
WorkArenaAction,
29-
ClickAction,
30-
GoBackAction,
31-
GoForwardAction,
32-
GotoPageAction,
33-
HoverAction,
34-
InputTextAction,
35-
PageObservation,
36-
PressAction,
37-
SelectOptionAction,
38-
ScrollAction,
39-
WorkArenaTape,
40-
WorkArenaTask,
41-
StopStep,
42-
)
43-
4411

4512
logger = logging.getLogger(__name__)
4613
logger.setLevel(logging.INFO)
4714

4815

4916
@dataclass
5017
class TapeAgentArgs(AgentArgs):
51-
agent_name: str = "WorkarenaTapeAgent"
52-
chat_model_args: BaseModelArgs = None
18+
config_name: str
5319

5420
def make_agent(self) -> bgym.Agent:
55-
llm = LiteLLM(
56-
model_name=self.chat_model_args.model_name,
57-
use_cache=False,
58-
context_size=self.chat_model_args.max_total_tokens,
59-
parameters={"temperature": self.chat_model_args.temperature},
60-
)
61-
return WorkarenaTapeAgent(llm)
62-
63-
def set_reproducibility_mode(self):
64-
self.chat_model_args.temperature = 0
21+
with hydra.initialize(config_path="./conf"):
22+
config = hydra.compose(config_name=self.config_name)
23+
agent: Agent = hydra.utils.instantiate(config)
24+
return TapeAgent(agent=agent, tape=Tape(steps=[]))
6525

66-
def prepare(self):
67-
return self.chat_model_args.prepare_server()
6826

69-
def close(self):
70-
return self.chat_model_args.close_server()
27+
class TapeAgent(bgym.Agent):
28+
agent: Agent
29+
tape: Tape
7130

72-
73-
class WorkarenaTapeAgent(bgym.Agent):
74-
tape: WorkArenaTape
75-
76-
def __init__(self, llm: LiteLLM):
77-
self.tapeagent = WorkArenaAgent.create(llm)
78-
self.tape = WorkArenaTape()
79-
80-
def obs_preprocessor(self, obs: dict) -> dict:
81-
axtree = obs.pop("axtree_object")
82-
obs["axtree_txt"] = flatten_axtree(axtree)
31+
def obs_preprocessor(self, obs: dict) -> Any:
32+
logger.info(f"Preprocessing observation: {obs}")
8333
return obs
8434

85-
@cost_tracker_decorator
86-
def get_action(self, obs: Any) -> tuple[str, bgym.AgentInfo]:
87-
self.update_tape(obs)
88-
# run agent and collect thoughts and last action
89-
tape_segment = []
90-
action = None
91-
logger.info(f"Run tape with {len(self.tape)} steps")
92-
for event in self.tapeagent.run(self.tape):
35+
def get_action(self, obs: Observation) -> tuple[str, bgym.AgentInfo]:
36+
self.tape = self.tape.append(obs)
37+
thoughts = []
38+
for event in self.agent.run(self.tape):
9339
if not event.step:
9440
continue
95-
step = event.step
96-
tape_segment.append(step)
97-
logger.info(f"Generated step: {step.llm_view()}")
98-
if isinstance(step, WorkArenaAction):
99-
action = self.step_to_action(step)
100-
self.tape += tape_segment
101-
102-
logger.info(f"Action string: {action}")
103-
return (
104-
action,
105-
bgym.AgentInfo(
106-
extra_info={"tape_segment": [step.model_dump() for step in tape_segment]},
107-
stats={},
108-
),
109-
)
110-
111-
def update_tape(self, obs: dict):
112-
"""
113-
Update tape with new observation
114-
"""
115-
obs_step = PageObservation(text=obs["axtree_txt"], current_page=1, total_pages=1)
116-
self.tape = self.tape.append(obs_step)
117-
if len(self.tape) == 1: # first observation
118-
logger.info("First observation, adding goal to tape")
119-
self.tape = self.tape.append(WorkArenaTask(task=obs["goal"]))
120-
121-
def step_to_action(self, action: WorkArenaAction) -> str | None:
122-
"""
123-
Convert action step to an action string with function call
124-
"""
125-
action_str = ""
126-
if isinstance(action, GotoPageAction):
127-
action_str = f"goto('{action.url}')"
128-
elif isinstance(action, ClickAction):
129-
action_str = (
130-
f"click('{action.bid}', button='{action.button}', modifiers={action.modifiers})"
131-
)
132-
elif isinstance(action, SelectOptionAction):
133-
action_str = f"select_option('{action.bid}', '{action.option}')"
134-
elif isinstance(action, HoverAction):
135-
action_str = f"hover('{action.bid}')"
136-
elif isinstance(action, InputTextAction):
137-
text = action.text.replace("'", "\\'")
138-
action_str = f"fill('{action.bid}', '{text}')"
139-
elif isinstance(action, PressAction):
140-
f"press('{action.bid}', '{action.key_comb}')"
141-
elif isinstance(action, GoBackAction):
142-
action_str = "go_back()"
143-
elif isinstance(action, GoForwardAction):
144-
action_str = "go_forward()"
145-
elif isinstance(action, StopStep):
146-
logger.info("Stopping the loop")
147-
action_str = None
148-
elif isinstance(action, ScrollAction):
149-
action_str = "noop()" # TODO: implement scroll action
150-
else:
151-
raise ValueError(f"Unknown action type: {action}")
152-
return action_str
41+
self.tape = self.tape.append(event.step)
42+
if isinstance(event.step, Thought):
43+
thoughts.append(event.step.llm_view())
44+
logger.info(f"Thought: {event.step.llm_view()}")
45+
elif isinstance(event.step, Action):
46+
action = event.step.llm_view()
47+
logger.info(f"Action: {action}")
48+
break # we stop at the first action
49+
return (action, bgym.AgentInfo(think="\n".join(thoughts), stats={}))

src/agentlab/benchmarks/gaia.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
from tapeagents.environment import ContainerExecutor
66
from tapeagents.tools.browser import Browser
77
from tapeagents.tools.code_executor import CodeExecutor
8-
from tapeagents.tools.container_executor import init_code_sandbox
98
from tapeagents.tools.media_reader import VideoReader
109
from tapeagents.tools.web_search import WebSearch
1110

@@ -25,7 +24,7 @@ class GaiaGymArgs(AbstractEnvArgs):
2524
viewport_chars: int = 64000
2625

2726
def make_env(self) -> GaiaGym:
28-
init_code_sandbox(self.exp_dir)
27+
self.init_code_sandbox()
2928
dataset = datasets.load_dataset("gaia-benchmark/GAIA", "2023_all")
3029
tasks_by_id = {task["task_id"]: task for task in dataset[self.split]}
3130
task = tasks_by_id[self.task_id]

0 commit comments

Comments
 (0)