Skip to content

Commit 09ef9a6

Browse files
committed
address review comments
1 parent ef9eb19 commit 09ef9a6

File tree

4 files changed

+49
-63
lines changed

4 files changed

+49
-63
lines changed

src/agentlab/benchmarks/abstract_env.py

Lines changed: 1 addition & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
11
from abc import ABC, abstractmethod
2-
from dataclasses import dataclass
32

43
import gymnasium as gym
54
from dataclasses_json import DataClassJsonMixin
65
from pydantic import BaseModel
76

87

9-
class AbstractEnvArgs(ABC):
8+
class AbstractEnvArgs(DataClassJsonMixin):
109
@abstractmethod
1110
def make_env(self, action_mapping, exp_dir, exp_task_kwargs) -> "AbstractEnv":
1211
"""Create an instance of the environment with the arguments stored in this object.
@@ -22,14 +21,6 @@ def make_env(self, action_mapping, exp_dir, exp_task_kwargs) -> "AbstractEnv":
2221
"""
2322

2423

25-
@dataclass
26-
class SerializableEnvArgs(AbstractEnvArgs, DataClassJsonMixin):
27-
"""Easily serialiazable class to store the arguments of an environment"""
28-
29-
task_seed: int = 0
30-
task_name: str = ""
31-
32-
3324
class AbstractBenchmark(BaseModel):
3425
name: str
3526
env_args_list: list
@@ -80,7 +71,3 @@ def step(self, action: str):
8071
@abstractmethod
8172
def close(self):
8273
"""Close any resources used by the environment"""
83-
84-
@abstractmethod
85-
def calculate_reward(self) -> float:
86-
"""Calculate the reward obtained in the last step"""

src/agentlab/benchmarks/gaia.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import re
44
import shutil
55
import string
6+
from dataclasses import dataclass
67
from pathlib import Path
78
from typing import Any, Literal
89

@@ -16,7 +17,7 @@
1617
from tapeagents.tools.media_reader import VideoReader
1718
from tapeagents.tools.web_search import WebSearch
1819

19-
from agentlab.benchmarks.abstract_env import AbstractBenchmark, SerializableEnvArgs
20+
from agentlab.benchmarks.abstract_env import AbstractBenchmark, AbstractEnvArgs
2021
from agentlab.benchmarks.multitool_gym import MultiToolGym
2122

2223
logger = logging.getLogger(__name__)
@@ -33,19 +34,16 @@ def __init__(self, tools: list[Tool | StatefulTool], task: dict, exp_dir: str):
3334
os.makedirs(".cache", exist_ok=True)
3435

3536
def reset(self, seed=None) -> tuple[list[Observation], dict]:
37+
"""
38+
Reset the state of all the tools and prepare initial observations from the task again
39+
"""
3640
super().reset()
3741
question = GaiaQuestion.from_task(self.task)
3842
steps = [question]
3943
if image_obs := with_image(question):
4044
steps.append(image_obs)
4145
return steps, {}
4246

43-
def step(self, action: Action) -> tuple[Observation, float, bool, bool, dict]:
44-
logger.info(f"Gym step called with action {type(action)}")
45-
observation, reward, terminated, truncated, env_info = super().step(action)
46-
logger.info(f"Gym observation: {observation.short_view()}")
47-
return observation, reward, terminated, truncated, env_info
48-
4947
def calculate_reward(self, action: Action) -> float:
5048
if isinstance(action, GaiaAnswer):
5149
model_answer = action.answer
@@ -62,7 +60,8 @@ def calculate_reward(self, action: Action) -> float:
6260
return reward
6361

6462

65-
class GaiaGymArgs(SerializableEnvArgs):
63+
@dataclass
64+
class GaiaGymArgs(AbstractEnvArgs):
6665
task: dict[str, Any]
6766
viewport_chars: int
6867
task_seed: int

src/agentlab/benchmarks/multitool_gym.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,26 @@
1+
import logging
12
import time
2-
from typing import Annotated, Union
33

4-
from pydantic import Field, TypeAdapter
54
from tapeagents.core import Action, Observation, StopStep, Tape
65
from tapeagents.environment import ToolCollectionEnvironment
76
from tapeagents.tools.base import StatefulTool, Tool
87

98
from agentlab.benchmarks.abstract_env import AbstractEnv
109

10+
logger = logging.getLogger(__name__)
1111
EnvTape = Tape[None, Action | Observation]
1212

1313

1414
class MultiToolGym(AbstractEnv):
1515
def __init__(self, tools: list[Tool | StatefulTool]):
1616
self._env = ToolCollectionEnvironment(tools)
1717
self._actions = self._env.actions()
18-
self._actions_parser: TypeAdapter = TypeAdapter(
19-
Annotated[Union[self._actions], Field(discriminator="kind")]
20-
)
2118

2219
def reset(self):
2320
self._env.reset()
2421

