|
1 | 1 | import logging |
2 | 2 | from dataclasses import dataclass |
3 | | -from pathlib import Path |
4 | 3 | from typing import Any |
5 | 4 |
|
6 | 5 | import bgym |
| 6 | +import hydra |
| 7 | +from tapeagents.agent import Agent |
| 8 | +from tapeagents.core import Action, Observation, Tape, Thought |
7 | 9 |
|
8 | 10 | from agentlab.agents.agent_args import AgentArgs |
9 | | -from agentlab.llm.chat_api import BaseModelArgs |
10 | | -from agentlab.llm.tracking import cost_tracker_decorator |
11 | | - |
12 | | -############################## |
13 | | -# TODO: replace this hacky imports after releasing tapeagents and tapeagents[examples] to pypi |
14 | | -try: |
15 | | - from tapeagents.llms import LiteLLM |
16 | | - from tapeagents.tools.gym_browser import flatten_axtree |
17 | | -except ImportError as e: |
18 | | - print("Please run install_tapeagents.sh to install tapeagents first.") |
19 | | - raise e |
20 | | - |
21 | | -import sys |
22 | | - |
23 | | -sys.path.append(str(Path(__file__).parent.resolve() / "TapeAgents")) |
24 | | -############################## |
25 | | - |
26 | | -from examples.workarena.agent import WorkArenaAgent |
27 | | -from examples.workarena.steps import ( |
28 | | - WorkArenaAction, |
29 | | - ClickAction, |
30 | | - GoBackAction, |
31 | | - GoForwardAction, |
32 | | - GotoPageAction, |
33 | | - HoverAction, |
34 | | - InputTextAction, |
35 | | - PageObservation, |
36 | | - PressAction, |
37 | | - SelectOptionAction, |
38 | | - ScrollAction, |
39 | | - WorkArenaTape, |
40 | | - WorkArenaTask, |
41 | | - StopStep, |
42 | | -) |
43 | | - |
44 | 11 |
|
45 | 12 | logger = logging.getLogger(__name__) |
46 | 13 | logger.setLevel(logging.INFO) |
47 | 14 |
|
48 | 15 |
|
49 | 16 | @dataclass |
50 | 17 | class TapeAgentArgs(AgentArgs): |
51 | | - agent_name: str = "WorkarenaTapeAgent" |
52 | | - chat_model_args: BaseModelArgs = None |
| 18 | + config_name: str |
53 | 19 |
|
54 | 20 | def make_agent(self) -> bgym.Agent: |
55 | | - llm = LiteLLM( |
56 | | - model_name=self.chat_model_args.model_name, |
57 | | - use_cache=False, |
58 | | - context_size=self.chat_model_args.max_total_tokens, |
59 | | - parameters={"temperature": self.chat_model_args.temperature}, |
60 | | - ) |
61 | | - return WorkarenaTapeAgent(llm) |
62 | | - |
63 | | - def set_reproducibility_mode(self): |
64 | | - self.chat_model_args.temperature = 0 |
| 21 | + with hydra.initialize(config_path="./conf"): |
| 22 | + config = hydra.compose(config_name=self.config_name) |
| 23 | + agent: Agent = hydra.utils.instantiate(config) |
| 24 | + return TapeAgent(agent=agent, tape=Tape(steps=[])) |
65 | 25 |
|
66 | | - def prepare(self): |
67 | | - return self.chat_model_args.prepare_server() |
68 | 26 |
|
69 | | - def close(self): |
70 | | - return self.chat_model_args.close_server() |
| 27 | +class TapeAgent(bgym.Agent): |
| 28 | + agent: Agent |
| 29 | + tape: Tape |
71 | 30 |
|
72 | | - |
73 | | -class WorkarenaTapeAgent(bgym.Agent): |
74 | | - tape: WorkArenaTape |
75 | | - |
76 | | - def __init__(self, llm: LiteLLM): |
77 | | - self.tapeagent = WorkArenaAgent.create(llm) |
78 | | - self.tape = WorkArenaTape() |
79 | | - |
80 | | - def obs_preprocessor(self, obs: dict) -> dict: |
81 | | - axtree = obs.pop("axtree_object") |
82 | | - obs["axtree_txt"] = flatten_axtree(axtree) |
| 31 | + def obs_preprocessor(self, obs: dict) -> Any: |
| 32 | + logger.info(f"Preprocessing observation: {obs}") |
83 | 33 | return obs |
84 | 34 |
|
85 | | - @cost_tracker_decorator |
86 | | - def get_action(self, obs: Any) -> tuple[str, bgym.AgentInfo]: |
87 | | - self.update_tape(obs) |
88 | | - # run agent and collect thoughts and last action |
89 | | - tape_segment = [] |
90 | | - action = None |
91 | | - logger.info(f"Run tape with {len(self.tape)} steps") |
92 | | - for event in self.tapeagent.run(self.tape): |
| 35 | + def get_action(self, obs: Observation) -> tuple[str, bgym.AgentInfo]: |
| 36 | + self.tape = self.tape.append(obs) |
| 37 | + thoughts = [] |
| 38 | + for event in self.agent.run(self.tape): |
93 | 39 | if not event.step: |
94 | 40 | continue |
95 | | - step = event.step |
96 | | - tape_segment.append(step) |
97 | | - logger.info(f"Generated step: {step.llm_view()}") |
98 | | - if isinstance(step, WorkArenaAction): |
99 | | - action = self.step_to_action(step) |
100 | | - self.tape += tape_segment |
101 | | - |
102 | | - logger.info(f"Action string: {action}") |
103 | | - return ( |
104 | | - action, |
105 | | - bgym.AgentInfo( |
106 | | - extra_info={"tape_segment": [step.model_dump() for step in tape_segment]}, |
107 | | - stats={}, |
108 | | - ), |
109 | | - ) |
110 | | - |
111 | | - def update_tape(self, obs: dict): |
112 | | - """ |
113 | | - Update tape with new observation |
114 | | - """ |
115 | | - obs_step = PageObservation(text=obs["axtree_txt"], current_page=1, total_pages=1) |
116 | | - self.tape = self.tape.append(obs_step) |
117 | | - if len(self.tape) == 1: # first observation |
118 | | - logger.info("First observation, adding goal to tape") |
119 | | - self.tape = self.tape.append(WorkArenaTask(task=obs["goal"])) |
120 | | - |
121 | | - def step_to_action(self, action: WorkArenaAction) -> str | None: |
122 | | - """ |
123 | | - Convert action step to an action string with function call |
124 | | - """ |
125 | | - action_str = "" |
126 | | - if isinstance(action, GotoPageAction): |
127 | | - action_str = f"goto('{action.url}')" |
128 | | - elif isinstance(action, ClickAction): |
129 | | - action_str = ( |
130 | | - f"click('{action.bid}', button='{action.button}', modifiers={action.modifiers})" |
131 | | - ) |
132 | | - elif isinstance(action, SelectOptionAction): |
133 | | - action_str = f"select_option('{action.bid}', '{action.option}')" |
134 | | - elif isinstance(action, HoverAction): |
135 | | - action_str = f"hover('{action.bid}')" |
136 | | - elif isinstance(action, InputTextAction): |
137 | | - text = action.text.replace("'", "\\'") |
138 | | - action_str = f"fill('{action.bid}', '{text}')" |
139 | | - elif isinstance(action, PressAction): |
140 | | - f"press('{action.bid}', '{action.key_comb}')" |
141 | | - elif isinstance(action, GoBackAction): |
142 | | - action_str = "go_back()" |
143 | | - elif isinstance(action, GoForwardAction): |
144 | | - action_str = "go_forward()" |
145 | | - elif isinstance(action, StopStep): |
146 | | - logger.info("Stopping the loop") |
147 | | - action_str = None |
148 | | - elif isinstance(action, ScrollAction): |
149 | | - action_str = "noop()" # TODO: implement scroll action |
150 | | - else: |
151 | | - raise ValueError(f"Unknown action type: {action}") |
152 | | - return action_str |
| 41 | + self.tape = self.tape.append(event.step) |
| 42 | + if isinstance(event.step, Thought): |
| 43 | + thoughts.append(event.step.llm_view()) |
| 44 | + logger.info(f"Thought: {event.step.llm_view()}") |
| 45 | + elif isinstance(event.step, Action): |
| 46 | + action = event.step.llm_view() |
| 47 | + logger.info(f"Action: {action}") |
| 48 | + break # we stop at the first action |
| 49 | + return (action, bgym.AgentInfo(think="\n".join(thoughts), stats={})) |
0 commit comments