Skip to content

Commit 76958ee

Browse files
committed
gaia benchmark class and entrypoint script
1 parent 7691e49 commit 76958ee

File tree

3 files changed

+42
-7
lines changed

3 files changed

+42
-7
lines changed

.gitignore

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,4 +167,6 @@ _sandbox.py
167167
results/
168168

169169
# gradio
170-
.gradio/
170+
.gradio/
171+
172+
outputs/

scripts/run_gaia.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
from agentlab.agents.tapeagent import TapeAgentArgs
2+
from agentlab.benchmarks.gaia import GaiaBenchmark
3+
from agentlab.experiments.study import make_study
4+
5+
exp_dir = "./outputs/gaia/debug1"
6+
agent_args = TapeAgentArgs("gaia_agent")
7+
study = make_study(
8+
benchmark=GaiaBenchmark(split="validation", exp_dir=exp_dir),
9+
agent_args=[agent_args],
10+
comment="Gaia eval",
11+
)
12+
13+
study.run(n_jobs=1)

src/agentlab/benchmarks/gaia.py

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import os
2-
from typing import Literal
2+
from typing import Any, Literal
33

4+
import bgym
45
import datasets
56
from tapeagents.environment import ContainerExecutor
67
from tapeagents.tools.browser import Browser
@@ -12,29 +13,48 @@
1213
from agentlab.benchmarks.multitool_gym import MultiToolGym
1314

1415

16+
class GaiaBenchmark(bgym.Benchmark):
17+
name = "gaia"
18+
split: Literal["test", "validation"]
19+
exp_dir: str
20+
21+
high_level_action_set_args = None
22+
is_multi_tab = False
23+
supports_parallel_seeds = False
24+
backends = ["gaia"]
25+
env_args_list = None
26+
task_metadata = None
27+
28+
def __post_init__(self):
29+
super().__post_init__()
30+
self.env_args_list = []
31+
dataset = datasets.load_dataset("gaia-benchmark/GAIA", "2023_all")[self.split]
32+
for task in dataset:
33+
task_dir = os.path.join(self.name, task["task_id"])
34+
env_args = GaiaGymArgs(task=task, exp_dir=task_dir)
35+
self.env_args_list.append(env_args)
36+
37+
1538
class GaiaGym(MultiToolGym):
1639
task: dict
1740
exp_dir: str
1841

1942

2043
class GaiaGymArgs(AbstractEnvArgs):
21-
task_id: str
44+
task: dict[str, Any]
2245
split: Literal["test", "validation"]
2346
exp_dir: str
2447
viewport_chars: int = 64000
2548

2649
def make_env(self) -> GaiaGym:
2750
self.init_code_sandbox()
28-
dataset = datasets.load_dataset("gaia-benchmark/GAIA", "2023_all")
29-
tasks_by_id = {task["task_id"]: task for task in dataset[self.split]}
30-
task = tasks_by_id[self.task_id]
3151
tools = [
3252
WebSearch(),
3353
VideoReader(self.exp_dir),
3454
Browser(self.exp_dir, viewport_chars=self.viewport_chars),
3555
CodeExecutor(self.exp_dir),
3656
]
37-
env = GaiaGym(tools=tools, task=task, exp_dir=self.exp_dir)
57+
env = GaiaGym(tools=tools, task=self.task, exp_dir=self.exp_dir)
3858
return env
3959

4060
def init_code_sandbox(self) -> None:

0 commit comments

Comments
 (0)