Skip to content

Commit 7b8d24e

Browse files
Merge pull request #257 from ServiceNow/aj/multiaction
Multiaction Support and refactoring
2 parents d92e0bf + a6f5349 commit 7b8d24e

File tree

6 files changed

+929
-572
lines changed

6 files changed

+929
-572
lines changed

src/agentlab/agents/tool_use_agent/tool_use_agent.py

Lines changed: 80 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,14 @@
2222
from agentlab.agents.agent_args import AgentArgs
2323
from agentlab.llm.llm_utils import image_to_png_base64_url
2424
from agentlab.llm.response_api import (
25+
APIPayload,
2526
ClaudeResponseModelArgs,
2627
LLMOutput,
2728
MessageBuilder,
2829
OpenAIChatModelArgs,
2930
OpenAIResponseModelArgs,
31+
OpenRouterModelArgs,
32+
ToolCalls,
3033
)
3134
from agentlab.llm.tracking import cost_tracker_decorator
3235

@@ -98,7 +101,8 @@ def flatten(self) -> list[MessageBuilder]:
98101
messages.extend(group.messages)
99102
# Mark all summarized messages for caching
100103
if i == len(self.groups) - keep_last_n_obs:
101-
messages[i].mark_all_previous_msg_for_caching()
104+
if not isinstance(messages[i], ToolCalls):
105+
messages[i].mark_all_previous_msg_for_caching()
102106
return messages
103107

104108
def set_last_summary(self, summary: MessageBuilder):
@@ -163,18 +167,15 @@ class Obs(Block):
163167
use_dom: bool = False
164168
use_som: bool = False
165169
use_tabs: bool = False
166-
add_mouse_pointer: bool = False
170+
# add_mouse_pointer: bool = False
167171
use_zoomed_webpage: bool = False
168172

169173
def apply(
170174
self, llm, discussion: StructuredDiscussion, obs: dict, last_llm_output: LLMOutput
171175
) -> dict:
172176

