Skip to content

Commit c59e60b

Browse files
committed
working gaia bench and gym, with test
1 parent 2ef8500 commit c59e60b

File tree

4 files changed

+49
-67
lines changed

4 files changed

+49
-67
lines changed

src/agentlab/benchmarks/abstract_env.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,23 @@
11
from abc import ABC, abstractmethod
22

3-
import gym
3+
import gymnasium as gym
44
from pydantic import BaseModel
55

66

7+
class AbstractBenchmark(BaseModel):
8+
name: str
9+
env_args_list: list = None
10+
11+
def get_version(self) -> int:
12+
return "1"
13+
14+
def prepare_backends(self):
15+
pass
16+
17+
def dependency_graph_over_tasks(self) -> dict[str, list[str]]:
18+
return {}
19+
20+
721
class AbstractEnvArgs(BaseModel):
822
"""Easily serialiazable class to store the arguments of an environment"""
923

src/agentlab/benchmarks/gaia.py

Lines changed: 5 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
import shutil
33
from typing import Any, Literal
44

5-
import bgym
65
import datasets
76
from pydantic import Field
87
from tapeagents.core import Observation, StopStep, Thought
@@ -12,24 +11,16 @@
1211
from tapeagents.tools.media_reader import VideoReader
1312
from tapeagents.tools.web_search import WebSearch
1413

15-
from agentlab.benchmarks.abstract_env import AbstractEnvArgs
14+
from agentlab.benchmarks.abstract_env import AbstractBenchmark, AbstractEnvArgs
1615
from agentlab.benchmarks.multitool_gym import MultiToolGym
1716

1817

19-
class GaiaBenchmark(bgym.Benchmark):
20-
name = "gaia"
21-
split: Literal["test", "validation"]
18+
class GaiaBenchmark(AbstractBenchmark):
2219
exp_dir: str
20+
name: str = "gaia"
21+
split: Literal["test", "validation"]
2322

24-
high_level_action_set_args = None
25-
is_multi_tab = False
26-
supports_parallel_seeds = False
27-
backends = ["gaia"]
28-
env_args_list = None
29-
task_metadata = None
30-
31-
def __post_init__(self):
32-
super().__post_init__()
23+
def model_post_init(self, __context: Any) -> None:
3324
self.env_args_list = []
3425
dataset = datasets.load_dataset("gaia-benchmark/GAIA", "2023_all")[self.split]
3526
for task in dataset:
@@ -45,7 +36,6 @@ class GaiaGym(MultiToolGym):
4536

4637
class GaiaGymArgs(AbstractEnvArgs):
4738
task: dict[str, Any]
48-
split: Literal["test", "validation"]
4939
exp_dir: str
5040
viewport_chars: int = 64000
5141

src/agentlab/benchmarks/multitool_gym.py

Lines changed: 3 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -1,58 +1,12 @@
1-
from typing import Any, Literal, Union
2-
3-
from pydantic import Annotated, Field, TypeAdapter
41
from tapeagents.core import Action, Observation, Tape
5-
from tapeagents.environment import ToolCollectionEnvironment
6-
from tapeagents.tools.base import Multitool, Tool
72

83
from agentlab.benchmarks.abstract_env import AbstractEnv
94

105
EnvTape = Tape[None, Action | Observation]
116

127

