Skip to content

Commit 1909cb5

Browse files
authored
Merge pull request #69 from OpenManus/fix/gaia-output-format
revise gaia code to make it align with React output format + add test…
2 parents 25ff2d4 + 0229366 commit 1909cb5

File tree

2 files changed

+177
-51
lines changed

2 files changed

+177
-51
lines changed

openmanus_rl/agentgym/agentenv-gaia/agentenv_gaia/environment.py

Lines changed: 90 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,14 @@
1717

1818
# List of default tools to include in every environment
1919
DEFAULT_TOOLS = ["web_search", "bash", "python_execute", "browser_use", "terminate"]
20-
20+
# Define valid action types
21+
VALID_ACTIONS = {
22+
"web_search": {"param": "query"},
23+
"bash": {"param": "command"},
24+
"python_execute": {"param": "code"},
25+
"browser_use": {"param": "input"},
26+
"terminate": {"param": "answer"}
27+
}
2128

2229

2330
# GaiaEnvServer class to manage environment instances
@@ -239,32 +246,45 @@ def step(self, env_id: str, action: str):
239246
with self.env_locks[env_id]:
240247
env = self.env_instances[env_id]
241248

242-
# Parse and process the action
243-
action_type, action_input = self._parse_action(action)
249+
try:
250+
# Parse and process the action
251+
action_type, action_input = self._parse_action(action)
244252

245-
# Update state
246-
env["state"]["steps_taken"] += 1
253+
# Update state
254+
env["state"]["steps_taken"] += 1
247255

248-
# Process the action using available tools
249-
observation, reward, done = self._process_action(env_id, action_type, action_input)
256+
# Process the action using available tools
257+
observation, reward, done = self._process_action(env_id, action_type, action_input)
250258

251-
# Store action and result in memory
252-
env["state"]["memory"].append({
253-
"action": action,
254-
"result": observation,
255-
"step": env["state"]["steps_taken"]
256-
})
259+
# Store action and result in memory
260+
env["state"]["memory"].append({
261+
"action": action,
262+
"result": observation,
263+
"step": env["state"]["steps_taken"]
264+
})
257265

258-
# Update state with results
259-
env["state"]["reward"] += reward
260-
env["state"]["done"] = done
266+
# Update state with results
267+
env["state"]["reward"] += reward
268+
env["state"]["done"] = done
261269

262-
info = {
263-
"steps_taken": env["state"]["steps_taken"],
264-
"action_processed": action,
265-
}
270+
info = {
271+
"steps_taken": env["state"]["steps_taken"],
272+
"action_processed": action,
273+
}
266274

267-
return observation, reward, done, info
275+
return observation, reward, done, info
276+
277+
except ValueError as e:
278+
# Handle validation errors
279+
error_msg = str(e)
280+
observation = f"Error: {error_msg}"
281+
reward = -0.1 # Penalty for invalid action
282+
done = False
283+
info = {
284+
"steps_taken": env["state"]["steps_taken"],
285+
"error": error_msg
286+
}
287+
return observation, reward, done, info
268288

269289
def _parse_action(self, action: str) -> Tuple[str, Dict[str, Any]]:
270290
"""
@@ -275,53 +295,71 @@ def _parse_action(self, action: str) -> Tuple[str, Dict[str, Any]]:
275295
276296
Returns:
277297
tuple: (action_type, action_parameters)
298+
299+
Raises:
300+
ValueError: If action format is invalid
278301
"""
279302
try:
303+
# Split by newlines and find the action line
304+
lines = action.split('\n')
305+
action_line = None
306+
for line in lines:
307+
if line.strip().startswith('Action:'):
308+
action_line = line.strip()
309+
break
310+
311+
if not action_line:
312+
raise ValueError("No valid action found in the input")
313+
280314
# Parse actions in format "Action: action_type with Action Input: action_input"
281-
if "Action:" in action and "Action Input:" in action:
282-
action_parts = action.split("Action Input:", 1)
315+
if "Action:" in action_line and "Action Input:" in action_line:
316+
action_parts = action_line.split("Action Input:", 1)
283317
action_type = action_parts[0].replace("Action:", "").strip()
284318
action_input = action_parts[1].strip()
285-
return action_type, {"query": action_input} if action_type == "web_search" else {
286-
"command": action_input} if action_type == "bash" else {
287-
"code": action_input} if action_type == "python_execute" else {"input": action_input}
319+
320+
# Validate action type
321+
if action_type not in VALID_ACTIONS:
322+
raise ValueError(f"Invalid action type: {action_type}. Valid types are: {', '.join(VALID_ACTIONS.keys())}")
323+
324+
# Return with appropriate parameter name
325+
param_name = VALID_ACTIONS[action_type]["param"]
326+
return action_type, {param_name: action_input}
288327

289328
# Parse JSON format with tool_name and parameters
290-
elif action.strip().startswith("{") and action.strip().endswith("}"):
329+
elif action_line.strip().startswith("{") and action_line.strip().endswith("}"):
291330
try:
292-
action_json = json.loads(action)
293-
if "tool_name" in action_json:
294-
tool_name = action_json.pop("tool_name")
295-
return tool_name, action_json
331+
action_json = json.loads(action_line)
332+
if "tool_name" not in action_json:
333+
raise ValueError("JSON format must include 'tool_name' field")
334+
335+
tool_name = action_json.pop("tool_name")
336+
if tool_name not in VALID_ACTIONS:
337+
raise ValueError(f"Invalid tool name: {tool_name}. Valid tools are: {', '.join(VALID_ACTIONS.keys())}")
338+
339+
return tool_name, action_json
296340
except json.JSONDecodeError:
297-
pass
341+
raise ValueError("Invalid JSON format")
298342