173-
if last_llm_output.tool_calls is None:
174-
obs_msg = llm.msg.user() # type: MessageBuilder
175-
else:
176-
obs_msg = llm.msg.tool(last_llm_output.raw_response) # type: MessageBuilder
177-
177+
obs_msg = llm.msg.user()
178+
tool_calls = last_llm_output.tool_calls
178179
if self.use_last_error:
179180
if obs["last_action_error"] != "":
180181
obs_msg.add_text(f"Last action error:\n{obs['last_action_error']}")
@@ -186,13 +187,12 @@ def apply(
186187
else:
187188
screenshot = obs["screenshot"]
188189

189-
if self.add_mouse_pointer:
190-
# TODO this mouse pointer should be added at the browsergym level
191-
screenshot = np.array(
192-
agent_utils.add_mouse_pointer_from_action(
193-
Image.fromarray(obs["screenshot"]), obs["last_action"]
194-
)
195-
)
190+
# if self.add_mouse_pointer:
191+
# screenshot = np.array(
192+
# agent_utils.add_mouse_pointer_from_action(
193+
# Image.fromarray(obs["screenshot"]), obs["last_action"]
194+
# )
195+
# )
196196

197197
obs_msg.add_image(image_to_png_base64_url(screenshot))
198198
if self.use_axtree:
@@ -203,6 +203,13 @@ def apply(
203203
obs_msg.add_text(_format_tabs(obs))
204204

205205
discussion.append(obs_msg)
206+
207+
if tool_calls:
208+
for call in tool_calls:
209+
call.response_text("See Observation")
210+
tool_response = llm.msg.add_responded_tool_calls(tool_calls)
211+
discussion.append(tool_response)
212+
206213
return obs_msg
207214

208215

@@ -254,8 +261,8 @@ def apply(self, llm, discussion: StructuredDiscussion) -> dict:
254261
msg = llm.msg.user().add_text("""Summarize\n""")
255262

256263
discussion.append(msg)
257-
# TODO need to make sure we don't force tool use here
258-
summary_response = llm(messages=discussion.flatten(), tool_choice="none")
264+
265+
summary_response = llm(APIPayload(messages=discussion.flatten()))
259266

260267
summary_msg = llm.msg.assistant().add_text(summary_response.think)
261268
discussion.append(summary_msg)
@@ -320,25 +327,6 @@ def apply(self, llm, discussion: StructuredDiscussion, task_name: str) -> dict:
320327
discussion.append(msg)
321328

322329

323-
class ToolCall(Block):
324-
325-
def __init__(self, tool_server):
326-
self.tool_server = tool_server
327-
328-
def apply(self, llm, messages: list[MessageBuilder], obs: dict) -> dict:
329-
# build the message by adding components to obs
330-
response: LLMOutput = llm(messages=self.messages)
331-
332-
messages.append(response.assistant_message) # this is tool call
333-
334-
tool_answer = self.tool_server.call_tool(response)
335-
tool_msg = llm.msg.tool() # type: MessageBuilder
336-
tool_msg.add_tool_id(response.last_computer_call_id)
337-
tool_msg.update_last_raw_response(response)
338-
tool_msg.add_text(str(tool_answer))
339-
messages.append(tool_msg)
340-
341-
342330
@dataclass
343331
class PromptConfig:
344332
tag_screenshot: bool = True # Whether to tag the screenshot with the last action.
@@ -394,7 +382,7 @@ def __init__(
394382

395383
self.call_ids = []
396384

397-
self.llm = model_args.make_model(extra_kwargs={"tools": self.tools})
385+
self.llm = model_args.make_model()
398386
self.msg_builder = model_args.get_message_builder()
399387
self.llm.msg = self.msg_builder
400388

@@ -462,21 +450,23 @@ def get_action(self, obs: Any) -> float:
462450

463451
messages = self.discussion.flatten()
464452
response: LLMOutput = self.llm(
465-
messages=messages,
466-
tool_choice="any",
467-
cache_tool_definition=True,
468-
cache_complete_prompt=False,
469-
use_cache_breakpoints=True,
453+
APIPayload(
454+
messages=messages,
455+
tools=self.tools, # You can update tools available tools now.
456+
tool_choice="any",
457+
cache_tool_definition=True,
458+
cache_complete_prompt=False,
459+
use_cache_breakpoints=True,
460+
)
470461
)
471-
472462
action = response.action
473463
think = response.think
474464
last_summary = self.discussion.get_last_summary()
475465
if last_summary is not None:
476466
think = last_summary.content[0]["text"] + "\n" + think
477467

478468
self.discussion.new_group()
479-
self.discussion.append(response.tool_calls)
469+
# self.discussion.append(response.tool_calls) # No need to append tool calls anymore.
480470

481471
self.last_response = response
482472
self._responses.append(response) # may be useful for debugging
@@ -486,8 +476,11 @@ def get_action(self, obs: Any) -> float:
486476
tools_msg = MessageBuilder("tool_description").add_text(tools_str)
487477

488478
# Adding these extra messages to visualize in gradio
489-
messages.insert(0, tools_msg) # insert at the beginning of the messages
490-
messages.append(response.tool_calls)
479+
messages.insert(0, tools_msg) # insert at the beginning of the message
480+
# This avoids the assertion error with self.llm.user().add_responded_tool_calls(tool_calls)
481+
msg = self.llm.msg("tool")
482+
msg.responded_tool_calls = response.tool_calls
483+
messages.append(msg)
491484

492485
agent_info = bgym.AgentInfo(
493486
think=think,
@@ -533,6 +526,31 @@ def get_action(self, obs: Any) -> float:
533526
vision_support=True,
534527
)
535528

529+
O3_RESPONSE_MODEL = OpenAIResponseModelArgs(
530+
model_name="o3-2025-04-16",
531+
max_total_tokens=200_000,
532+
max_input_tokens=200_000,
533+
max_new_tokens=2_000,
534+
temperature=None, # O3 does not support temperature
535+
vision_support=True,
536+
)
537+
O3_CHATAPI_MODEL = OpenAIChatModelArgs(
538+
model_name="o3-2025-04-16",
539+
max_total_tokens=200_000,
540+
max_input_tokens=200_000,
541+
max_new_tokens=2_000,
542+
temperature=None,
543+
vision_support=True,
544+
)
545+
546+
GPT4_1_OPENROUTER_MODEL = OpenRouterModelArgs(
547+
model_name="openai/gpt-4.1",
548+
max_total_tokens=200_000,
549+
max_input_tokens=200_000,
550+
max_new_tokens=2_000,
551+
temperature=None, # O3 does not support temperature
552+
vision_support=True,
553+
)
536554

537555
DEFAULT_PROMPT_CONFIG = PromptConfig(
538556
tag_screenshot=True,
@@ -548,8 +566,8 @@ def get_action(self, obs: Any) -> float:
548566
summarizer=Summarizer(do_summary=True),
549567
general_hints=GeneralHints(use_hints=False),
550568
task_hint=TaskHint(use_task_hint=True),
551-
keep_last_n_obs=None, # keep only the last observation in the discussion
552-
multiaction=False, # whether to use multi-action or not
569+
keep_last_n_obs=None,
570+
multiaction=True, # whether to use multi-action or not
553571
# action_subsets=("bid",),
554572
action_subsets=("coord"),
555573
# action_subsets=("coord", "bid"),
@@ -559,3 +577,18 @@ def get_action(self, obs: Any) -> float:
559577
model_args=CLAUDE_MODEL_CONFIG,
560578
config=DEFAULT_PROMPT_CONFIG,
561579
)
580+
581+
OAI_AGENT = ToolUseAgentArgs(
582+
model_args=GPT_4_1,
583+
config=DEFAULT_PROMPT_CONFIG,
584+
)
585+
586+
OAI_CHATAPI_AGENT = ToolUseAgentArgs(
587+
model_args=O3_CHATAPI_MODEL,
588+
config=DEFAULT_PROMPT_CONFIG,
589+
)
590+
591+
OAI_OPENROUTER_AGENT = ToolUseAgentArgs(
592+
model_args=GPT4_1_OPENROUTER_MODEL,
593+
config=DEFAULT_PROMPT_CONFIG,
594+
)

src/agentlab/analyze/agent_xray.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from agentlab.llm.llm_utils import BaseMessage as AgentLabBaseMessage
2727
from agentlab.llm.llm_utils import Discussion
2828
from agentlab.llm.response_api import MessageBuilder
29+
from agentlab.llm.response_api import ToolCalls
2930

3031
select_dir_instructions = "Select Experiment Directory"
3132
AGENT_NAME_KEY = "agent.agent_name"
@@ -673,6 +674,9 @@ def dict_to_markdown(d: dict):
673674
str: A markdown-formatted string representation of the dictionary.
674675
"""
675676
if not isinstance(d, dict):
677+
if isinstance(d, ToolCalls):
678+
# ToolCalls rendered by to_markdown method.
679+
return ""
676680
warning(f"Expected dict, got {type(d)}")
677681
return repr(d)
678682
if not d:

0 commit comments

Comments
 (0)