Skip to content

Commit 55378a4

Browse files
committed
remaining fixes, eval now matched with the old tapeagents evals
1 parent e98a0c2 commit 55378a4

File tree

4 files changed

+38
-31
lines changed

4 files changed

+38
-31
lines changed

src/agentlab/agents/tapeagent/agent.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,12 +26,12 @@ class ExtendedMetadata(TapeMetadata):
2626

2727

2828
class Tape(BaseTape):
29-
metadata: ExtendedMetadata = Field(default_factory=ExtendedMetadata)
29+
metadata: ExtendedMetadata = Field(default_factory=ExtendedMetadata) # type: ignore
3030

3131

3232
@dataclass
3333
class TapeAgentArgs(AgentArgs):
34-
agent_name: str
34+
agent_name: str = "tape_agent"
3535

3636
def make_agent(self) -> bgym.Agent:
3737
with hydra.initialize(config_path="conf", version_base="1.1"):

src/agentlab/benchmarks/gaia.py

Lines changed: 24 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,17 @@
1+
import fcntl
12
import logging
23
import os
34
import re
45
import shutil
56
import string
67
from dataclasses import dataclass
7-
from math import exp
88
from pathlib import Path
99
from typing import Any, Literal
1010

1111
import datasets
1212
from pdf2image import convert_from_path
1313
from pydantic import Field
14-
from tapeagents.core import Action, Observation, Step, StopStep, Thought
14+
from tapeagents.core import Action, Observation, StopStep, Thought
1515
from tapeagents.environment import ContainerExecutor, StatefulTool, Tool
1616
from tapeagents.steps import ImageObservation
1717
from tapeagents.tools.browser import Browser
@@ -78,7 +78,7 @@ def make_env(self, exp_dir: str | Path, action_mapping=None) -> GaiaGym:
7878
exp_dir = str(exp_dir)
7979
logger.info(f"Init gaia env with directory {exp_dir}")
8080
os.environ["TAPEAGENTS_SQLITE_DB"] = os.path.join(exp_dir, "tapedata.sqlite")
81-
self.init_code_sandbox(exp_dir)
81+
init_code_sandbox(exp_dir)
8282
tools = [
8383
WebSearch(),
8484
VideoReader(exp_path=exp_dir),
@@ -88,34 +88,40 @@ def make_env(self, exp_dir: str | Path, action_mapping=None) -> GaiaGym:
8888
env = GaiaGym(tools=tools, task=self.task, exp_dir=exp_dir)
8989
return env
9090

91-
def init_code_sandbox(self, exp_dir: str) -> None:
92-
# Use a common code directory for all tasks in the experiment, which is mounted in the container
93-
root_exp_dir = Path(exp_dir).parent
94-
code_path = os.path.join(root_exp_dir, "shared_code")
95-
os.makedirs(code_path, exist_ok=True)
9691

97-
container_name = "gaia_code_shared"
98-
os.environ["COMPUTER_CONTAINER_NAME"] = container_name
92+
def init_code_sandbox(exp_dir: str) -> None:
93+
# Use a common code directory for all tasks in the experiment, which is mounted in the container
94+
root_exp_dir = Path(exp_dir).parent
95+
code_path = os.path.join(root_exp_dir, "shared_code")
96+
os.makedirs(code_path, exist_ok=True)
9997

100-
# symlink task code to the shared code directory
101-
task_code_path = os.path.join(exp_dir, "code")
102-
if not os.path.exists(task_code_path):
103-
os.symlink(code_path, task_code_path)
98+
container_name = "gaia_code_shared"
99+
os.environ["COMPUTER_CONTAINER_NAME"] = container_name
104100

101+
# symlink task code to the shared code directory
102+
task_code_path = os.path.join(exp_dir, "code")
103+
if not os.path.exists(task_code_path):
104+
os.symlink(code_path, task_code_path)
105+
106+
try:
105107
ContainerExecutor(container_name=container_name, work_dir=code_path, no_deps=True)
108+
except Exception as e:
109+
logger.warning(f"Failed to initialize container executor: {e}")
106110

107111

108112
class GaiaBenchmark(AbstractBenchmark):
109113
name: str = "gaia"
110114
split: Literal["test", "validation"]
111115
level: Literal["1", "2", "3", "all"] = "all"
112-
env_args_list: list[GaiaGymArgs] = None
113-
dataset: dict = Field(default_factory=dict)
116+
env_args_list: list[GaiaGymArgs] = None # type: ignore
117+
dataset: dict = None # type: ignore
114118

115119
def model_post_init(self, __context: Any) -> None:
116120
if not self.dataset:
117121
self.dataset = datasets.load_dataset(
118-
"gaia-benchmark/GAIA", "2023_all", trust_remote_code=True
122+
path="gaia-benchmark/GAIA",
123+
name="2023_all",
124+
trust_remote_code=True,
119125
) # type: ignore
120126
self.env_args_list = []
121127
number = 0
@@ -134,7 +140,7 @@ class ExtractedFacts(Thought):
134140
Thought that contains the list of facts extracted from the document
135141
"""
136142

137-
kind: Literal["extracted_facts_thought"] = "extracted_facts_thought"
143+
kind: Literal["extracted_facts_thought"] = "extracted_facts_thought" # type: ignore
138144
extracted_facts: list[str] | dict[str, Any] | str = Field(
139145
description="facts extracted from the observation"
140146
)

src/agentlab/benchmarks/multitool_gym.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,15 @@
1111

1212

1313
class MultiToolGym(AbstractEnv):
14-
def __init__(self, tools: list[Tool | StatefulTool]):
14+
def __init__(self, tools: list[Tool | StatefulTool], max_turns: int = 50):
1515
self._env = ToolCollectionEnvironment(tools)
1616
self._actions = self._env.actions()
17+
self.max_turns = max_turns
18+
self._turns = 0
1719

1820
def reset(self):
1921
self._env.reset()
22+
self._turns = 0
2023

2124
def step(self, action: Action) -> tuple[Observation, float, bool, bool, dict]:
2225
logger.info(f"Gym {self.__class__.__name__} step called with action {type(action)}")
@@ -28,11 +31,13 @@ def step(self, action: Action) -> tuple[Observation, float, bool, bool, dict]:
2831
observation = Observation() # empty observation
2932
else:
3033
observation = self._env.step(action)
34+
terminated = isinstance(observation, StopStep)
3135
action_exec_stop = time.time()
36+
self._turns += 1
3237

3338
reward = self.calculate_reward(action)
3439

35-
truncated = False
40+
truncated = self._turns >= self.max_turns
3641

3742
env_info = {
3843
"step_metadata": observation.metadata,

src/agentlab/experiments/loop.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -23,17 +23,12 @@
2323
from browsergym.experiments.utils import count_messages_token, count_tokens
2424
from dataclasses_json import DataClassJsonMixin
2525
from PIL import Image
26-
from tapeagents.core import Step, StepMetadata, TapeMetadata
26+
from tapeagents.core import Step, StepMetadata
2727
from tapeagents.dialog_tape import AssistantStep, AssistantThought
2828
from tapeagents.io import save_json_tape, save_tape_images
2929
from tqdm import tqdm
3030

31-
from agentlab.agents.tapeagent.agent import (
32-
DictObservation,
33-
ExtendedMetadata,
34-
Tape,
35-
TapeAgent,
36-
)
31+
from agentlab.agents.tapeagent.agent import DictObservation, Tape, TapeAgent
3732

3833
logger = logging.getLogger(__name__)
3934

@@ -237,9 +232,10 @@ def run(self):
237232
self._set_logger()
238233

239234
# log python environment info
240-
save_package_versions(self.exp_dir)
235+
save_package_versions(Path(self.exp_dir))
241236

242237
episode_info = []
238+
agent = None
243239
env, step_info, err_msg, stack_trace = None, None, None, None
244240
try:
245241
logger.info(f"Running experiment {self.exp_name} in:\n {self.exp_dir}")
@@ -255,7 +251,7 @@ def run(self):
255251
step_info = StepInfo(step=0)
256252
episode_info = [step_info]
257253
step_info.from_reset(
258-
env, seed=self.env_args.task_seed, obs_preprocessor=agent.obs_preprocessor
254+
env, seed=self.env_args.task_seed or 0, obs_preprocessor=agent.obs_preprocessor
259255
)
260256
logger.debug("Environment reset.")
261257

0 commit comments

Comments
 (0)