Skip to content

Commit b0d4a99

Browse files
Merge remote-tracking branch 'origin' into osworld
2 parents 2afb28b + 7b8d24e commit b0d4a99

File tree

8 files changed

+1064
-569
lines changed

8 files changed

+1064
-569
lines changed

main_workarena_debug.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
"""
2+
Note: This script is a convenience script to launch experiments instead of using
3+
the command line.
4+
5+
Copy this script and modify at will, but don't push your changes to the
6+
repository.
7+
"""
8+
9+
import logging
10+
from copy import deepcopy
11+
12+
import bgym
13+
14+
from agentlab.agents.tool_use_agent.tool_use_agent import (
15+
DEFAULT_PROMPT_CONFIG,
16+
GPT_4_1,
17+
ToolUseAgentArgs,
18+
)
19+
from agentlab.experiments.study import Study
20+
21+
logging.getLogger().setLevel(logging.INFO)
22+
23+
config = deepcopy(DEFAULT_PROMPT_CONFIG)
24+
# config.keep_last_n_obs = 1
25+
config.obs.use_som = True
26+
27+
28+
agent_configs = [
29+
ToolUseAgentArgs(
30+
model_args=GPT_4_1,
31+
config=config,
32+
),
33+
# ToolUseAgentArgs(
34+
# model_args=GPT_4_1,
35+
# config=config,
36+
# ),
37+
]
38+
39+
for agent_config in agent_configs:
40+
agent_config.config.action_subsets = ("workarena",) # use the workarena action set
41+
42+
43+
# ## select the benchmark to run on
44+
# benchmark = "miniwob_tiny_test"
45+
benchmark = "workarena_l1"
46+
47+
48+
benchmark = bgym.DEFAULT_BENCHMARKS[benchmark](n_repeats=4) # type: bgym.Benchmark
49+
benchmark = benchmark.subset_from_glob("task_name", "*create*")
50+
51+
# for env_args in benchmark.env_args_list:
52+
# print(env_args.task_name)
53+
# env_args.max_steps = 15
54+
55+
relaunch = False
56+
57+
## Number of parallel jobs
58+
n_jobs = 10 # Make sure to use 1 job when debugging in VSCode
59+
parallel_backend = "ray"
60+
# parallel_backend = "sequential" # activate sequential backend for debugging in VSCode
61+
62+
if __name__ == "__main__": # necessary for dask backend
63+
64+
if relaunch:
65+
# relaunch an existing study
66+
study = Study.load_most_recent(contains=None)
67+
study.find_incomplete(include_errors=True)
68+
69+
else:
70+
study = Study(agent_configs, benchmark, logging_level_stdout=logging.WARNING)
71+
72+
study.run(
73+
n_jobs=n_jobs,
74+
parallel_backend=parallel_backend, # "ray", "joblib" or "sequential"
75+
strict_reproducibility=False,
76+
n_relaunch=3,
77+
)

src/agentlab/agents/tool_use_agent/tool_use_agent.py

Lines changed: 93 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,14 @@
2626
from agentlab.llm.base_api import BaseModelArgs
2727
from agentlab.llm.llm_utils import image_to_png_base64_url
2828
from agentlab.llm.response_api import (
29+
APIPayload,
2930
ClaudeResponseModelArgs,
3031
LLMOutput,
3132
MessageBuilder,
3233
OpenAIChatModelArgs,
3334
OpenAIResponseModelArgs,
35+
OpenRouterModelArgs,
36+
ToolCalls,
3437
)
3538
from agentlab.llm.tracking import cost_tracker_decorator
3639

@@ -101,7 +104,8 @@ def flatten(self) -> list[MessageBuilder]:
101104
messages.extend(group.messages)
102105
# Mark all summarized messages for caching
103106
if i == len(self.groups) - keep_last_n_obs:
104-
messages[i].mark_all_previous_msg_for_caching()
107+
if not isinstance(messages[i], ToolCalls):
108+
messages[i].mark_all_previous_msg_for_caching()
105109
return messages
106110

