Skip to content

Commit 0b3c6eb

Browse files
committed
fix: update gaia env
1 parent 75cc995 commit 0b3c6eb

File tree

4 files changed

+241
-145
lines changed

4 files changed

+241
-145
lines changed

openmanus_rl/environments/env_package/tool_use/envs.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,10 @@ class ToolUseEnv:
1616
Provides tasks from dataset and handles tool execution.
1717
"""
1818

19-
def __init__(self, tasks_data: List[Dict], available_tools: List[str], seed: int = 42):
19+
def __init__(self, tasks_data: List[Dict], available_tools: List[str], seed: int = 42, model_string: str = None):
2020
self.tasks_data = tasks_data
2121
self.available_tools = available_tools
22-
self.tool_manager = ToolManager(available_tools)
22+
self.tool_manager = ToolManager(available_tools, model_string=model_string)
2323
self.current_task_idx = 0
2424
self.seed = seed
2525
random.seed(seed)
@@ -65,7 +65,7 @@ class ToolUseEnvs:
6565
"""
6666

6767
def __init__(self, tasks_data: List[Dict], available_tools: List[str],
68-
seed: int = 0, env_num: int = 1, group_n: int = 1, is_train: bool = True):
68+
seed: int = 0, env_num: int = 1, group_n: int = 1, is_train: bool = True, model_string: str = None):
6969
self.tasks_data = tasks_data
7070
self.available_tools = available_tools
7171
self.num_processes = env_num * group_n
@@ -75,7 +75,7 @@ def __init__(self, tasks_data: List[Dict], available_tools: List[str],
7575
# Create individual environments
7676
self.envs = []
7777
for i in range(self.num_processes):
78-
env = ToolUseEnv(tasks_data, available_tools, seed + i)
78+
env = ToolUseEnv(tasks_data, available_tools, seed + i, model_string=model_string)
7979
self.envs.append(env)
8080

8181
# Track current task indices for each environment
@@ -121,8 +121,9 @@ def close(self):
121121
class ToolManager:
122122
"""Manages available tools and their execution"""
123123

124-
def __init__(self, tool_names: List[str]):
124+
def __init__(self, tool_names: List[str], model_string: str = None):
125125
self.tool_names = tool_names
126+
self.model_string = model_string
126127
self.available_tools = {}
127128
self._load_tools()
128129

@@ -140,7 +141,7 @@ def _load_tool(self, tool_name: str):
140141
tool_mapping = {
141142
'google_search': 'openmanus_rl.tools.google_search.tool.Google_Search_Tool',
142143
'wikipedia_knowledge_searcher': 'openmanus_rl.tools.wikipedia_knowledge_searcher.tool.Wikipedia_Knowledge_Searcher_Tool',
143-
'arxiv_paper_searcher': 'openmanus_rl.tools.arxiv_paper_searcher.tool.Arxiv_Paper_Searcher_Tool',
144+
'arxiv_paper_searcher': 'openmanus_rl.tools.arxiv_paper_searcher.tool.ArXiv_Paper_Searcher_Tool',
144145
'pubmed_search': 'openmanus_rl.tools.pubmed_search.tool.Pubmed_Search_Tool',
145146
'url_text_extractor': 'openmanus_rl.tools.url_text_extractor.tool.URL_Text_Extractor_Tool',
146147
'python_code_generator': 'openmanus_rl.tools.python_code_generator.tool.Python_Code_Generator_Tool',
@@ -150,13 +151,20 @@ def _load_tool(self, tool_name: str):
150151
print(f"Unknown tool: {tool_name}, skipping...")
151152
return
152153

154+
print(f"Loading tool: {tool_name}")
155+
153156
module_path = tool_mapping[tool_name]
154157
module_name, class_name = module_path.rsplit('.', 1)
155158

156159
# Import and instantiate the tool
157160
module = importlib.import_module(module_name)
158161
tool_class = getattr(module, class_name)
159-
tool_instance = tool_class()
162+
163+
# Check if tool requires LLM engine and pass model_string if available
164+
if hasattr(tool_class, 'require_llm_engine') and tool_class.require_llm_engine and self.model_string:
165+
tool_instance = tool_class(model_string=self.model_string)
166+
else:
167+
tool_instance = tool_class()
160168

161169
self.available_tools[tool_name] = tool_instance
162170

@@ -199,6 +207,6 @@ def execute_tool(self, tool_name: str, params: Dict) -> str:
199207

200208

201209
def build_tool_use_envs(tasks_data: List[Dict], available_tools: List[str],
202-
seed: int, env_num: int, group_n: int, is_train: bool = True):
210+
seed: int, env_num: int, group_n: int, is_train: bool = True, model_string: str = None):
203211
"""Build tool use environments"""
204-
return ToolUseEnvs(tasks_data, available_tools, seed, env_num, group_n, is_train)
212+
return ToolUseEnvs(tasks_data, available_tools, seed, env_num, group_n, is_train, model_string=model_string)

openmanus_rl/environments/env_package/tool_use/manager.py

Lines changed: 71 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ def __init__(self, envs, projection_f, config):
2323
self.ground_truths = []
2424
self.step_counts = []
2525
self.task_completed = []
26+
self.task_success = []
2627

2728
def reset(self):
2829
"""Reset environment and get new tasks"""
@@ -36,6 +37,7 @@ def reset(self):
3637
batch_size = len(self.current_tasks)
3738
self.step_counts = [0] * batch_size
3839
self.task_completed = [False] * batch_size
40+
self.task_success = [False] * batch_size
3941

4042
# Initialize memory
4143
self.memory.reset(batch_size=batch_size)
@@ -59,7 +61,7 @@ def step(self, text_actions: List[str]):
5961
for i, (action, valid) in enumerate(zip(actions, valids)):
6062
if self.task_completed[i]:
6163
observations.append("Task completed.")
62-
infos.append({'is_action_valid': True, 'won': True})
64+
infos.append({'is_action_valid': True, 'won': self.task_success[i]})
6365
continue
6466

6567
self.step_counts[i] += 1
@@ -70,14 +72,19 @@ def step(self, text_actions: List[str]):
7072

7173
# Check completion
7274
if self._is_completion_action(action):
75+
is_correct = self._evaluate_answer(action, i)
76+
self.task_success[i] = is_correct
7377
self.task_completed[i] = True
7478
dones[i] = True
79+
obs_feedback = "\n\nEvaluation: final answer matches the ground truth." if is_correct else "\n\nEvaluation: final answer does not match the ground truth."
80+
observations[-1] = obs + obs_feedback
7581
elif self.step_counts[i] >= self.config.env.max_steps:
7682
obs += "\n\nMaximum steps reached. Please provide your final answer in <answer></answer> tags."
7783
dones[i] = True
84+
observations[-1] = obs
7885

7986
info['is_action_valid'] = to_numpy(valid)
80-
info['won'] = self.task_completed[i]
87+
info['won'] = self.task_success[i]
8188
info['step_count'] = self.step_counts[i]
8289
infos.append(info)
8390

@@ -125,28 +132,78 @@ def _is_completion_action(self, action: str) -> bool:
125132
"""Check if action indicates task completion"""
126133
return action.startswith("FINAL_ANSWER:") or "<answer>" in action
127134

135+
def _evaluate_answer(self, action: str, batch_idx: int) -> bool:
136+
"""Compare model answer with ground truth"""
137+
predicted = self._extract_answer_text(action)
138+
ground_truth = self.ground_truths[batch_idx]
139+
return self._normalize_answer(predicted) == self._normalize_answer(ground_truth)
140+
141+
@staticmethod
142+
def _extract_answer_text(action: str) -> str:
143+
"""Extract answer text from action string"""
144+
if action.startswith("FINAL_ANSWER:"):
145+
return action.split("FINAL_ANSWER:", 1)[1].strip()
146+
147+
match = re.search(r"<answer>(.*?)</answer>", action, re.DOTALL)
148+
if match:
149+
return match.group(1).strip()
150+
return action.strip()
151+
152+
@staticmethod
153+
def _normalize_answer(text: str) -> str:
154+
"""Normalize answer string for comparison"""
155+
normalized = re.sub(r"\s+", " ", text).strip().lower()
156+
normalized = normalized.strip(".,!?:;\"")
157+
return normalized
158+
128159
def build_text_obs(self, observations: List[str] = None, init: bool = False) -> List[str]:
129160
"""Build text observations for agent"""
130161
batch_size = len(self.current_tasks)
131162
postprocess_text_obs = []
132-
163+
max_steps = getattr(self.config.env, "max_steps", None)
164+
history_length_cfg = getattr(self.config.env, "history_length", 0)
165+
166+
if not init and history_length_cfg > 0:
167+
memory_contexts, valid_lens = self.memory.fetch(
168+
history_length_cfg,
169+
obs_key="text_obs",
170+
action_key="action",
171+
)
172+
else:
173+
memory_contexts = [""] * batch_size
174+
valid_lens = [0] * batch_size
175+
133176
for i in range(batch_size):
134-
if init or self.config.env.history_length <= 0:
177+
current_obs = observations[i] if observations else "Continue with your task."
178+
should_use_last_step = (
179+
not init
180+
and not self.task_completed[i]
181+
and max_steps is not None
182+
and self.step_counts[i] >= max_steps - 1
183+
)
184+
185+
if init:
135186
obs = TOOL_USE_TEMPLATE_NO_HIS.format(
136187
task_description=self.current_tasks[i],
137188
available_tools=self.tool_metadata,
138189
current_observation="Start working on the task."
139190
)
140-
else:
141-
# Get history
142-
memory_contexts, valid_lens = self.memory.fetch(
143-
self.config.env.history_length,
144-
obs_key="text_obs",
145-
action_key="action"
191+
elif should_use_last_step:
192+
obs = TOOL_USE_TEMPLATE_LAST_STEP.format(
193+
task_description=self.current_tasks[i],
194+
step_count=self.step_counts[i],
195+
history_length=valid_lens[i],
196+
action_history=memory_contexts[i],
197+
current_step=self.step_counts[i] + 1,
198+
current_observation=current_obs,
146199
)
147-
148-
current_obs = observations[i] if observations else "Continue with your task."
149-
200+
elif history_length_cfg <= 0:
201+
obs = TOOL_USE_TEMPLATE_NO_HIS.format(
202+
task_description=self.current_tasks[i],
203+
available_tools=self.tool_metadata,
204+
current_observation=current_obs,
205+
)
206+
else:
150207
obs = TOOL_USE_TEMPLATE.format(
151208
task_description=self.current_tasks[i],
152209
step_count=self.step_counts[i],
@@ -156,7 +213,7 @@ def build_text_obs(self, observations: List[str] = None, init: bool = False) ->
156213
current_observation=current_obs,
157214
available_tools=self.tool_metadata
158215
)
159-
216+
160217
postprocess_text_obs.append(obs)
161218

162219
return postprocess_text_obs

openmanus_rl/environments/prompts/tool_use.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,10 @@
1717
4. When you have sufficient information, provide your final answer in <answer></answer> tags
1818
1919
Format for tool usage:
20-
<tool_call>
20+
<action>
2121
tool: [tool_name]
2222
parameters: {{"param1": "value1", "param2": "value2"}}
23-
</tool_call>
23+
</action>
2424
2525
Now it's your turn to take an action. You should first reason step-by-step about the current situation. This reasoning process MUST be enclosed within <plan> </plan> tags.
2626
Once you've finished your reasoning, you should either use a tool or provide your final answer within <answer> </answer> tags.
@@ -35,6 +35,8 @@
3535
You are now at step {current_step} and this is the final step.
3636
Current Observation: {current_observation}
3737
You must provide your final answer within <answer> </answer> tags.
38+
Even if the evidence is incomplete, infer the most plausible answer.
39+
Never respond with "unknown", "cannot determine", or similar phrases.
3840
"""
3941

4042
TOOL_USE_TEMPLATE = """
@@ -85,4 +87,3 @@
8587
</action>
8688
8789
"""
88-

0 commit comments

Comments
 (0)