Skip to content

Commit 418764f

Browse files
committed
test gym creation and reset
1 parent 460febf commit 418764f

File tree

4 files changed

+79
-34
lines changed

4 files changed

+79
-34
lines changed

src/agentlab/benchmarks/abstract_env.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4,20 +4,6 @@
44
from pydantic import BaseModel
55

66

7-
class AbstractBenchmark(BaseModel):
8-
name: str
9-
env_args_list: list = None
10-
11-
def get_version(self) -> int:
12-
return "1"
13-
14-
def prepare_backends(self):
15-
pass
16-
17-
def dependency_graph_over_tasks(self) -> dict[str, list[str]]:
18-
return {}
19-
20-
217
class AbstractEnvArgs(BaseModel):
228
"""Easily serialiazable class to store the arguments of an environment"""
239

@@ -36,6 +22,20 @@ def make_env(self, action_mapping, exp_dir, exp_task_kwargs) -> "AbstractEnv":
3622
"""
3723

3824

25+
class AbstractBenchmark(BaseModel):
26+
name: str
27+
env_args_list: list[AbstractEnvArgs]
28+
29+
def get_version(self) -> int:
30+
return "1"
31+
32+
def prepare_backends(self):
33+
pass
34+
35+
def dependency_graph_over_tasks(self) -> dict[str, list[str]]:
36+
return {}
37+
38+
3939
class AbstractEnv(gym.Env, ABC):
4040
@abstractmethod
4141
def reset(self, seed: int = None) -> tuple[dict[str, any], dict[str, any]]:

src/agentlab/benchmarks/gaia.py

Lines changed: 26 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import datasets
66
from pydantic import Field
77
from tapeagents.core import Observation, StopStep, Thought
8-
from tapeagents.environment import ContainerExecutor
8+
from tapeagents.environment import ContainerExecutor, StatefulTool, Tool
99
from tapeagents.steps import ImageObservation
1010
from tapeagents.tools.browser import Browser
1111
from tapeagents.tools.code_executor import CodeExecutor
@@ -16,29 +16,22 @@
1616
from agentlab.benchmarks.multitool_gym import MultiToolGym
1717

1818

19-
class GaiaBenchmark(AbstractBenchmark):
20-
exp_dir: str
21-
name: str = "gaia"
22-
split: Literal["test", "validation"]
23-
24-
def model_post_init(self, __context: Any) -> None:
25-
self.env_args_list = []
26-
dataset = datasets.load_dataset("gaia-benchmark/GAIA", "2023_all")[self.split]
27-
for task in dataset:
28-
task_dir = os.path.join(self.name, task["task_id"])
29-
env_args = GaiaGymArgs(task=task, exp_dir=task_dir)
30-
self.env_args_list.append(env_args)
31-
32-
3319
class GaiaGym(MultiToolGym):
3420
task: dict
3521
exp_dir: str
3622

23+
def __init__(self, tools: list[Tool | StatefulTool], task: dict, exp_dir: str):
24+
super().__init__(tools=tools)
25+
self.task = task
26+
self.exp_dir = exp_dir
27+
3728
def reset(self) -> tuple[list[Observation], dict]:
3829
super().reset()
30+
print("task:", self.task)
3931
question = GaiaQuestion.from_task(self.task)
4032
steps = [question]
4133
if image_obs := with_image(question):
34+
print("image_obs:", image_obs)
4235
steps.append(image_obs)
4336
return steps
4437

@@ -52,9 +45,9 @@ def make_env(self) -> GaiaGym:
5245
self.init_code_sandbox()
5346
tools = [
5447
WebSearch(),
55-
VideoReader(self.exp_dir),
56-
Browser(self.exp_dir, viewport_chars=self.viewport_chars),
57-
CodeExecutor(self.exp_dir),
48+
VideoReader(exp_path=self.exp_dir),
49+
Browser(exp_path=self.exp_dir, viewport_chars=self.viewport_chars),
50+
CodeExecutor(exp_path=self.exp_dir),
5851
]
5952
env = GaiaGym(tools=tools, task=self.task, exp_dir=self.exp_dir)
6053
return env
@@ -72,6 +65,21 @@ def init_code_sandbox(self) -> None:
7265
)
7366

7467

68+
class GaiaBenchmark(AbstractBenchmark):
69+
exp_dir: str
70+
name: str = "gaia"
71+
split: Literal["test", "validation"]
72+
env_args_list: list[GaiaGymArgs] = None
73+
74+
def model_post_init(self, __context: Any) -> None:
75+
self.env_args_list = []
76+
dataset = datasets.load_dataset("gaia-benchmark/GAIA", "2023_all")[self.split]
77+
for task in dataset:
78+
task_dir = os.path.join(self.name, task["task_id"])
79+
env_args = GaiaGymArgs(task=task, exp_dir=task_dir)
80+
self.env_args_list.append(env_args)
81+
82+
7583
class ExtractedFacts(Thought):
7684
"""
7785
Thought that contains the list of facts extracted from the document