107111
def set_last_summary(self, summary: MessageBuilder):
@@ -130,8 +134,10 @@ class Goal(Block):
130134

131135
goal_as_system_msg: bool = True
132136

133-
def apply(self, llm, discussion: StructuredDiscussion, obs: dict) -> dict:
134-
system_message = llm.msg.system().add_text(SYS_MSG)
137+
def apply(
138+
self, llm, discussion: StructuredDiscussion, obs: dict, sys_msg: str = SYS_MSG
139+
) -> dict:
140+
system_message = llm.msg.system().add_text(sys_msg)
135141
discussion.append(system_message)
136142

137143
if self.goal_as_system_msg:
@@ -164,18 +170,16 @@ class Obs(Block):
164170
use_dom: bool = False
165171
use_som: bool = False
166172
use_tabs: bool = False
167-
add_mouse_pointer: bool = False
173+
# add_mouse_pointer: bool = False
168174
use_zoomed_webpage: bool = False
169175
skip_preprocessing: bool = False
170176

171177
def apply(
172178
self, llm, discussion: StructuredDiscussion, obs: dict, last_llm_output: LLMOutput
173179
) -> dict:
174-
if last_llm_output.tool_calls is None:
175-
obs_msg = llm.msg.user() # type: MessageBuilder
176-
else:
177-
obs_msg = llm.msg.tool(last_llm_output.raw_response) # type: MessageBuilder
178180