2522
def step(self, action: Action) -> tuple[Observation, float, bool, bool, dict]:
23+
logger.info(f"Gym {self.__class__.__name__} step called with action {type(action)}")
2624
assert isinstance(action, Action)
2725

2826
action_exec_start = time.time()
@@ -43,9 +41,12 @@ def step(self, action: Action) -> tuple[Observation, float, bool, bool, dict]:
4341
"action_exec_stop": action_exec_stop,
4442
"action_exec_timeout": 0.0,
4543
}
44+
obs_view = observation.short_view() if isinstance(observation, Observation) else observation
45+
logger.info(f"Gym {self.__class__.__name__} observation: {obs_view}")
4646
return observation, reward, terminated, truncated, env_info
4747

4848
def calculate_reward(self, action: Action) -> float:
49+
logger.warning("Reward calculation is not implemented, returning 0")
4950
return 0.0
5051

5152
def close(self):

src/agentlab/experiments/loop.py

Lines changed: 35 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,7 @@
2323
from browsergym.experiments.utils import count_messages_token, count_tokens
2424
from dataclasses_json import DataClassJsonMixin
2525
from PIL import Image
26-
from tapeagents.core import (
27-
StepMetadata,
28-
Tape,
29-
)
26+
from tapeagents.core import StepMetadata, Tape
3027
from tapeagents.dialog_tape import AssistantStep, AssistantThought
3128
from tqdm import tqdm
3229

@@ -315,9 +312,8 @@ def run(self):
315312
err_msg = f"Exception uncaught by agent or environment in task {self.env_args.task_name}.\n{type(e).__name__}:\n{e}"
316313
logger.info("Saving experiment info.")
317314
_save_summary_info(episode_info, self.exp_dir, err_msg, stack_trace)
318-
self.save_tape(
319-
agent.final_tape if isinstance(agent, TapeAgent) else self.as_tape(episode_info)
320-
)
315+
tape = agent.final_tape if isinstance(agent, TapeAgent) else as_tape(episode_info)
316+
self.save_tape(tape)
321317
except Exception as e:
322318
logger.exception(f"Error while saving experiment info: {e}")
323319
try:
@@ -330,36 +326,11 @@ def run(self):
330326
except Exception as e:
331327
logger.exception(f"Error while unsetting the logger: {e}")
332328

333-
def as_tape(self, steps_info: list["StepInfo"]) -> Tape:
334-
"""
335-
Create a Tape object from the steps info.
336-
337-
Returns:
338-
Tape: a Tape object containing the steps and metadata.
339-
"""
340-
tape: Tape = []
341-
for step_info in steps_info:
342-
step_metadata = StepMetadata(
343-
result=dict(
344-
reward=step_info.reward,
345-
raw_reward=step_info.raw_reward,
346-
terminated=step_info.terminated,
347-
truncated=step_info.truncated,
348-
agent_info=step_info.agent_info,
349-
stats=step_info.stats,
350-
)
351-
)
352-
steps = [DictObservation(content=step_info.obs)]
353-
if thought := step_info.agent_info.get("think"):
354-
steps.append(AssistantThought(content=thought))
355-
steps.append(AssistantStep(content=step_info.action, metadata=step_metadata))
356-
tape += steps
357-
return tape
358-
359329
def save_tape(self, tape: Tape, filename: str = "tape.json"):
360-
if os.path.exists(self.exp_dir / filename):
361-
raise FileExistsError(f"{filename} already exists in {self.exp_dir}")
362-
with open(self.exp_dir / filename, "w") as f:
330+
tape_path = Path(self.exp_dir) / filename
331+
if tape_path.exists():
332+
raise FileExistsError(f"{tape_path} already exists")
333+
with open(tape_path, "w") as f:
363334
json.dump(tape.model_dump(), f, indent=2, ensure_ascii=False)
364335

365336
def _set_logger(self):
@@ -951,3 +922,31 @@ def _flatten_dict(d, parent_key="", sep="."):
951922
else:
952923
items.append((new_key, v))
953924
return dict(items)
925+
926+
927+
def as_tape(steps_info: list) -> Tape:
928+
"""
929+
Create a Tape object from the steps info.
930+
931+
Returns:
932+
Tape: a Tape object containing the steps and metadata.
933+
"""
934+
tape: Tape = []
935+
for step_info in steps_info:
936+
step_metadata = StepMetadata(
937+
other=dict(
938+
reward=step_info.reward,
939+
raw_reward=step_info.raw_reward,
940+
terminated=step_info.terminated,
941+
truncated=step_info.truncated,
942+
agent_info=step_info.agent_info,
943+
stats=step_info.stats,
944+
)
945+
)
946+
steps = [DictObservation(content=step_info.obs)]
947+
if thought := step_info.agent_info.get("think"):
948+
steps.append(AssistantThought(content=thought))
949+
if step_info.action is not None:
950+
steps.append(AssistantStep(content=step_info.action, metadata=step_metadata))
951+
tape += steps
952+
return tape

0 commit comments

Comments
 (0)