Skip to content

Commit c747915

Browse files
committed
config-driven gym with tools and bench
1 parent 55378a4 commit c747915

File tree

9 files changed

+145
-60
lines changed

9 files changed

+145
-60
lines changed

.vscode/launch.json

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
{
2+
// Use IntelliSense to learn about possible attributes.
3+
// Hover to view descriptions of existing attributes.
4+
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
5+
"version": "0.2.0",
6+
"configurations": [
7+
{
8+
"name": "Python Debugger: Current File",
9+
"type": "debugpy",
10+
"request": "launch",
11+
"program": "${file}",
12+
"console": "integratedTerminal",
13+
"justMyCode": false,
14+
"env": {
15+
"AGENTLAB_DEBUG": "1"
16+
}
17+
}
18+
]
19+
}

scripts/run_gaia.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,25 @@
11
import logging
2+
import os
23

3-
from agentlab.agents.tapeagent.agent import TapeAgentArgs
4-
from agentlab.benchmarks.gaia import GaiaBenchmark
4+
from agentlab.agents.tapeagent.agent import TapeAgentArgs, load_config
5+
from agentlab.benchmarks.gaia import GaiaBenchmark, stop_old_sandbox
56
from agentlab.experiments.study import make_study
67

78
fmt = "%(asctime)s - %(levelname)s - %(name)s:%(lineno)d - %(funcName)s() - %(message)s"
89
logging.basicConfig(level=logging.INFO, force=True, format=fmt, handlers=[logging.StreamHandler()])
910

1011
if __name__ == "__main__":
12+
config = load_config("gaia_l1")
1113
study = make_study(
12-
benchmark=GaiaBenchmark(split="validation", level="1"), # type: ignore
13-
agent_args=TapeAgentArgs("gaia_agent"),
14-
comment="Gaia eval",
14+
benchmark=GaiaBenchmark.from_config(config), # type: ignore
15+
agent_args=TapeAgentArgs(agent_name=config.name, config=config),
16+
comment=config.comment,
1517
logging_level=logging.INFO,
1618
logging_level_stdout=logging.INFO,
1719
)
18-
# study.exp_args_list = study.exp_args_list[:3]
19-
# study.run(n_jobs=1, n_relaunch=1, parallel_backend="sequential")
20-
study.run(n_jobs=8, n_relaunch=1, parallel_backend="ray")
20+
stop_old_sandbox()
21+
if os.environ.get("AGENTLAB_DEBUG"):
22+
study.exp_args_list = study.exp_args_list[:3]
23+
study.run(n_jobs=1, n_relaunch=1, parallel_backend="sequential")
24+
else:
25+
study.run(n_jobs=config.n_jobs, n_relaunch=1, parallel_backend=config.parallel_backend)

src/agentlab/agents/tapeagent/agent.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import bgym
66
import hydra
7+
from omegaconf import DictConfig
78
from pydantic import Field
89
from tapeagents.agent import Agent
910
from tapeagents.core import Action, Observation, TapeMetadata, Thought
@@ -29,28 +30,32 @@ class Tape(BaseTape):
2930
metadata: ExtendedMetadata = Field(default_factory=ExtendedMetadata) # type: ignore
3031

3132

33+
def load_config(config_name: str) -> DictConfig:
34+
with hydra.initialize(config_path="conf", version_base="1.1"):
35+
config = hydra.compose(config_name=config_name)
36+
return config
37+
38+
3239
@dataclass
3340
class TapeAgentArgs(AgentArgs):
34-
agent_name: str = "tape_agent"
41+
config: DictConfig = None # type: ignore
3542

3643
def make_agent(self) -> bgym.Agent:
37-
with hydra.initialize(config_path="conf", version_base="1.1"):
38-
config = hydra.compose(config_name=self.agent_name)
39-
agent: Agent = hydra.utils.instantiate(config)
44+
agent: Agent = hydra.utils.instantiate(self.config.agent)
4045
return TapeAgent(agent=agent)
4146

