2222from agentlab .agents .agent_args import AgentArgs
2323from agentlab .llm .llm_utils import image_to_png_base64_url
2424from agentlab .llm .response_api import (
25+ APIPayload ,
2526 ClaudeResponseModelArgs ,
2627 LLMOutput ,
2728 MessageBuilder ,
2829 OpenAIChatModelArgs ,
2930 OpenAIResponseModelArgs ,
31+ OpenRouterModelArgs ,
32+ ToolCalls ,
3033)
3134from agentlab .llm .tracking import cost_tracker_decorator
3235
@@ -98,7 +101,8 @@ def flatten(self) -> list[MessageBuilder]:
98101 messages .extend (group .messages )
99102 # Mark all summarized messages for caching
100103 if i == len (self .groups ) - keep_last_n_obs :
101- messages [i ].mark_all_previous_msg_for_caching ()
104+ if not isinstance (messages [i ], ToolCalls ):
105+ messages [i ].mark_all_previous_msg_for_caching ()
102106 return messages
103107
104108 def set_last_summary (self , summary : MessageBuilder ):
@@ -163,18 +167,15 @@ class Obs(Block):
163167 use_dom : bool = False
164168 use_som : bool = False
165169 use_tabs : bool = False
166- add_mouse_pointer : bool = False
170+ # add_mouse_pointer: bool = False
167171 use_zoomed_webpage : bool = False
168172
169173 def apply (
170174 self , llm , discussion : StructuredDiscussion , obs : dict , last_llm_output : LLMOutput
171175 ) -> dict :
172176
173- if last_llm_output .tool_calls is None :
174- obs_msg = llm .msg .user () # type: MessageBuilder
175- else :
176- obs_msg = llm .msg .tool (last_llm_output .raw_response ) # type: MessageBuilder
177-
177+ obs_msg = llm .msg .user ()
178+ tool_calls = last_llm_output .tool_calls
178179 if self .use_last_error :
179180 if obs ["last_action_error" ] != "" :
180181 obs_msg .add_text (f"Last action error:\n { obs ['last_action_error' ]} " )
@@ -186,13 +187,12 @@ def apply(
186187 else :
187188 screenshot = obs ["screenshot" ]
188189
189- if self .add_mouse_pointer :
190- # TODO this mouse pointer should be added at the browsergym level
191- screenshot = np .array (
192- agent_utils .add_mouse_pointer_from_action (
193- Image .fromarray (obs ["screenshot" ]), obs ["last_action" ]
194- )
195- )
190+ # if self.add_mouse_pointer:
191+ # screenshot = np.array(
192+ # agent_utils.add_mouse_pointer_from_action(
193+ # Image.fromarray(obs["screenshot"]), obs["last_action"]
194+ # )
195+ # )
196196
197197 obs_msg .add_image (image_to_png_base64_url (screenshot ))
198198 if self .use_axtree :
@@ -203,6 +203,13 @@ def apply(
203203 obs_msg .add_text (_format_tabs (obs ))
204204
205205 discussion .append (obs_msg )
206+
207+ if tool_calls :
208+ for call in tool_calls :
209+ call .response_text ("See Observation" )
210+ tool_response = llm .msg .add_responded_tool_calls (tool_calls )
211+ discussion .append (tool_response )
212+
206213 return obs_msg
207214
208215
@@ -254,8 +261,8 @@ def apply(self, llm, discussion: StructuredDiscussion) -> dict:
254261 msg = llm .msg .user ().add_text ("""Summarize\n """ )
255262
256263 discussion .append (msg )
257- # TODO need to make sure we don't force tool use here
258- summary_response = llm (messages = discussion .flatten (), tool_choice = "none" )
264+
265+ summary_response = llm (APIPayload ( messages = discussion .flatten ()) )
259266
260267 summary_msg = llm .msg .assistant ().add_text (summary_response .think )
261268 discussion .append (summary_msg )
@@ -320,25 +327,6 @@ def apply(self, llm, discussion: StructuredDiscussion, task_name: str) -> dict:
320327 discussion .append (msg )
321328
322329
323- class ToolCall (Block ):
324-
325- def __init__ (self , tool_server ):
326- self .tool_server = tool_server
327-
328- def apply (self , llm , messages : list [MessageBuilder ], obs : dict ) -> dict :
329- # build the message by adding components to obs
330- response : LLMOutput = llm (messages = self .messages )
331-
332- messages .append (response .assistant_message ) # this is tool call
333-
334- tool_answer = self .tool_server .call_tool (response )
335- tool_msg = llm .msg .tool () # type: MessageBuilder
336- tool_msg .add_tool_id (response .last_computer_call_id )
337- tool_msg .update_last_raw_response (response )
338- tool_msg .add_text (str (tool_answer ))
339- messages .append (tool_msg )
340-
341-
342330@dataclass
343331class PromptConfig :
344332 tag_screenshot : bool = True # Whether to tag the screenshot with the last action.
@@ -394,7 +382,7 @@ def __init__(
394382
395383 self .call_ids = []
396384
397- self .llm = model_args .make_model (extra_kwargs = { "tools" : self . tools } )
385+ self .llm = model_args .make_model ()
398386 self .msg_builder = model_args .get_message_builder ()
399387 self .llm .msg = self .msg_builder
400388
@@ -462,21 +450,23 @@ def get_action(self, obs: Any) -> float:
462450
463451 messages = self .discussion .flatten ()
464452 response : LLMOutput = self .llm (
465- messages = messages ,
466- tool_choice = "any" ,
467- cache_tool_definition = True ,
468- cache_complete_prompt = False ,
469- use_cache_breakpoints = True ,
453+ APIPayload (
454+ messages = messages ,
455+ tools = self .tools , # You can update tools available tools now.
456+ tool_choice = "any" ,
457+ cache_tool_definition = True ,
458+ cache_complete_prompt = False ,
459+ use_cache_breakpoints = True ,
460+ )
470461 )
471-
472462 action = response .action
473463 think = response .think
474464 last_summary = self .discussion .get_last_summary ()
475465 if last_summary is not None :
476466 think = last_summary .content [0 ]["text" ] + "\n " + think
477467
478468 self .discussion .new_group ()
479- self .discussion .append (response .tool_calls )
469+ # self.discussion.append(response.tool_calls) # No need to append tool calls anymore.
480470
481471 self .last_response = response
482472 self ._responses .append (response ) # may be useful for debugging
@@ -486,8 +476,11 @@ def get_action(self, obs: Any) -> float:
486476 tools_msg = MessageBuilder ("tool_description" ).add_text (tools_str )
487477
488478 # Adding these extra messages to visualize in gradio
489- messages .insert (0 , tools_msg ) # insert at the beginning of the messages
490- messages .append (response .tool_calls )
479+ messages .insert (0 , tools_msg ) # insert at the beginning of the message
480+ # This avoids the assertion error with self.llm.user().add_responded_tool_calls(tool_calls)
481+ msg = self .llm .msg ("tool" )
482+ msg .responded_tool_calls = response .tool_calls
483+ messages .append (msg )
491484
492485 agent_info = bgym .AgentInfo (
493486 think = think ,
@@ -533,6 +526,31 @@ def get_action(self, obs: Any) -> float:
533526 vision_support = True ,
534527)
535528
529+ O3_RESPONSE_MODEL = OpenAIResponseModelArgs (
530+ model_name = "o3-2025-04-16" ,
531+ max_total_tokens = 200_000 ,
532+ max_input_tokens = 200_000 ,
533+ max_new_tokens = 2_000 ,
534+ temperature = None , # O3 does not support temperature
535+ vision_support = True ,
536+ )
537+ O3_CHATAPI_MODEL = OpenAIChatModelArgs (
538+ model_name = "o3-2025-04-16" ,
539+ max_total_tokens = 200_000 ,
540+ max_input_tokens = 200_000 ,
541+ max_new_tokens = 2_000 ,
542+ temperature = None ,
543+ vision_support = True ,
544+ )
545+
546+ GPT4_1_OPENROUTER_MODEL = OpenRouterModelArgs (
547+ model_name = "openai/gpt-4.1" ,
548+ max_total_tokens = 200_000 ,
549+ max_input_tokens = 200_000 ,
550+ max_new_tokens = 2_000 ,
551+ temperature = None , # O3 does not support temperature
552+ vision_support = True ,
553+ )
536554
537555DEFAULT_PROMPT_CONFIG = PromptConfig (
538556 tag_screenshot = True ,
@@ -548,8 +566,8 @@ def get_action(self, obs: Any) -> float:
548566 summarizer = Summarizer (do_summary = True ),
549567 general_hints = GeneralHints (use_hints = False ),
550568 task_hint = TaskHint (use_task_hint = True ),
551- keep_last_n_obs = None , # keep only the last observation in the discussion
552- multiaction = False , # whether to use multi-action or not
569+ keep_last_n_obs = None ,
570+ multiaction = True , # whether to use multi-action or not
553571 # action_subsets=("bid",),
554572 action_subsets = ("coord" ),
555573 # action_subsets=("coord", "bid"),
@@ -559,3 +577,18 @@ def get_action(self, obs: Any) -> float:
559577 model_args = CLAUDE_MODEL_CONFIG ,
560578 config = DEFAULT_PROMPT_CONFIG ,
561579)
580+
581+ OAI_AGENT = ToolUseAgentArgs (
582+ model_args = GPT_4_1 ,
583+ config = DEFAULT_PROMPT_CONFIG ,
584+ )
585+
586+ OAI_CHATAPI_AGENT = ToolUseAgentArgs (
587+ model_args = O3_CHATAPI_MODEL ,
588+ config = DEFAULT_PROMPT_CONFIG ,
589+ )
590+
591+ OAI_OPENROUTER_AGENT = ToolUseAgentArgs (
592+ model_args = GPT4_1_OPENROUTER_MODEL ,
593+ config = DEFAULT_PROMPT_CONFIG ,
594+ )
0 commit comments