Skip to content

Commit 0791f2d

Browse files
committed
add GAIA gym
1 parent 3040571 commit 0791f2d

File tree

4 files changed

+56
-3
lines changed

4 files changed

+56
-3
lines changed

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ joblib>=1.2.0
1313
openai>=1.7,<2
1414
langchain_community
1515
tiktoken
16-
tapeagents[converters]~=0.1.4
16+
tapeagents[converters]
1717
huggingface_hub
1818
contexttimer
1919
ipython

src/agentlab/benchmarks/abstract_env.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
from abc import ABC, abstractmethod
22

33
import gym
4+
from pydantic import BaseModel
45

56

6-
class AbstractEnvArgs(ABC):
7+
class AbstractEnvArgs(BaseModel):
78
"""Easily serialiazable class to store the arguments of an environment"""
89

910
@abstractmethod

src/agentlab/benchmarks/gaia.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
import os
2+
from typing import Literal
3+
4+
import datasets
5+
from tapeagents.environment import ContainerExecutor
6+
from tapeagents.tools.browser import Browser
7+
from tapeagents.tools.code_executor import CodeExecutor
8+
from tapeagents.tools.container_executor import init_code_sandbox
9+
from tapeagents.tools.media_reader import VideoReader
10+
from tapeagents.tools.web_search import WebSearch
11+
12+
from agentlab.benchmarks.abstract_env import AbstractEnvArgs
13+
from agentlab.benchmarks.multitool_gym import MultiToolGym
14+
15+
16+
class GaiaGym(MultiToolGym):
17+
task: dict
18+
exp_dir: str
19+
20+
21+
class GaiaGymArgs(AbstractEnvArgs):
22+
task_id: str
23+
split: Literal["test", "validation"]
24+
exp_dir: str
25+
viewport_chars: int = 64000
26+
27+
def make_env(self) -> GaiaGym:
28+
init_code_sandbox(self.exp_dir)
29+
dataset = datasets.load_dataset("gaia-benchmark/GAIA", "2023_all")
30+
tasks_by_id = {task["task_id"]: task for task in dataset[self.split]}
31+
task = tasks_by_id[self.task_id]
32+
tools = [
33+
WebSearch(),
34+
VideoReader(self.exp_dir),
35+
Browser(self.exp_dir, viewport_chars=self.viewport_chars),
36+
CodeExecutor(self.exp_dir),
37+
]
38+
env = GaiaGym(tools=tools, task=task, exp_dir=self.exp_dir)
39+
return env
40+
41+
def init_code_sandbox(self) -> None:
42+
code_path = os.path.join(self.exp_dir, "code")
43+
os.makedirs(code_path, exist_ok=True)
44+
container_name = self.exp_dir.replace("/", "-")
45+
ContainerExecutor(
46+
work_dir=code_path,
47+
container_name=container_name,
48+
restart_if_exists=False,
49+
stop_container=False,
50+
no_deps=True,
51+
)

src/agentlab/benchmarks/multitool_gym.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,9 @@ def __init__(self, tools: list[Tool | Multitool]):
5151
)
5252
self.reset()
5353

54-
def reset(self, seed=None):
54+
def reset(self):
5555
self._tape: EnvTape = EnvTape(steps=[])
56+
self._env.reset()
5657

5758
def step(self, action: str):
5859
try:

0 commit comments

Comments
 (0)