Skip to content

Commit 52e113b

Browse files
authored
Merge pull request #5 from Agent-One-Lab/multimodal
Multimodal Training This PR officially adds multimodal training, including a visual question-answering and gui agent training. More docs will be added later.
2 parents d36e4fa + 1658b51 commit 52e113b

File tree

17 files changed

+1362
-10
lines changed

17 files changed

+1362
-10
lines changed

agents/agents/agents/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from .react.react_agent import ReactAgent
22
from .specialized.code_agent import CodeAgent
33
from .specialized.think_agent import ThinkAgent
4+
from .specialized.gui_agent import GUIAgent
45

5-
__all__ = ["ReactAgent", "CodeAgent", "ThinkAgent"]
6+
__all__ = ["ReactAgent", "CodeAgent", "ThinkAgent", "GUIAgent"]

agents/agents/agents/agent_base.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,7 @@ def _init_llm_engine(self, model_name_or_path: str, backend: str):
157157

158158
def set_llm_engine(self, llm_engine: Any, tokenizer: Any, processor: Any):
159159
assert self.backend == "async_verl", "Only async verl backend is supported for now"
160+
160161
self.llm_engine.llm_engine = llm_engine
161162
self.tokenizer = tokenizer
162163
self.processor = processor

agents/agents/agents/auto.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
from typing import Any, Callable, Dict, List, Optional, Type, Union
22

33
from .specialized.think_agent import ThinkAgent
4-
from agents.agents.specialized.openai_agent import OpenAIAgent
4+
from .specialized.openai_agent import OpenAIAgent
55
from ..tools import get_tools_from_names
66
from .agent_base import BaseAgent
77
from .react.react_agent import ReactAgent
88
from .specialized.code_agent import CodeAgent
9+
from .specialized.gui_agent import GUIAgent
910
from ..rewards.reward_base import get_reward_from_name
1011

1112