13-
class FunctionCall(Action):
14-
kind: Literal["function_call_action"] = ["function_call_action"]
15-
function_name: str
16-
args: list[Any] | None
17-
kwargs: dict[str, Any] | None
18-
19-
20-
class FunctionCallResult(Observation):
21-
kind: Literal["function_call_result"] = ["function_call_result"]
22-
result: Any
23-
24-
25-
class SimpleFunctionCallTool(Tool):
26-
action = FunctionCall
27-
observation = FunctionCallResult
28-
function: callable
29-
function_name: str = ""
30-
31-
def model_post_init(self, __context):
32-
function_name = getattr(self.function, "__name__", "")
33-
if not function_name and not self.function_name:
34-
raise ValueError("Function has no name, function_name must be provided")
35-
36-
def execute_action(self, action: FunctionCall) -> FunctionCallResult:
37-
if not self.function_name == action.function_name:
38-
raise ValueError(
39-
f"Unexpected function action {action.function_name}, expected {self.function_name}"
40-
)
41-
result = self.function(*action.args, **action.kwargs)
42-
return FunctionCallResult(result=result)
43-
44-
458
class MultiToolGym(AbstractEnv):
46-
def __init__(self, tools: list[Tool | Multitool]):
47-
self._env = ToolCollectionEnvironment(tools)
48-
self._actions = self._env.actions()
49-
self._actions_parser: TypeAdapter = TypeAdapter(
50-
Annotated[Union[self._actions], Field(discriminator="kind")]
51-
)
52-
self.reset()
53-
549
def reset(self):
55-
self._tape: EnvTape = EnvTape(steps=[])
5610
self._env.reset()
5711

5812
def step(self, action: str):
@@ -61,14 +15,12 @@ def step(self, action: str):
6115
except Exception:
6216
raise ValueError("Action must be a valid JSON dict")
6317
assert isinstance(action_step, Action), "{action_step.kind} is not an Action"
64-
self._tape += [action_step]
65-
self._tape = self._env.react(self._tape)
66-
observation_step: Observation = self._tape.steps[-1]
18+
observation = self._env.step(action_step)
6719
reward = self.calculate_reward()
6820
terminated = False
6921
truncated = False
70-
env_info = {"step_metadata": observation_step.metadata}
71-
return observation_step.llm_dict(), reward, terminated, truncated, env_info
22+
env_info = {"step_metadata": observation.metadata}
23+
return observation, reward, terminated, truncated, env_info
7224

7325
def calculate_reward(self) -> float:
7426
return 0.0

tests/agents/test_gaia_agent.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,34 @@
11
from agentlab.agents.tapeagent import TapeAgent, TapeAgentArgs
2+
from agentlab.benchmarks.gaia import GaiaBenchmark
23

34

45
def test_agent_creation():
56
args = TapeAgentArgs(agent_name="gaia_agent")
67
agent = args.make_agent()
78
assert isinstance(agent, TapeAgent)
89
assert agent.agent.name == "gaia_agent"
10+
11+
12+
def test_gaia_bench():
13+
exp_dir = "/tmp/"
14+
bench = GaiaBenchmark(exp_dir=exp_dir, split="validation")
15+
assert bench.name == "gaia"
16+
assert bench.split == "validation"
17+
assert bench.exp_dir == exp_dir
18+
assert len(bench.env_args_list) == 165
19+
20+
assert bench.env_args_list[5].exp_dir == "gaia/32102e3e-d12a-4209-9163-7b3a104efe5d"
21+
assert bench.env_args_list[5].viewport_chars == 64000
22+
task = bench.env_args_list[5].task
23+
question = """The attached spreadsheet shows the inventory for a movie and video game rental store in Seattle, Washington. What is the title of the oldest Blu-Ray recorded in this spreadsheet? Return it as appearing in the spreadsheet."""
24+
steps = """1. Open the attached file.\n2. Compare the years given in the Blu-Ray section to find the oldest year, 2009.\n3. Find the title of the Blu-Ray disc that corresponds to the year 2009: Time-Parking 2: Parallel Universe."""
25+
assert task["task_id"] == "32102e3e-d12a-4209-9163-7b3a104efe5d"
26+
assert task["Question"] == question
27+
assert task["Level"] == "2"
28+
assert task["Final answer"] == "Time-Parking 2: Parallel Universe"
29+
assert task["file_name"] == "32102e3e-d12a-4209-9163-7b3a104efe5d.xlsx"
30+
assert task["Annotator Metadata"]["Steps"] == steps
31+
assert task["Annotator Metadata"]["Number of steps"] == "3"
32+
assert task["Annotator Metadata"]["How long did this take?"] == "1 minute"
33+
assert task["Annotator Metadata"]["Tools"] == "1. Microsoft Excel"
34+
assert task["Annotator Metadata"]["Number of tools"] == "1"

0 commit comments

Comments
 (0)