4247

4348
@dataclass
4449
class TapeAgentInfo(bgym.AgentInfo):
45-
thoughts: list[Thought] = None
50+
thoughts: list[Thought] = None # type: ignore
4651

4752

4853
class DictObservation(Observation):
4954
"""
5055
Container for wrapping old dict observation into new Observation class.
5156
"""
5257

53-
kind: Literal["dict_observation"] = "dict_observation"
58+
kind: Literal["dict_observation"] = "dict_observation" # type: ignore
5459
content: str
5560

5661

@@ -70,8 +75,8 @@ def obs_preprocessor(self, obs: Observation | list[Observation]) -> list[Observa
7075
logger.info(f"Observations: {[type(o).__name__ for o in obs]}")
7176
return obs
7277

73-
def get_action(self, obs: Observation | list[Observation]) -> tuple[str, TapeAgentInfo]:
74-
self.tape += obs
78+
def get_action(self, obs: Observation | list[Observation]) -> tuple[Action, TapeAgentInfo]:
79+
self.tape += obs # type: ignore
7580
thoughts: list[Thought] = []
7681
action = None
7782
while not action:

src/agentlab/agents/tapeagent/conf/gaia_agent.yaml renamed to src/agentlab/agents/tapeagent/conf/agent/plan_react.yaml

Lines changed: 14 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,13 @@
1-
defaults:
2-
- [email protected]: gpt4o_mini
3-
- _self_
4-
51
_target_: tapeagents.agent.Agent
62
name : gaia_agent
73
max_iterations: 2
4+
llms:
5+
default: ${llm}
86
tools_description: |
9-
- WebSearch - Performs a search in the web, wikipedia or youtube
10-
- VideoReader - Opens video from a youtube URL. Can access the video content, thumbnail, subtitles and audio.
7+
- WebSearch - Performs web search.
8+
- VideoReader - Opens video from a youtube URL.
119
- Browser - Browser tool that can load web pages and interact with their content.
12-
- CodeExecutor - Executes the python code snippet
10+
- CodeExecutor - Executes the python code snippet.
1311
known_actions:
1412
- _target_: hydra.utils.get_class
1513
path: tapeagents.tools.web_search.SearchAction
@@ -64,18 +62,18 @@ templates:
6462
nodes:
6563
- _target_: tapeagents.nodes.StandardNode
6664
name: plan
67-
system_prompt: ${templates.system_prompt}
65+
system_prompt: ${agent.templates.system_prompt}
6866
guidance: |
6967
Write a concise multi-step plan explaining which steps should be performed to find the answer for the given task.
7068
Remember that you can use web search, browser, python code execution and access the youtube videos to reach your goals.
7169
Be specific about how each step should be performed. Only describe the intended actions here, do not perform them yet.
7270
Consider that next steps may depend on results of previous steps, so include conditional branching using "if" statements where needed.
73-
${templates.thought_format}
74-
steps_prompt: ${templates.allowed_tools}
71+
${agent.templates.thought_format}
72+
steps_prompt: ${agent.templates.allowed_tools}
7573

7674
- _target_: tapeagents.nodes.StandardNode
7775
name: facts_survey
78-
system_prompt: ${templates.system_prompt}
76+
system_prompt: ${agent.templates.system_prompt}
7977
guidance: |
8078
Before we begin executing the plan, please answer the following pre-survey.
8179
Here is the pre-survey:
@@ -84,16 +82,16 @@ nodes:
8482
3. Please list any facts that may need to be derived (e.g., via logical deduction, simulation, or computation)
8583
4. Please list any facts that are recalled from memory, hunches, well-reasoned guesses, etc.
8684
When answering this survey, keep in mind that "facts" will typically be specific names, dates, statistics, etc.
87-
${templates.thought_format}
88-
steps_prompt: ${templates.allowed_tools}
85+
${agent.templates.thought_format}
86+
steps_prompt: ${agent.templates.allowed_tools}
8987

9088
- _target_: tapeagents.nodes.StandardNode
9189
name: act
92-
system_prompt: ${templates.system_prompt}
90+
system_prompt: ${agent.templates.system_prompt}
9391
guidance: |
9492
Produce single next step. If the answer is ready, produce gaia_answer_action.
95-
${templates.format}
96-
steps_prompt: ${templates.allowed_steps}
93+
${agent.templates.format}
94+
steps_prompt: ${agent.templates.allowed_steps}
9795
steps:
9896
- tapeagents.steps.ReasoningThought
9997
- agentlab.benchmarks.gaia.ExtractedFacts
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
tools:
2+
- _target_: tapeagents.tools.web_search.WebSearch
3+
- _target_: tapeagents.tools.media_reader.VideoReader
4+
exp_path: ""
5+
- _target_: tapeagents.tools.browser.Browser
6+
exp_path: ""
7+
viewport_chars: 64000
8+
navigation_only: true
9+
- _target_: tapeagents.tools.code_executor.CodeExecutor
10+
exp_path: ""
11+
reuse_computer_container: true
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
defaults:
2+
- llm: gpt4o_mini
3+
- agent: plan_react
4+
- environment: web_code
5+
- _self_
6+
7+
name: gaia_agent
8+
comment: Gaia L1 val
9+
split: validation
10+
level: "1"
11+
parallel_backend: ray
12+
n_jobs: 10
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
defaults:
2+
- llm: gpt4o_mini
3+
- agent: plan_react
4+
- environment: web_code
5+
- _self_
6+
7+
name: gaia_agent
8+
comment: Gaia val
9+
split: validation
10+
level: "all"
11+
parallel_backend: ray
12+
n_jobs: 10

src/agentlab/benchmarks/gaia.py

Lines changed: 41 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,30 @@
1-
import fcntl
21
import logging
32
import os
43
import re
54
import shutil
65
import string
76
from dataclasses import dataclass
87
from pathlib import Path
9-
from typing import Any, Literal
8+
from typing import Any, Literal, Self
109

1110
import datasets
11+
import hydra
12+
import podman
13+
from omegaconf import DictConfig
1214
from pdf2image import convert_from_path
13-
from pydantic import Field
15+
from pydantic import ConfigDict, Field
1416
from tapeagents.core import Action, Observation, StopStep, Thought
1517
from tapeagents.environment import ContainerExecutor, StatefulTool, Tool
1618
from tapeagents.steps import ImageObservation
17-
from tapeagents.tools.browser import Browser
18-
from tapeagents.tools.code_executor import CodeExecutor
19-
from tapeagents.tools.media_reader import VideoReader
2019
from tapeagents.tools.simple_browser import SimpleTextBrowser
21-
from tapeagents.tools.web_search import WebSearch
2220

2321
from agentlab.benchmarks.abstract_env import AbstractBenchmark, AbstractEnvArgs
2422
from agentlab.benchmarks.multitool_gym import MultiToolGym
2523

2624
logger = logging.getLogger(__name__)
2725

26+
CONTAINER_NAME = "gaia_code_shared"
27+
2828

2929
class GaiaGym(MultiToolGym):
3030
task: dict
@@ -61,30 +61,33 @@ def calculate_reward(self, action: Action) -> float:
6161

6262
@dataclass
6363
class GaiaGymArgs(AbstractEnvArgs):
64+
model_config = ConfigDict(arbitrary_types_allowed=True)
6465
task: dict[str, Any]
65-
viewport_chars: int
6666
task_seed: int
6767
task_name: str
68+
env_config: DictConfig
6869

6970
def __init__(
70-
self, task_name: str, task: dict[str, Any], viewport_chars: int = 64000, task_seed: int = 0
71+
self,
72+
task_name: str,
73+
task: dict[str, Any],
74+
env_config: DictConfig,
75+
task_seed: int = 0,
7176
):
7277
self.task_name = task_name
7378
self.task = task
74-
self.viewport_chars = viewport_chars
7579
self.task_seed = task_seed
80+
self.env_config = env_config
7681

7782
def make_env(self, exp_dir: str | Path, action_mapping=None) -> GaiaGym:
7883
exp_dir = str(exp_dir)
7984
logger.info(f"Init gaia env with directory {exp_dir}")
8085
os.environ["TAPEAGENTS_SQLITE_DB"] = os.path.join(exp_dir, "tapedata.sqlite")
8186
init_code_sandbox(exp_dir)
82-
tools = [
83-
WebSearch(),
84-
VideoReader(exp_path=exp_dir),
85-
Browser(exp_path=exp_dir, viewport_chars=self.viewport_chars, navigation_only=True),
86-
CodeExecutor(exp_path=exp_dir, reuse_computer_container=True),
87-
]
87+
for i in range(len(self.env_config.tools)):
88+
if hasattr(self.env_config.tools[i], "exp_path"):
89+
self.env_config.tools[i].exp_path = exp_dir
90+
tools = hydra.utils.instantiate(self.env_config.tools)
8891
env = GaiaGym(tools=tools, task=self.task, exp_dir=exp_dir)
8992
return env
9093

@@ -94,27 +97,43 @@ def init_code_sandbox(exp_dir: str) -> None:
9497
root_exp_dir = Path(exp_dir).parent
9598
code_path = os.path.join(root_exp_dir, "shared_code")
9699
os.makedirs(code_path, exist_ok=True)
97-
98-
container_name = "gaia_code_shared"
99-
os.environ["COMPUTER_CONTAINER_NAME"] = container_name
100+
os.environ["COMPUTER_CONTAINER_NAME"] = CONTAINER_NAME
100101

101102
# symlink task code to the shared code directory
102103
task_code_path = os.path.join(exp_dir, "code")
103104
if not os.path.exists(task_code_path):
104105
os.symlink(code_path, task_code_path)
105106

106107
try:
107-
ContainerExecutor(container_name=container_name, work_dir=code_path, no_deps=True)
108+
ContainerExecutor(container_name=CONTAINER_NAME, work_dir=code_path, no_deps=True)
108109
except Exception as e:
109110
logger.warning(f"Failed to initialize container executor: {e}")
110111

111112

113+
def stop_old_sandbox():
114+
try:
115+
podman.from_env().containers.get(CONTAINER_NAME).stop()
116+
except Exception as e:
117+
logger.warning(f"Failed to stop old container {CONTAINER_NAME}: {e}")
118+
119+
112120
class GaiaBenchmark(AbstractBenchmark):
121+
model_config = ConfigDict(arbitrary_types_allowed=True)
113122
name: str = "gaia"
114123
split: Literal["test", "validation"]
115124
level: Literal["1", "2", "3", "all"] = "all"
116125
env_args_list: list[GaiaGymArgs] = None # type: ignore
117126
dataset: dict = None # type: ignore
127+
env_config: DictConfig = None # type: ignore
128+
129+
@classmethod
130+
def from_config(cls, config: DictConfig, dataset: dict = None) -> Self:
131+
return cls(
132+
split=config.split,
133+
level=config.level,
134+
env_config=config.environment,
135+
dataset=dataset,
136+
)
118137

119138
def model_post_init(self, __context: Any) -> None:
120139
if not self.dataset:
@@ -130,7 +149,8 @@ def model_post_init(self, __context: Any) -> None:
130149
continue
131150
number += 1
132151
task["number"] = number
133-
env_args = GaiaGymArgs(task_name="gaia." + task["task_id"], task=task)
152+
name = f"gaia.{task['task_id']}"
153+
env_args = GaiaGymArgs(task_name=name, task=task, env_config=self.env_config)
134154
self.env_args_list.append(env_args)
135155
logger.info(f"Loaded {len(self.env_args_list)} tasks from {self.split} split")
136156

0 commit comments

Comments
 (0)