@@ -165,4 +166,5 @@ def from_pretrained(
165166
AutoAgent.register_agent("react", ReactAgent)
166167
AutoAgent.register_agent("code", CodeAgent)
167168
AutoAgent.register_agent("openai", OpenAIAgent)
168-
AutoAgent.register_agent("think", ThinkAgent)
169+
AutoAgent.register_agent("think", ThinkAgent)
170+
AutoAgent.register_agent("gui", GUIAgent)

agents/agents/agents/chain/chain_base.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -312,6 +312,10 @@ async def _run_single_chain(self,
312312
)
313313
thought_node.is_terminal = new_msg.get("status", "continue") in self.terminal_status
314314
current_node = thought_node
315+
316+
# Check if the thought node is terminal - if so, break the loop
317+
if current_node.is_terminal:
318+
break
315319

316320
# Handle tool calls
317321
if current_node.messages[-1].get("tool_calls"):
Lines changed: 226 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,226 @@
1+
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
import json
5+
from typing import List, Any, Tuple, Dict, Optional
6+
from ..agent_base import BaseAgent
7+
from agents.utils.ui_action_parser import parse_action_to_structure_output, IMAGE_FACTOR
8+
9+
# Default image dimensions
10+
TEST_IMAGE_HEIGHT = 1080
11+
TEST_IMAGE_WIDTH = 1920
12+
13+
# GUI Agent system prompt
14+
GUI_AGENT_SYSTEM_PROMPT = """You are a GUI automation agent. Given a task and screenshot, you must analyze the screen and perform the next action.
15+
16+
## Response Format (REQUIRED)
17+
You MUST always respond with exactly two lines:
18+
Thought: [Describe what you see and what action to take]
19+
Action: [Choose ONE action from the list below]
20+
21+
## Action Space
22+
23+
click(start_box='<|box_start|>(x1,y1)<|box_end|>')
24+
type(content='xxx') # Use escape characters \\', \\", and \\n in content part
25+
scroll(direction='down or up or right or left')
26+
27+
## Examples
28+
Example 1:
29+
Thought: I need to click on the search button at coordinates (100, 200).
30+
Action: click(start_box='<|box_start|>(100,200)<|box_end|>')
31+
32+
Example 2:
33+
Thought: I need to type "hello world" in the text field.
34+
Action: type(content='hello world')
35+
36+
## Note
37+
- Use English in `Thought` and `Action` part
38+
- Always provide both Thought and Action lines
39+
- Coordinates should be in pixel values"""
40+
41+
42+
class GUIAgent(BaseAgent):
43+
"""GUI Agent for interacting with graphical user interfaces."""
44+
45+
def __init__(
46+
self,
47+
model_name_or_path: str,
48+
template: str,
49+
tools: List = None,
50+
**kwargs
51+
):
52+
"""
53+
Initialize GUI Agent.
54+
55+
Args:
56+
model_name_or_path: Path to the vision-language model
57+
template: Template name for formatting prompts
58+
tools: List of tools available to the agent
59+
**kwargs: Additional arguments
60+
"""
61+
super().__init__(
62+
model_name_or_path=model_name_or_path,
63+
template=template,
64+
system_prompt=GUI_AGENT_SYSTEM_PROMPT,
65+
tools=tools,
66+
max_length=kwargs.get("max_length", 8192),
67+
**kwargs
68+
)
69+
self.action_counter = 0 # Track number of actions taken
70+
self.max_retries = 3 # Maximum retries for empty responses
71+
72+
# def _init_llm_engine(self, model_name_or_path: str, backend: str = "vllm"):
73+
# """
74+
# Override to handle vision-language models properly.
75+
76+
# For GUI agents using vision-language models like Qwen2.5-VL,
77+
# we need special handling since they're not standard causal LM models.
78+
# """
79+
# # For unit tests or when model loading should be skipped
80+
# # if model_name_or_path == "ByteDance-Seed/UI-TARS-1.5-7B":
81+
# # # Return mock objects for testing
82+
# # print(f"[GUIAgent] Skipping actual model load for testing: {model_name_or_path}")
83+
# # return None, None, None
84+
85+
# # Otherwise use parent's initialization
86+
# return super()._init_llm_engine(model_name_or_path, backend)
87+
88+
def parse(self, responses: List[str], tools: List[Any]) -> List[Dict[str, Any]]:
89+
"""
90+
Parse model responses into structured messages.
91+
92+
Args:
93+
responses: List of model response strings
94+
tools: List of available tools
95+
96+
Returns:
97+
List of structured messages with tool calls
98+
"""
99+
print(f"[GUIAgent.parse] Number of responses: {len(responses)}")
100+
print(f"[GUIAgent.parse] Raw responses type: {type(responses)}")
101+
102+
new_messages_list = []
103+
104+
# Process each response
105+
processed_responses = []
106+
for resp in responses:
107+
if resp and "Thought:" in resp and "Action:" in resp:
108+
processed_responses.append(resp)
109+
elif resp and resp.strip():
110+
# Try to reformat responses that don't have the expected format
111+
resp_lower = resp.lower()
112+
print(f"[GUIAgent.parse] Response missing format, reformatting: {resp[:100]}")
113+
114+
# Check if it contains action-like content
115+
if any(action in resp_lower for action in ['click', 'type', 'scroll']):
116+
formatted_resp = f"Thought: Executing action based on response.\nAction: {resp.strip()}"
117+
else:
118+
# Default to click at center if no clear action
119+
formatted_resp = f"Thought: {resp.strip()}\nAction: click(start_box='<|box_start|>(960,540)<|box_end|>')"
120+
processed_responses.append(formatted_resp)
121+
else:
122+
# Handle empty responses with default click at center
123+
self.action_counter += 1
124+
processed_responses.append(f"Thought: Processing the screen (attempt {self.action_counter}).\nAction: click(start_box='<|box_start|>(960,540)<|box_end|>')")
125+
126+
responses = processed_responses
127+
128+
# Log responses for debugging
129+
for idx, resp in enumerate(responses[:3]): # Log first 3 responses
130+
if resp:
131+
print(f"[GUIAgent.parse] Response {idx} length: {len(resp)}, preview: {resp[:200]}")
132+
else:
133+
print(f"[GUIAgent.parse] Response {idx} is None or empty")
134+
135+
# Parse actions from responses
136+
action_list = []
137+
for response in responses:
138+
parsed = parse_action_to_structure_output(
139+
response,
140+
IMAGE_FACTOR,
141+
TEST_IMAGE_HEIGHT,
142+
TEST_IMAGE_WIDTH
143+
)
144+
action_list.append(parsed)
145+
146+
# Create messages with tool calls
147+
for i, (response, actions) in enumerate(zip(responses, action_list)):
148+
print(f"[GUIAgent.parse] Processing response {i+1}: response_length={len(response) if response else 0}, actions={actions}")
149+
150+
tool_calls = []
151+
152+
if actions is not None and len(actions) > 0:
153+
if len(actions) > 1:
154+
print(f"[GUIAgent.parse] Warning: Multiple actions found ({len(actions)}), using first one")
155+
action = actions[0]
156+
tool_calls = [{
157+
"id": str(i),
158+
"type": "function",
159+
"function": {
160+
"name": "pyautogui_code_generator",
161+
"arguments": json.dumps({"action": action})
162+
}
163+
}]
164+
else:
165+
# If no action was parsed, create a default click action at center
166+
print(f"[GUIAgent.parse] No action parsed from response, creating default click action")
167+
default_action = {
168+
"action_type": "click",
169+
"action_inputs": {"start_box": "(960, 540)"},
170+
"thought": "Clicking at screen center",
171+
"reflection": None
172+
}
173+
tool_calls = [{
174+
"id": str(i),
175+
"type": "function",
176+
"function": {
177+
"name": "pyautogui_code_generator",
178+
"arguments": json.dumps({"action": default_action})
179+
}
180+
}]
181+
182+
# Always terminate after one turn since we only have 3 action types
183+
# and no explicit termination action
184+
status = "terminal"
185+
if actions and isinstance(actions[0], dict):
186+
action_type = actions[0].get("action_type", "")
187+
print(f"[GUIAgent.parse] Action type: {action_type}, terminating after one turn")
188+
189+
message = {
190+
"role": "assistant",
191+
"content": [{"type": "text", "text": response}] if response else [{"type": "text", "text": ""}],
192+
"tool_calls": tool_calls,
193+
"loss": True,
194+
"status": status
195+
}
196+
print(f"[GUIAgent.parse] Created message with status={status}, tool_calls={len(tool_calls)}, content_length={len(response)}")
197+
new_messages_list.append(message)
198+
199+
return new_messages_list
200+
201+
def format_messages(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
202+
"""
203+
Format messages for the vision-language model.
204+
205+
Args:
206+
messages: List of messages to format
207+
208+
Returns:
209+
Formatted messages suitable for VLM input
210+
"""
211+
formatted_messages = []
212+
213+
for msg in messages:
214+
formatted_msg = {
215+
"role": msg.get("role"),
216+
"content": msg.get("content", "")
217+
}
218+
219+
# Handle image content if present
220+
if "images" in msg:
221+
# Convert images to appropriate format for the model
222+
formatted_msg["images"] = msg["images"]
223+
224+
formatted_messages.append(formatted_msg)
225+
226+
return formatted_messages

agents/agents/rewards/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from .webshop_reward import webshop_reward
55
from .alfworld_reward import alfworld_episode_reward
66
from .scienceworld_reward import scienceworld_reward
7+
from .gui_reward import gui_reward
78

89

9-
__all__ = ["alfworld_episode_reward","qa_f1_reward", "math_reward", "math_reward_tool", "math_reward_think", "RewardFunction", "get_reward_from_name", "get_rewards_from_names", "list_available_rewards", "register_reward", "llm_as_judge_client_math_reward", "webshop_reward", "alfworld_episode_reward"]
10+
__all__ = ["alfworld_episode_reward","qa_f1_reward", "math_reward", "math_reward_tool", "math_reward_think", "RewardFunction", "get_reward_from_name", "get_rewards_from_names", "list_available_rewards", "register_reward", "llm_as_judge_client_math_reward", "webshop_reward", "alfworld_episode_reward", "gui_reward"]

0 commit comments

Comments
 (0)