Skip to content

Commit 8461301

Browse files
committed
add gaia scorer for reward in gym, fix envargs serialization
1 parent 809ad00 commit 8461301

File tree

6 files changed

+149
-26
lines changed

6 files changed

+149
-26
lines changed

scripts/run_gaia.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@
1616
benchmark=GaiaBenchmark(split="validation"),
1717
agent_args=TapeAgentArgs("gaia_agent"),
1818
comment="Gaia eval",
19-
logging_level=logging.DEBUG,
20-
logging_level_stdout=logging.DEBUG,
19+
logging_level=logging.INFO,
20+
logging_level_stdout=logging.INFO,
2121
)
2222
print(f"Exp args list len: {len(study.exp_args_list)}")
2323
study.exp_args_list = study.exp_args_list[:1]

src/agentlab/agents/tapeagent/agent.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def __init__(self, agent: Agent, tape: Tape):
3939
self.tape = tape
4040

4141
def obs_preprocessor(self, obs: Any) -> Any:
42-
logger.info(f"Observation: {obs}")
42+
logger.info(f"Observations: {type(obs)}")
4343
return obs
4444

4545
def get_action(self, obs: Observation | list[Observation]) -> tuple[str, TapeAgentInfo]:
@@ -52,10 +52,6 @@ def get_action(self, obs: Observation | list[Observation]) -> tuple[str, TapeAge
5252
action = None
5353
while not action:
5454
for event in self.agent.run(self.tape):
55-
if event.final_tape:
56-
logger.info(
57-
f"agent run final tape state: {[type(s).__name__ for s in self.tape]}"
58-
)
5955
if not event.step:
6056
continue
6157
self.tape = self.tape.append(event.step)
@@ -64,7 +60,7 @@ def get_action(self, obs: Observation | list[Observation]) -> tuple[str, TapeAge
6460
logger.info(f"Thought: {event.step.llm_view()}")
6561
elif isinstance(event.step, Action) and not action:
6662
action = event.step
67-
logger.info(f"Action: {action}")
63+
logger.info(f"Action: {action.llm_view()}")
6864
# we stop at the first action
6965
else:
7066
logger.info(f"Other step: {type(event.step)}")

src/agentlab/benchmarks/abstract_env.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,12 @@
11
from abc import ABC, abstractmethod
2+
from dataclasses import dataclass
23

34
import gymnasium as gym
5+
from dataclasses_json import DataClassJsonMixin
46
from pydantic import BaseModel
57

68

7-
class AbstractEnvArgs(BaseModel, frozen=True):
8-
"""Easily serialiazable class to store the arguments of an environment"""
9-
10-
task_seed: int = 0
11-
task_name: str = ""
12-
9+
class AbstractEnvArgs(ABC):
1310
@abstractmethod
1411
def make_env(self, action_mapping, exp_dir, exp_task_kwargs) -> "AbstractEnv":
1512
"""Create an instance of the environment with the arguments stored in this object.
@@ -25,9 +22,17 @@ def make_env(self, action_mapping, exp_dir, exp_task_kwargs) -> "AbstractEnv":
2522
"""
2623

2724

25+
@dataclass
26+
class SerializableEnvArgs(AbstractEnvArgs, DataClassJsonMixin):
27+
"""Easily serialiazable class to store the arguments of an environment"""
28+
29+
task_seed: int = 0
30+
task_name: str = ""
31+
32+
2833
class AbstractBenchmark(BaseModel):
2934
name: str
30-
env_args_list: list[AbstractEnvArgs]
35+
env_args_list: list
3136

3237
def get_version(self) -> int:
3338
return "1"

src/agentlab/benchmarks/gaia.py

Lines changed: 130 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,22 @@
11
import logging
22
import os
3+
import re
34
import shutil
5+
import string
46
from pathlib import Path
57
from typing import Any, Literal
68

79
import datasets
810
from pydantic import Field
9-
from tapeagents.core import Observation, StopStep, Thought
11+
from tapeagents.core import Action, Observation, StopStep, Thought
1012
from tapeagents.environment import ContainerExecutor, StatefulTool, Tool
1113
from tapeagents.steps import ImageObservation
1214
from tapeagents.tools.browser import Browser
1315
from tapeagents.tools.code_executor import CodeExecutor
1416
from tapeagents.tools.media_reader import VideoReader
1517
from tapeagents.tools.web_search import WebSearch
1618

17-
from agentlab.benchmarks.abstract_env import AbstractBenchmark, AbstractEnvArgs
19+
from agentlab.benchmarks.abstract_env import AbstractBenchmark, SerializableEnvArgs
1820
from agentlab.benchmarks.multitool_gym import MultiToolGym
1921

2022
logger = logging.getLogger(__name__)
@@ -38,14 +40,41 @@ def reset(self, seed=None) -> tuple[list[Observation], dict]:
3840
steps.append(image_obs)
3941
return steps, {}
4042

41-
def step(self, action: str) -> tuple[Observation, float, bool, bool, dict]:
42-
logger.info(f"env step called with action {type(action)}")
43-
return super().step(action)
43+
def step(self, action: Action) -> tuple[Observation, float, bool, bool, dict]:
44+
logger.info(f"Gym step called with action {type(action)}")
45+
observation, reward, terminated, truncated, env_info = super().step(action)
46+
logger.info(f"Gym observation: {observation.short_view()}")
47+
return observation, reward, terminated, truncated, env_info
4448

49+
def calculate_reward(self, action: Action) -> float:
50+
if isinstance(action, GaiaAnswer):
51+
model_answer = action.answer
52+
ground_truth = self.task["Final answer"]
53+
reward = 1.0 if question_scorer(model_answer, ground_truth) else 0.0
54+
else:
55+
reward = 0.0
4556

46-
class GaiaGymArgs(AbstractEnvArgs, frozen=True):
57+
if reward == 1.0:
58+
logger.info(f"Task {self.task['task_id']} solved.")
59+
else:
60+
logger.info(f"Task {self.task['task_id']} failed.")
61+
62+
return reward
63+
64+
65+
class GaiaGymArgs(SerializableEnvArgs):
4766
task: dict[str, Any]
48-
viewport_chars: int = 64000
67+
viewport_chars: int
68+
task_seed: int
69+
task_name: str
70+
71+
def __init__(
72+
self, task_name: str, task: dict[str, Any], viewport_chars: int = 64000, task_seed: int = 0
73+
):
74+
self.task_name = task_name
75+
self.task = task
76+
self.viewport_chars = viewport_chars
77+
self.task_seed = task_seed
4978

5079
def make_env(self, exp_dir: str | Path, action_mapping=None) -> GaiaGym:
5180
exp_dir = str(exp_dir)
@@ -80,7 +109,7 @@ def model_post_init(self, __context: Any) -> None:
80109
self.env_args_list = []
81110
dataset = datasets.load_dataset("gaia-benchmark/GAIA", "2023_all")[self.split]
82111
for task in dataset:
83-
env_args = GaiaGymArgs(task_name="gaia_" + task["task_id"], task=task)
112+
env_args = GaiaGymArgs(task_name="gaia." + task["task_id"], task=task)
84113
self.env_args_list.append(env_args)
85114

86115

@@ -143,3 +172,96 @@ class GaiaAnswer(StopStep):
143172
)
144173
answer: Any = Field(description="Short final answer")
145174
long_answer: str = Field(description="Detailed final answer not restricted by format rules")
175+
176+
177+
def normalize_number_str(number_str: str) -> float:
178+
# we replace these common units and commas to allow
179+
# conversion to float
180+
for char in ["$", "%", ","]:
181+
number_str = number_str.replace(char, "")
182+
try:
183+
return float(number_str)
184+
except ValueError:
185+
logger.info(f"String {number_str} cannot be normalized to number str.")
186+
return float("inf")
187+
188+
189+
def split_string(
190+
s: str,
191+
char_list: list[str] = [",", ";"],
192+
) -> list[str]:
193+
pattern = f"[{''.join(char_list)}]"
194+
return re.split(pattern, s)
195+
196+
197+
def question_scorer(
198+
model_answer: str,
199+
ground_truth: str,
200+
) -> bool:
201+
def is_float(element: any) -> bool:
202+
try:
203+
float(element)
204+
return True
205+
except ValueError:
206+
return False
207+
208+
# if gt is a number
209+
if is_float(ground_truth):
210+
logger.info(f"Evaluating {model_answer} as a number.")
211+
normalized_answer = normalize_number_str(model_answer)
212+
return normalized_answer == float(ground_truth)
213+
214+
# if gt is a list
215+
elif any(char in ground_truth for char in [",", ";"]):
216+
logger.info(f"Evaluating {model_answer} as a comma separated list.")
217+
# question with the fish: normalization removes punct
218+
219+
gt_elems = split_string(ground_truth)
220+
ma_elems = split_string(model_answer)
221+
222+
# check length is the same
223+
if len(gt_elems) != len(ma_elems):
224+
logger.warning("Answer lists have different lengths, returning False.", UserWarning)
225+
return False
226+
227+
# compare each element as float or str
228+
comparisons = []
229+
for ma_elem, gt_elem in zip(ma_elems, gt_elems):
230+
if is_float(gt_elem):
231+
normalized_ma_elem = normalize_number_str(ma_elem)
232+
comparisons.append(normalized_ma_elem == float(gt_elem))
233+
else:
234+
# we do not remove punct since comparisons can include punct
235+
comparisons.append(
236+
normalize_str(ma_elem, remove_punct=False)
237+
== normalize_str(gt_elem, remove_punct=False)
238+
)
239+
return all(comparisons)
240+
241+
# if gt is a str
242+
else:
243+
logger.info(f"Evaluating {model_answer} as a string.")
244+
return normalize_str(model_answer) == normalize_str(ground_truth)
245+
246+
247+
def normalize_str(input_str, remove_punct=True) -> str:
248+
"""
249+
Normalize a string by:
250+
- Removing all white spaces
251+
- Optionally removing punctuation (if remove_punct is True)
252+
- Converting to lowercase
253+
Parameters:
254+
- input_str: str, the string to normalize
255+
- remove_punct: bool, whether to remove punctuation (default: True)
256+
Returns:
257+
- str, the normalized string
258+
"""
259+
# Remove all white spaces. Required e.g for seagull vs. sea gull
260+
no_spaces = re.sub(r"\s", "", input_str)
261+
262+
# Remove punctuation, if specified.
263+
if remove_punct:
264+
translator = str.maketrans("", "", string.punctuation)
265+
return no_spaces.lower().translate(translator)
266+
else:
267+
return no_spaces.lower()

src/agentlab/benchmarks/multitool_gym.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def step(self, action: Action) -> tuple[Observation, float, bool, bool, dict]:
3333
observation = self._env.step(action)
3434
action_exec_stop = time.time()
3535

36-
reward = self.calculate_reward()
36+
reward = self.calculate_reward(action)
3737

3838
truncated = False
3939

@@ -45,7 +45,7 @@ def step(self, action: Action) -> tuple[Observation, float, bool, bool, dict]:
4545
}
4646
return observation, reward, terminated, truncated, env_info
4747

48-
def calculate_reward(self) -> float:
48+
def calculate_reward(self, action: Action) -> float:
4949
return 0.0
5050

5151
def close(self):

src/agentlab/experiments/launch_exp.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ def run_experiments(
9898
logging.info("All jobs are finished. Calling agent_args.close() on all agents...")
9999
for exp_args in exp_args_list:
100100
exp_args.agent_args.close()
101-
logging.info("Experiment finished.")
101+
logging.info(f"Experiment finished and saved in {study_dir}.")
102102

103103

104104
def find_incomplete(study_dir: str | Path, include_errors=True):

0 commit comments

Comments
 (0)