181+
obs_msg = llm.msg.user()
182+
tool_calls = last_llm_output.tool_calls
179183
if self.use_last_error:
180184
if obs["last_action_error"] != "":
181185
obs_msg.add_text(f"Last action error:\n{obs['last_action_error']}")
@@ -186,13 +190,12 @@ def apply(
186190
else:
187191
screenshot = obs["screenshot"]
188192

189-
if self.add_mouse_pointer:
190-
# TODO this mouse pointer should be added at the browsergym level
191-
screenshot = np.array(
192-
agent_utils.add_mouse_pointer_from_action(
193-
Image.fromarray(obs["screenshot"]), obs["last_action"]
194-
)
195-
)
193+
# if self.add_mouse_pointer:
194+
# screenshot = np.array(
195+
# agent_utils.add_mouse_pointer_from_action(
196+
# Image.fromarray(obs["screenshot"]), obs["last_action"]
197+
# )
198+
# )
196199

197200
obs_msg.add_image(image_to_png_base64_url(screenshot))
198201
if self.use_axtree:
@@ -203,6 +206,13 @@ def apply(
203206
obs_msg.add_text(_format_tabs(obs))
204207

205208
discussion.append(obs_msg)
209+
210+
if tool_calls:
211+
for call in tool_calls:
212+
call.response_text("See Observation")
213+
tool_response = llm.msg.add_responded_tool_calls(tool_calls)
214+
discussion.append(tool_response)
215+
206216
return obs_msg
207217

208218

@@ -253,8 +263,8 @@ def apply(self, llm, discussion: StructuredDiscussion) -> dict:
253263
msg = llm.msg.user().add_text("""Summarize\n""")
254264

255265
discussion.append(msg)
256-
# TODO need to make sure we don't force tool use here
257-
summary_response = llm(messages=discussion.flatten(), tool_choice="none")
266+
267+
summary_response = llm(APIPayload(messages=discussion.flatten()))
258268

259269
summary_msg = llm.msg.assistant().add_text(summary_response.think)
260270
discussion.append(summary_msg)
@@ -319,24 +329,6 @@ def apply(self, llm, discussion: StructuredDiscussion, task_name: str) -> dict:
319329
discussion.append(msg)
320330

321331

322-
class ToolCall(Block):
323-
def __init__(self, tool_server):
324-
self.tool_server = tool_server
325-
326-
def apply(self, llm, messages: list[MessageBuilder], obs: dict) -> dict:
327-
# build the message by adding components to obs
328-
response: LLMOutput = llm(messages=self.messages)
329-
330-
messages.append(response.assistant_message) # this is tool call
331-
332-
tool_answer = self.tool_server.call_tool(response)
333-
tool_msg = llm.msg.tool() # type: MessageBuilder
334-
tool_msg.add_tool_id(response.last_computer_call_id)
335-
tool_msg.update_last_raw_response(response)
336-
tool_msg.add_text(str(tool_answer))
337-
messages.append(tool_msg)
338-
339-
340332
@dataclass
341333
class PromptConfig:
342334
tag_screenshot: bool = True # Whether to tag the screenshot with the last action.
@@ -401,7 +393,7 @@ def __init__(
401393

402394
self.call_ids = []
403395

404-
self.llm = model_args.make_model(extra_kwargs={"tools": self.tools})
396+
self.llm = model_args.make_model()
405397
self.msg_builder = model_args.get_message_builder()
406398
self.llm.msg = self.msg_builder
407399

@@ -451,7 +443,13 @@ def get_action(self, obs: Any) -> float:
451443
self.llm.reset_stats()
452444
if not self.discussion.is_goal_set():
453445
self.discussion.new_group("goal")
454-
self.config.goal.apply(self.llm, self.discussion, obs)
446+
447+
if self.config.multiaction:
448+
sys_msg = SYS_MSG + "\nYou can take multiple actions in a single step, if needed."
449+
else:
450+
sys_msg = SYS_MSG + "\nYou can only take one action at a time."
451+
self.config.goal.apply(self.llm, self.discussion, obs, sys_msg)
452+
455453
self.config.summarizer.apply_init(self.llm, self.discussion)
456454
self.config.general_hints.apply(self.llm, self.discussion)
457455
self.task_hint.apply(self.llm, self.discussion, self.task_name)
@@ -464,21 +462,23 @@ def get_action(self, obs: Any) -> float:
464462

465463
messages = self.discussion.flatten()
466464
response: LLMOutput = self.llm(
467-
messages=messages,
468-
tool_choice="any",
469-
cache_tool_definition=True,
470-
cache_complete_prompt=False,
471-
use_cache_breakpoints=True,
465+
APIPayload(
466+
messages=messages,
467+
tools=self.tools, # You can update tools available tools now.
468+
tool_choice="any",
469+
cache_tool_definition=True,
470+
cache_complete_prompt=False,
471+
use_cache_breakpoints=True,
472+
)
472473
)
473-
474474
action = response.action
475475
think = response.think
476476
last_summary = self.discussion.get_last_summary()
477477
if last_summary is not None:
478478
think = last_summary.content[0]["text"] + "\n" + think
479479

480480
self.discussion.new_group()
481-
self.discussion.append(response.tool_calls)
481+
# self.discussion.append(response.tool_calls) # No need to append tool calls anymore.
482482

483483
self.last_response = response
484484
self._responses.append(response) # may be useful for debugging
@@ -488,8 +488,11 @@ def get_action(self, obs: Any) -> float:
488488
tools_msg = MessageBuilder("tool_description").add_text(tools_str)
489489

490490
# Adding these extra messages to visualize in gradio
491-
messages.insert(0, tools_msg) # insert at the beginning of the messages
492-
messages.append(response.tool_calls)
491+
messages.insert(0, tools_msg) # insert at the beginning of the message
492+
# This avoids the assertion error with self.llm.user().add_responded_tool_calls(tool_calls)
493+
msg = self.llm.msg("tool")
494+
msg.responded_tool_calls = response.tool_calls
495+
messages.append(msg)
493496

494497
agent_info = bgym.AgentInfo(
495498
think=think,
@@ -499,7 +502,7 @@ def get_action(self, obs: Any) -> float:
499502
return action, agent_info
500503

501504

502-
OPENAI_MODEL_CONFIG = OpenAIResponseModelArgs(
505+
GPT_4_1 = OpenAIResponseModelArgs(
503506
model_name="gpt-4.1",
504507
max_total_tokens=200_000,
505508
max_input_tokens=200_000,
@@ -535,6 +538,32 @@ def get_action(self, obs: Any) -> float:
535538
vision_support=True,
536539
)
537540

541+
O3_RESPONSE_MODEL = OpenAIResponseModelArgs(
542+
model_name="o3-2025-04-16",
543+
max_total_tokens=200_000,
544+
max_input_tokens=200_000,
545+
max_new_tokens=2_000,
546+
temperature=None, # O3 does not support temperature
547+
vision_support=True,
548+
)
549+
O3_CHATAPI_MODEL = OpenAIChatModelArgs(
550+
model_name="o3-2025-04-16",
551+
max_total_tokens=200_000,
552+
max_input_tokens=200_000,
553+
max_new_tokens=2_000,
554+
temperature=None,
555+
vision_support=True,
556+
)
557+
558+
GPT4_1_OPENROUTER_MODEL = OpenRouterModelArgs(
559+
model_name="openai/gpt-4.1",
560+
max_total_tokens=200_000,
561+
max_input_tokens=200_000,
562+
max_new_tokens=2_000,
563+
temperature=None, # O3 does not support temperature
564+
vision_support=True,
565+
)
566+
538567
DEFAULT_PROMPT_CONFIG = PromptConfig(
539568
tag_screenshot=True,
540569
goal=Goal(goal_as_system_msg=True),
@@ -549,8 +578,8 @@ def get_action(self, obs: Any) -> float:
549578
summarizer=Summarizer(do_summary=True),
550579
general_hints=GeneralHints(use_hints=False),
551580
task_hint=TaskHint(use_task_hint=True),
552-
keep_last_n_obs=None, # keep only the last observation in the discussion
553-
multiaction=False, # whether to use multi-action or not
581+
keep_last_n_obs=None,
582+
multiaction=True, # whether to use multi-action or not
554583
# action_subsets=("bid",),
555584
action_subsets=("coord"),
556585
# action_subsets=("coord", "bid"),
@@ -561,6 +590,21 @@ def get_action(self, obs: Any) -> float:
561590
config=DEFAULT_PROMPT_CONFIG,
562591
)
563592

593+
OAI_AGENT = ToolUseAgentArgs(
594+
model_args=GPT_4_1,
595+
config=DEFAULT_PROMPT_CONFIG,
596+
)
597+
598+
OAI_CHATAPI_AGENT = ToolUseAgentArgs(
599+
model_args=O3_CHATAPI_MODEL,
600+
config=DEFAULT_PROMPT_CONFIG,
601+
)
602+
603+
OAI_OPENROUTER_AGENT = ToolUseAgentArgs(
604+
model_args=GPT4_1_OPENROUTER_MODEL,
605+
config=DEFAULT_PROMPT_CONFIG,
606+
)
607+
564608
OSWORLD_CLAUDE = ToolUseAgentArgs(
565609
model_args=CLAUDE_MODEL_CONFIG,
566610
config=PromptConfig(

src/agentlab/analyze/agent_xray.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from agentlab.llm.llm_utils import BaseMessage as AgentLabBaseMessage
2727
from agentlab.llm.llm_utils import Discussion
2828
from agentlab.llm.response_api import MessageBuilder
29+
from agentlab.llm.response_api import ToolCalls
2930

3031
select_dir_instructions = "Select Experiment Directory"
3132
AGENT_NAME_KEY = "agent.agent_name"
@@ -673,6 +674,9 @@ def dict_to_markdown(d: dict):
673674
str: A markdown-formatted string representation of the dictionary.
674675
"""
675676
if not isinstance(d, dict):
677+
if isinstance(d, ToolCalls):
678+
# ToolCalls rendered by to_markdown method.
679+
return ""
676680
warning(f"Expected dict, got {type(d)}")
677681
return repr(d)
678682
if not d:
@@ -1164,7 +1168,7 @@ def get_directory_contents(results_dir: Path):
11641168
most_recent_summary = max(summary_files, key=os.path.getctime)
11651169
summary_df = pd.read_csv(most_recent_summary)
11661170

1167-
if len(summary_df) == 0 or summary_df["avg_reward"].isna().all():
1171+
if len(summary_df) == 0:
11681172
continue # skip if all avg_reward are NaN
11691173

11701174
# get row with max avg_reward

0 commit comments

Comments
 (0)