2626from agentlab .llm .base_api import BaseModelArgs
2727from agentlab .llm .llm_utils import image_to_png_base64_url
2828from agentlab .llm .response_api import (
29+ APIPayload ,
2930 ClaudeResponseModelArgs ,
3031 LLMOutput ,
3132 MessageBuilder ,
3233 OpenAIChatModelArgs ,
3334 OpenAIResponseModelArgs ,
35+ OpenRouterModelArgs ,
36+ ToolCalls ,
3437)
3538from agentlab .llm .tracking import cost_tracker_decorator
3639
@@ -101,7 +104,8 @@ def flatten(self) -> list[MessageBuilder]:
101104 messages .extend (group .messages )
102105 # Mark all summarized messages for caching
103106 if i == len (self .groups ) - keep_last_n_obs :
104- messages [i ].mark_all_previous_msg_for_caching ()
107+ if not isinstance (messages [i ], ToolCalls ):
108+ messages [i ].mark_all_previous_msg_for_caching ()
105109 return messages
106110
107111 def set_last_summary (self , summary : MessageBuilder ):
@@ -130,8 +134,10 @@ class Goal(Block):
130134
131135 goal_as_system_msg : bool = True
132136
133- def apply (self , llm , discussion : StructuredDiscussion , obs : dict ) -> dict :
134- system_message = llm .msg .system ().add_text (SYS_MSG )
137+ def apply (
138+ self , llm , discussion : StructuredDiscussion , obs : dict , sys_msg : str = SYS_MSG
139+ ) -> dict :
140+ system_message = llm .msg .system ().add_text (sys_msg )
135141 discussion .append (system_message )
136142
137143 if self .goal_as_system_msg :
@@ -164,18 +170,16 @@ class Obs(Block):
164170 use_dom : bool = False
165171 use_som : bool = False
166172 use_tabs : bool = False
167- add_mouse_pointer : bool = False
173+ # add_mouse_pointer: bool = False
168174 use_zoomed_webpage : bool = False
169175 skip_preprocessing : bool = False
170176
171177 def apply (
172178 self , llm , discussion : StructuredDiscussion , obs : dict , last_llm_output : LLMOutput
173179 ) -> dict :
174- if last_llm_output .tool_calls is None :
175- obs_msg = llm .msg .user () # type: MessageBuilder
176- else :
177- obs_msg = llm .msg .tool (last_llm_output .raw_response ) # type: MessageBuilder
178180
181+ obs_msg = llm .msg .user ()
182+ tool_calls = last_llm_output .tool_calls
179183 if self .use_last_error :
180184 if obs ["last_action_error" ] != "" :
181185 obs_msg .add_text (f"Last action error:\n { obs ['last_action_error' ]} " )
@@ -186,13 +190,12 @@ def apply(
186190 else :
187191 screenshot = obs ["screenshot" ]
188192
189- if self .add_mouse_pointer :
190- # TODO this mouse pointer should be added at the browsergym level
191- screenshot = np .array (
192- agent_utils .add_mouse_pointer_from_action (
193- Image .fromarray (obs ["screenshot" ]), obs ["last_action" ]
194- )
195- )
193+ # if self.add_mouse_pointer:
194+ # screenshot = np.array(
195+ # agent_utils.add_mouse_pointer_from_action(
196+ # Image.fromarray(obs["screenshot"]), obs["last_action"]
197+ # )
198+ # )
196199
197200 obs_msg .add_image (image_to_png_base64_url (screenshot ))
198201 if self .use_axtree :
@@ -203,6 +206,13 @@ def apply(
203206 obs_msg .add_text (_format_tabs (obs ))
204207
205208 discussion .append (obs_msg )
209+
210+ if tool_calls :
211+ for call in tool_calls :
212+ call .response_text ("See Observation" )
213+ tool_response = llm .msg .add_responded_tool_calls (tool_calls )
214+ discussion .append (tool_response )
215+
206216 return obs_msg
207217
208218
@@ -253,8 +263,8 @@ def apply(self, llm, discussion: StructuredDiscussion) -> dict:
253263 msg = llm .msg .user ().add_text ("""Summarize\n """ )
254264
255265 discussion .append (msg )
256- # TODO need to make sure we don't force tool use here
257- summary_response = llm (messages = discussion .flatten (), tool_choice = "none" )
266+
267+ summary_response = llm (APIPayload ( messages = discussion .flatten ()) )
258268
259269 summary_msg = llm .msg .assistant ().add_text (summary_response .think )
260270 discussion .append (summary_msg )
@@ -319,24 +329,6 @@ def apply(self, llm, discussion: StructuredDiscussion, task_name: str) -> dict:
319329 discussion .append (msg )
320330
321331
322- class ToolCall (Block ):
323- def __init__ (self , tool_server ):
324- self .tool_server = tool_server
325-
326- def apply (self , llm , messages : list [MessageBuilder ], obs : dict ) -> dict :
327- # build the message by adding components to obs
328- response : LLMOutput = llm (messages = self .messages )
329-
330- messages .append (response .assistant_message ) # this is tool call
331-
332- tool_answer = self .tool_server .call_tool (response )
333- tool_msg = llm .msg .tool () # type: MessageBuilder
334- tool_msg .add_tool_id (response .last_computer_call_id )
335- tool_msg .update_last_raw_response (response )
336- tool_msg .add_text (str (tool_answer ))
337- messages .append (tool_msg )
338-
339-
340332@dataclass
341333class PromptConfig :
342334 tag_screenshot : bool = True # Whether to tag the screenshot with the last action.
@@ -401,7 +393,7 @@ def __init__(
401393
402394 self .call_ids = []
403395
404- self .llm = model_args .make_model (extra_kwargs = { "tools" : self . tools } )
396+ self .llm = model_args .make_model ()
405397 self .msg_builder = model_args .get_message_builder ()
406398 self .llm .msg = self .msg_builder
407399
@@ -451,7 +443,13 @@ def get_action(self, obs: Any) -> float:
451443 self .llm .reset_stats ()
452444 if not self .discussion .is_goal_set ():
453445 self .discussion .new_group ("goal" )
454- self .config .goal .apply (self .llm , self .discussion , obs )
446+
447+ if self .config .multiaction :
448+ sys_msg = SYS_MSG + "\n You can take multiple actions in a single step, if needed."
449+ else :
450+ sys_msg = SYS_MSG + "\n You can only take one action at a time."
451+ self .config .goal .apply (self .llm , self .discussion , obs , sys_msg )
452+
455453 self .config .summarizer .apply_init (self .llm , self .discussion )
456454 self .config .general_hints .apply (self .llm , self .discussion )
457455 self .task_hint .apply (self .llm , self .discussion , self .task_name )
@@ -464,21 +462,23 @@ def get_action(self, obs: Any) -> float:
464462
465463 messages = self .discussion .flatten ()
466464 response : LLMOutput = self .llm (
467- messages = messages ,
468- tool_choice = "any" ,
469- cache_tool_definition = True ,
470- cache_complete_prompt = False ,
471- use_cache_breakpoints = True ,
465+ APIPayload (
466+ messages = messages ,
467+ tools = self .tools , # You can update tools available tools now.
468+ tool_choice = "any" ,
469+ cache_tool_definition = True ,
470+ cache_complete_prompt = False ,
471+ use_cache_breakpoints = True ,
472+ )
472473 )
473-
474474 action = response .action
475475 think = response .think
476476 last_summary = self .discussion .get_last_summary ()
477477 if last_summary is not None :
478478 think = last_summary .content [0 ]["text" ] + "\n " + think
479479
480480 self .discussion .new_group ()
481- self .discussion .append (response .tool_calls )
481+ # self.discussion.append(response.tool_calls) # No need to append tool calls anymore.
482482
483483 self .last_response = response
484484 self ._responses .append (response ) # may be useful for debugging
@@ -488,8 +488,11 @@ def get_action(self, obs: Any) -> float:
488488 tools_msg = MessageBuilder ("tool_description" ).add_text (tools_str )
489489
490490 # Adding these extra messages to visualize in gradio
491- messages .insert (0 , tools_msg ) # insert at the beginning of the messages
492- messages .append (response .tool_calls )
491+ messages .insert (0 , tools_msg ) # insert at the beginning of the message
492+ # This avoids the assertion error with self.llm.user().add_responded_tool_calls(tool_calls)
493+ msg = self .llm .msg ("tool" )
494+ msg .responded_tool_calls = response .tool_calls
495+ messages .append (msg )
493496
494497 agent_info = bgym .AgentInfo (
495498 think = think ,
@@ -499,7 +502,7 @@ def get_action(self, obs: Any) -> float:
499502 return action , agent_info
500503
501504
502- OPENAI_MODEL_CONFIG = OpenAIResponseModelArgs (
505+ GPT_4_1 = OpenAIResponseModelArgs (
503506 model_name = "gpt-4.1" ,
504507 max_total_tokens = 200_000 ,
505508 max_input_tokens = 200_000 ,
@@ -535,6 +538,32 @@ def get_action(self, obs: Any) -> float:
535538 vision_support = True ,
536539)
537540
541+ O3_RESPONSE_MODEL = OpenAIResponseModelArgs (
542+ model_name = "o3-2025-04-16" ,
543+ max_total_tokens = 200_000 ,
544+ max_input_tokens = 200_000 ,
545+ max_new_tokens = 2_000 ,
546+ temperature = None , # O3 does not support temperature
547+ vision_support = True ,
548+ )
549+ O3_CHATAPI_MODEL = OpenAIChatModelArgs (
550+ model_name = "o3-2025-04-16" ,
551+ max_total_tokens = 200_000 ,
552+ max_input_tokens = 200_000 ,
553+ max_new_tokens = 2_000 ,
554+ temperature = None ,
555+ vision_support = True ,
556+ )
557+
558+ GPT4_1_OPENROUTER_MODEL = OpenRouterModelArgs (
559+ model_name = "openai/gpt-4.1" ,
560+ max_total_tokens = 200_000 ,
561+ max_input_tokens = 200_000 ,
562+ max_new_tokens = 2_000 ,
563+ temperature = None , # O3 does not support temperature
564+ vision_support = True ,
565+ )
566+
538567DEFAULT_PROMPT_CONFIG = PromptConfig (
539568 tag_screenshot = True ,
540569 goal = Goal (goal_as_system_msg = True ),
@@ -549,8 +578,8 @@ def get_action(self, obs: Any) -> float:
549578 summarizer = Summarizer (do_summary = True ),
550579 general_hints = GeneralHints (use_hints = False ),
551580 task_hint = TaskHint (use_task_hint = True ),
552- keep_last_n_obs = None , # keep only the last observation in the discussion
553- multiaction = False , # whether to use multi-action or not
581+ keep_last_n_obs = None ,
582+ multiaction = True , # whether to use multi-action or not
554583 # action_subsets=("bid",),
555584 action_subsets = ("coord" ),
556585 # action_subsets=("coord", "bid"),
@@ -561,6 +590,21 @@ def get_action(self, obs: Any) -> float:
561590 config = DEFAULT_PROMPT_CONFIG ,
562591)
563592
593+ OAI_AGENT = ToolUseAgentArgs (
594+ model_args = GPT_4_1 ,
595+ config = DEFAULT_PROMPT_CONFIG ,
596+ )
597+
598+ OAI_CHATAPI_AGENT = ToolUseAgentArgs (
599+ model_args = O3_CHATAPI_MODEL ,
600+ config = DEFAULT_PROMPT_CONFIG ,
601+ )
602+
603+ OAI_OPENROUTER_AGENT = ToolUseAgentArgs (
604+ model_args = GPT4_1_OPENROUTER_MODEL ,
605+ config = DEFAULT_PROMPT_CONFIG ,
606+ )
607+
564608OSWORLD_CLAUDE = ToolUseAgentArgs (
565609 model_args = CLAUDE_MODEL_CONFIG ,
566610 config = PromptConfig (
0 commit comments