src/agentlab/benchmarks/multitool_gym.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,27 @@
1+
from typing import Annotated, Union
2+
3+
from pydantic import Field, TypeAdapter
14
from tapeagents.core import Action, Observation, Tape
5+
from tapeagents.environment import ToolCollectionEnvironment
6+
from tapeagents.tools.base import StatefulTool, Tool
27

38
from agentlab.benchmarks.abstract_env import AbstractEnv
49

510
EnvTape = Tape[None, Action | Observation]
611

712

813
class MultiToolGym(AbstractEnv):
14+
def __init__(self, tools: list[Tool | StatefulTool]):
15+
self._env = ToolCollectionEnvironment(tools)
16+
self._actions = self._env.actions()
17+
self._actions_parser: TypeAdapter = TypeAdapter(
18+
Annotated[Union[self._actions], Field(discriminator="kind")]
19+
)
20+
921
def reset(self):
1022
self._env.reset()
1123

12-
def step(self, action: str):
24+
def step(self, action: str) -> tuple[Observation, float, bool, bool, dict]:
1325
try:
1426
action_step = self._actions_parser.validate_json(action)
1527
except Exception:

tests/agents/test_gaia_agent.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
1+
import os
2+
3+
from tapeagents.steps import ImageObservation
4+
15
from agentlab.agents.tapeagent.agent import TapeAgent, TapeAgentArgs
2-
from agentlab.benchmarks.gaia import GaiaBenchmark
6+
from agentlab.benchmarks.gaia import GaiaBenchmark, GaiaQuestion
37

48

59
def test_agent_creation():
@@ -32,3 +36,24 @@ def test_gaia_bench():
3236
assert task["Annotator Metadata"]["How long did this take?"] == "1 minute"
3337
assert task["Annotator Metadata"]["Tools"] == "1. Microsoft Excel"
3438
assert task["Annotator Metadata"]["Number of tools"] == "1"
39+
40+
41+
def test_gaia_gym_reset():
42+
exp_dir = "/tmp/"
43+
bench = GaiaBenchmark(exp_dir=exp_dir, split="validation")
44+
45+
args = bench.env_args_list[5]
46+
env = args.make_env()
47+
steps = env.reset()
48+
assert len(steps) == 1
49+
assert isinstance(steps[0], GaiaQuestion)
50+
assert steps[0].content == args.task["Question"]
51+
52+
args = bench.env_args_list[20]
53+
env = args.make_env()
54+
steps = env.reset()
55+
assert len(steps) == 2
56+
assert isinstance(steps[0], GaiaQuestion)
57+
assert steps[0].content == args.task["Question"]
58+
assert isinstance(steps[1], ImageObservation)
59+
assert os.path.basename(steps[1].image_path) == args.task["file_name"]

0 commit comments

Comments
 (0)