299343
# Parse direct tool calls (e.g., "web_search: query")
300-
elif ":" in action:
301-
tool_name, input_text = action.split(":", 1)
344+
elif ":" in action_line:
345+
tool_name, input_text = action_line.split(":", 1)
302346
tool_name = tool_name.strip()
303347
input_text = input_text.strip()
304348

305-
# Map to appropriate parameter name based on tool
306-
if tool_name.lower() in ["web_search", "search_web"]:
307-
return tool_name, {"query": input_text}
308-
elif tool_name.lower() == "bash":
309-
return tool_name, {"command": input_text}
310-
elif tool_name.lower() == "python_execute":
311-
return tool_name, {"code": input_text}
312-
elif tool_name.lower() == "browser_use":
313-
return tool_name, {"url": input_text}
314-
elif tool_name.lower() == "terminate":
315-
return "terminate", {"answer": input_text}
316-
else:
317-
return tool_name, {"input": input_text}
349+
if tool_name not in VALID_ACTIONS:
350+
raise ValueError(f"Invalid tool name: {tool_name}. Valid tools are: {', '.join(VALID_ACTIONS.keys())}")
318351

319-
# Handle plain text as a default case
320-
return "unknown", {"input": action}
352+
param_name = VALID_ACTIONS[tool_name]["param"]
353+
return tool_name, {param_name: input_text}
354+
355+
else:
356+
raise ValueError("Invalid action format. Must be one of:\n" +
357+
"1. 'Action: tool_name with Action Input: input'\n" +
358+
"2. JSON format with tool_name and parameters\n" +
359+
"3. 'tool_name: input'")
321360

322361
except Exception as e:
323-
print(f"Error parsing action: {e}")
324-
return "unknown", {"input": action}
362+
raise ValueError(f"Error parsing action: {str(e)}")
325363

326364
def _process_action(self, env_id: str, action_type: str, action_input: Dict[str, Any]) -> Tuple[str, float, bool]:
327365
"""
@@ -495,6 +533,7 @@ def _format_new_task(self, env_id: str):
495533
observation += "\nUse tools in the format: Action: tool_name with Action Input: your_input\n"
496534
observation += "\nGive me one action."
497535

536+
498537
return observation
499538

500539
def _format_observation(self, env_id: str):
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
import requests
2+
import json
3+
4+
def test_gaia_react_format():
5+
# 服务器地址
6+
base_url = "http://127.0.0.1:8081"
7+
8+
# 1. 创建新环境
9+
create_response = requests.post(
10+
f"{base_url}/create",
11+
json={
12+
"data_dir": "data/",
13+
"level": "level1",
14+
"dataset": "validation"
15+
}
16+
)
17+
print(f"Create response: {create_response.text}")
18+
env_idx = create_response.text.strip('"') # 移除引号
19+
print(f"Created environment with ID: {env_idx}")
20+
21+
if not env_idx:
22+
print("Failed to create environment")
23+
return
24+
25+
# 2. 获取初始观察
26+
obs_response = requests.get(
27+
f"{base_url}/observation",
28+
params={"env_idx": env_idx}
29+
)
30+
print(f"Observation response: {obs_response.text}")
31+
try:
32+
initial_observation = obs_response.json().get("observation")
33+
except:
34+
initial_observation = obs_response.text
35+
print(f"\nInitial observation:\n{initial_observation}")
36+
37+
# 3. 测试ReAct格式的响应
38+
# 使用完整的ReAct格式,包含Thought和Action
39+
react_response = """Thought: I need to search for information about the capital of France.
40+
Action: web_search Action Input: What is the capital of France?"""
41+
42+
# 4. 执行动作
43+
step_response = requests.post(
44+
f"{base_url}/step",
45+
json={
46+
"env_idx": env_idx,
47+
"action": react_response
48+
}
49+
)
50+
print(f"Step response: {step_response.text}")
51+
try:
52+
result = step_response.json()
53+
except:
54+
result = {"response": step_response.text}
55+
print(f"\nStep result:\n{json.dumps(result, indent=2)}")
56+
57+
# 5. 验证输出格式
58+
print("\nValidating ReAct format:")
59+
print("1. Response contains 'Thought:' section:", "Thought:" in react_response)
60+
print("2. Response contains 'Action:' section:", "Action:" in react_response)
61+
print("3. Action format is correct:", "Action Input:" in react_response)
62+
63+
# 提取并验证action部分
64+
action_line = None
65+
for line in react_response.split('\n'):
66+
if line.strip().startswith('Action:'):
67+
action_line = line.strip()
68+
break
69+
70+
if action_line:
71+
print("4. Action line format is valid:", action_line.startswith('Action:'))
72+
print("5. Tool name is valid:", action_line.split("Action:")[1].split("Action Input:")[0].strip() in ["web_search", "bash", "python_execute", "browser_use", "terminate"])
73+
74+
# 6. 获取可用动作
75+
actions_response = requests.get(
76+
f"{base_url}/available_actions",
77+
params={"env_idx": env_idx}
78+
)
79+
print(f"Available actions response: {actions_response.text}")
80+
try:
81+
available_actions = actions_response.json()
82+
except:
83+
available_actions = {"response": actions_response.text}
84+
print(f"\nAvailable actions:\n{json.dumps(available_actions, indent=2)}")
85+
86+
if __name__ == "__main__":
87+
test_gaia_react_format()

0 commit comments

Comments
 (0)