|
1 | 1 | import os |
2 | | -from typing import Literal |
| 2 | +from typing import Any, Literal |
3 | 3 |
|
| 4 | +import bgym |
4 | 5 | import datasets |
5 | 6 | from tapeagents.environment import ContainerExecutor |
6 | 7 | from tapeagents.tools.browser import Browser |
|
12 | 13 | from agentlab.benchmarks.multitool_gym import MultiToolGym |
13 | 14 |
|
14 | 15 |
|
| 16 | +class GaiaBenchmark(bgym.Benchmark): |
| 17 | + name = "gaia" |
| 18 | + split: Literal["test", "validation"] |
| 19 | + exp_dir: str |
| 20 | + |
| 21 | + high_level_action_set_args = None |
| 22 | + is_multi_tab = False |
| 23 | + supports_parallel_seeds = False |
| 24 | + backends = ["gaia"] |
| 25 | + env_args_list = None |
| 26 | + task_metadata = None |
| 27 | + |
| 28 | + def __post_init__(self): |
| 29 | + super().__post_init__() |
| 30 | + self.env_args_list = [] |
| 31 | + dataset = datasets.load_dataset("gaia-benchmark/GAIA", "2023_all")[self.split] |
| 32 | + for task in dataset: |
| 33 | + task_dir = os.path.join(self.name, task["task_id"]) |
| 34 | + env_args = GaiaGymArgs(task=task, exp_dir=task_dir) |
| 35 | + self.env_args_list.append(env_args) |
| 36 | + |
| 37 | + |
15 | 38 | class GaiaGym(MultiToolGym): |
16 | 39 | task: dict |
17 | 40 | exp_dir: str |
18 | 41 |
|
19 | 42 |
|
20 | 43 | class GaiaGymArgs(AbstractEnvArgs): |
21 | | - task_id: str |
| 44 | + task: dict[str, Any] |
22 | 45 | split: Literal["test", "validation"] |
23 | 46 | exp_dir: str |
24 | 47 | viewport_chars: int = 64000 |
25 | 48 |
|
26 | 49 | def make_env(self) -> GaiaGym: |
27 | 50 | self.init_code_sandbox() |
28 | | - dataset = datasets.load_dataset("gaia-benchmark/GAIA", "2023_all") |
29 | | - tasks_by_id = {task["task_id"]: task for task in dataset[self.split]} |
30 | | - task = tasks_by_id[self.task_id] |
31 | 51 | tools = [ |
32 | 52 | WebSearch(), |
33 | 53 | VideoReader(self.exp_dir), |
34 | 54 | Browser(self.exp_dir, viewport_chars=self.viewport_chars), |
35 | 55 | CodeExecutor(self.exp_dir), |
36 | 56 | ] |
37 | | - env = GaiaGym(tools=tools, task=task, exp_dir=self.exp_dir) |
| 57 | + env = GaiaGym(tools=tools, task=self.task, exp_dir=self.exp_dir) |
38 | 58 | return env |
39 | 59 |
|
40 | 60 | def init_code_sandbox(self) -> None: |
|
0 commit comments