Skip to content

Commit cabc393

Browse files
committed
more info in tape metadata, better tape browser
1 parent 13eec41 commit cabc393

File tree

5 files changed

+83
-41
lines changed

5 files changed

+83
-41
lines changed

src/agentlab/agents/tapeagent/agent.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
import logging
22
from dataclasses import dataclass
3-
from typing import Any, Literal
3+
from typing import Literal
44

55
import bgym
66
import hydra
7+
from pydantic import Field
78
from tapeagents.agent import Agent
8-
from tapeagents.core import Action, Observation, Tape, TapeMetadata, Thought
9+
from tapeagents.core import Action, Observation, TapeMetadata, Thought
10+
from tapeagents.core import Tape as BaseTape
911

1012
from agentlab.agents.agent_args import AgentArgs
1113

@@ -23,6 +25,10 @@ class ExtendedMetadata(TapeMetadata):
2325
other: dict = {}
2426

2527

28+
class Tape(BaseTape):
29+
metadata: ExtendedMetadata = Field(default_factory=ExtendedMetadata)
30+
31+
2632
@dataclass
2733
class TapeAgentArgs(AgentArgs):
2834
agent_name: str

src/agentlab/analyze/tapes.py

Lines changed: 51 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,13 @@
44
from collections import defaultdict
55
from pathlib import Path
66

7+
import numpy as np
78
import yaml
8-
from tapeagents.core import Step, StepMetadata, Tape
9+
from tapeagents.core import Step, StepMetadata
910
from tapeagents.renderers.camera_ready_renderer import CameraReadyRenderer
1011
from tapeagents.tape_browser import TapeBrowser
1112

12-
from agentlab.agents.tapeagent.agent import ExtendedMetadata
13+
from agentlab.agents.tapeagent.agent import ExtendedMetadata, Tape
1314

1415
logger = logging.getLogger(__name__)
1516
fmt = "%(asctime)s - %(levelname)s - %(name)s:%(lineno)d - %(funcName)s() - %(message)s"
@@ -20,6 +21,10 @@ class WrapperStep(Step):
2021
content: dict
2122

2223

24+
def pretty_yaml(data: dict) -> str:
25+
return yaml.dump(data, sort_keys=False, indent=2) if data else ""
26+
27+
2328
class TapesRender(CameraReadyRenderer):
2429

2530
@property
@@ -31,36 +36,35 @@ def render_step(self, step: WrapperStep, index: int, **kwargs):
3136
step_dict = step.content.copy()
3237
step_dict.pop("metadata", None)
3338
kind = step_dict.pop("kind", "Step")
39+
if kind == "set_next_node":
40+
return ""
3441
# remove empty keys
3542
step_dict = {k: v for k, v in step_dict.items() if v is not None and v != ""}
3643
if len(step_dict) == 1:
3744
content = list(step_dict.values())[0]
3845
elif kind == "page_observation":
39-
content = step_dict["text"]
46+
content = step_dict.get("text", pretty_yaml(step_dict))
4047
if len(content) > 100:
4148
summary = content[:100]
4249
content = f"<details><summary>{summary}</summary>---<br>{content}</details>"
4350
elif kind == "python_code_action":
44-
content = step_dict["code"]
51+
content = step_dict.get("code", pretty_yaml(step_dict))
4552
elif kind == "code_execution_result":
46-
content = yaml.dump(step_dict["result"], sort_keys=False, indent=2)
53+
content = pretty_yaml(step_dict.get("result"))
4754
else:
48-
content = yaml.dump(step_dict, sort_keys=False, indent=2) if step_dict else ""
55+
content = pretty_yaml(step_dict)
4956

50-
if kind.endswith("thought"):
57+
if step_dict.get("error") or step_dict.get("result", {}).get("exit_code"):
58+
class_ = "error"
59+
elif kind.endswith("thought"):
5160
class_ = "thought"
5261
kind = kind[:-8]
5362
elif kind.endswith("action"):
5463
class_ = "action"
5564
kind = kind[:-7]
5665
else:
5766
class_ = "observation"
58-
return (
59-
f"<div class='basic-renderer-box {class_}'>"
60-
f"<h4 class='step-header'>{kind}</h4>"
61-
f"<pre class='step-text'>{content}</pre>"
62-
f"</div>"
63-
)
67+
return f"<div class='basic-renderer-box {class_}'><h4 class='step-header'>{kind}</h4><pre class='step-text'>{content}</pre></div>"
6468

6569

6670
class TapesBrowser(TapeBrowser):
@@ -89,10 +93,21 @@ def get_context(self, tape: Tape) -> list:
8993
return []
9094

9195
def get_tape_name(self, i: int, tape: Tape) -> str:
92-
return tape[0].content["content"][:32] + "..."
96+
errors = [
97+
bool(s.content.get("error", False) or s.content.get("result", {}).get("exit_code"))
98+
for s in tape.steps
99+
]
100+
mark = "✅ " if tape.metadata.reward > 0 else ""
101+
if any(errors):
102+
mark = "⚠ "
103+
if tape.metadata.task.get("file_name"):
104+
mark += "📁 "
105+
n = f"{tape.metadata.task.get('Level', '')}.{tape.metadata.task.get('number','')}"
106+
name = tape[0].content["content"][:32] + "..."
107+
return f"{n} {mark}{name}"
93108

94109
def get_exp_label(self, filename: str, tapes: list[Tape]) -> str:
95-
acc, n_solved = 0, 0 # calculate_accuracy(tapes)
110+
acc, n_solved = self.calculate_accuracy(tapes)
96111
errors = defaultdict(int)
97112
prompt_tokens_num = 0
98113
output_tokens_num = 0
@@ -106,8 +121,10 @@ def get_exp_label(self, filename: str, tapes: list[Tape]) -> str:
106121
prompt_tokens_num += llm_call.prompt_length_tokens
107122
output_tokens_num += llm_call.output_length_tokens
108123
total_cost += llm_call.cost
124+
avg_steps = np.mean([len(tape) for tape in tapes])
125+
std_steps = np.std([len(tape) for tape in tapes])
109126
for tape in tapes:
110-
if tape.metadata.result in ["", None, "None"]:
127+
if not tape.metadata.terminated:
111128
no_result += 1
112129
if tape.metadata.error:
113130
errors["fatal"] += 1
@@ -125,9 +142,9 @@ def get_exp_label(self, filename: str, tapes: list[Tape]) -> str:
125142
if kind.endswith("action"):
126143
actions[kind] += 1
127144
last_action = kind
128-
if kind == "search_results_observation" and not len(step_dict["serp"]):
145+
if kind == "search_results_observation" and not len(step_dict.get("serp")):
129146
errors["search_empty"] += 1
130-
if kind == "page_observation" and step_dict["error"]:
147+
if kind == "page_observation" and step_dict.get("error"):
131148
errors["browser"] += 1
132149
elif kind == "llm_output_parsing_failure_action":
133150
errors["parsing"] += 1
@@ -136,13 +153,15 @@ def get_exp_label(self, filename: str, tapes: list[Tape]) -> str:
136153
errors[f"{last_action}"] += 1
137154
else:
138155
errors["unknown_action_execution_failure"] += 1
139-
elif kind == "code_execution_result" and step_dict["result"]["exit_code"]:
140-
errors["code_execution"] += 1
156+
elif kind == "code_execution_result":
157+
if step_dict.get("result", {}).get("exit_code"):
158+
errors["code_execution"] += 1
141159
timers, timer_counts = self.aggregate_timer_times(tapes)
142160
html = f"<h2>Solved {acc:.2f}%, {n_solved} out of {len(tapes)}</h2>"
143161
if "all" in filename:
144162
html += f"Prompt tokens: {prompt_tokens_num}<br>Output tokens: {output_tokens_num}<br>Cost: {total_cost:.2f} USD<h3>Visible</h3>"
145163
html += f"Prompt tokens: {visible_prompt_tokens_num}<br>Output tokens: {visible_output_tokens_num}<br>Cost: {visible_cost:.2f} USD"
164+
html += f"<h2>Steps per tape: {avg_steps:.1f} ± {std_steps:.1f}</h2>"
146165
if errors:
147166
errors_str = "<br>".join(f"{k}: {v}" for k, v in errors.items())
148167
html += f"<h2>No result: {no_result}</h2>"
@@ -158,6 +177,11 @@ def get_exp_label(self, filename: str, tapes: list[Tape]) -> str:
158177
html += f"<h2>Timings</h2>{timers_str}"
159178
return html
160179

180+
def calculate_accuracy(self, tapes: list[Tape]) -> tuple[float, int]:
181+
solved = [tape.metadata.reward for tape in tapes]
182+
accuracy = 100 * (sum(solved) / len(solved) if solved else 0.0)
183+
return accuracy, sum(solved)
184+
161185
def aggregate_timer_times(self, tapes: list[Tape]):
162186
timer_sums = defaultdict(float)
163187
timer_counts = defaultdict(int)
@@ -175,7 +199,7 @@ def aggregate_timer_times(self, tapes: list[Tape]):
175199
return dict(timer_sums), dict(timer_counts)
176200

177201
def load_tapes(self, exp_dir: str) -> list[dict]:
178-
tape_dicts = []
202+
tapes: list[Tape] = []
179203
fpath = Path(self.tapes_folder) / exp_dir
180204
for json_file in fpath.rglob("tape.json"):
181205
if json_file.stat().st_size == 0:
@@ -189,11 +213,14 @@ def load_tapes(self, exp_dir: str) -> list[dict]:
189213
WrapperStep(content=s, metadata=StepMetadata(**s["metadata"]))
190214
for s in tape_dict["steps"]
191215
]
192-
tape_dicts.append(tape)
216+
tapes.append(tape)
193217
except Exception as e:
194218
logger.warning(f"Failed to load {json_file}: {e}")
195-
logger.info(f"Loaded {len(tape_dicts)} tapes from {exp_dir}")
196-
return tape_dicts
219+
logger.info(f"Loaded {len(tapes)} tapes from {exp_dir}")
220+
return sorted(
221+
tapes,
222+
key=lambda x: f"{x.metadata.task.get('Level', '')}{x.metadata.task.get('number', 0):03d}",
223+
)
197224

198225
def save_annotation(self, step: int, annotation: str, tape_id: int):
199226
pass

src/agentlab/benchmarks/gaia.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ def __init__(
7777

7878
def make_env(self, exp_dir: str | Path, action_mapping=None) -> GaiaGym:
7979
exp_dir = str(exp_dir)
80+
logger.info(f"Init gaia env with directory {exp_dir}")
8081
self.init_code_sandbox(exp_dir)
8182
tools = [
8283
WebSearch(),
@@ -90,15 +91,9 @@ def make_env(self, exp_dir: str | Path, action_mapping=None) -> GaiaGym:
9091
def init_code_sandbox(self, exp_dir: str) -> None:
9192
code_path = os.path.join(exp_dir, "code")
9293
os.makedirs(code_path, exist_ok=True)
93-
container_name = "gaia_code_sandbox"
94+
container_name = f"gaia_code_{self.task['task_id'][:8]}"
9495
os.environ["COMPUTER_CONTAINER_NAME"] = container_name
95-
ContainerExecutor(
96-
work_dir=code_path,
97-
container_name=container_name,
98-
restart_if_exists=False,
99-
stop_container=False,
100-
no_deps=True,
101-
)
96+
ContainerExecutor(container_name=container_name, work_dir=code_path, no_deps=True)
10297

10398

10499
class GaiaBenchmark(AbstractBenchmark):
@@ -112,9 +107,10 @@ def model_post_init(self, __context: Any) -> None:
112107
if not self.dataset:
113108
self.dataset = datasets.load_dataset("gaia-benchmark/GAIA", "2023_all")
114109
self.env_args_list = []
115-
for task in self.dataset[self.split]:
110+
for i, task in enumerate(self.dataset[self.split]):
116111
if self.level != "all" and task["Level"] != self.level:
117112
continue
113+
task["number"] = i
118114
env_args = GaiaGymArgs(task_name="gaia." + task["task_id"], task=task)
119115
self.env_args_list.append(env_args)
120116
logger.info(f"Loaded {len(self.env_args_list)} tasks from {self.split} split")

src/agentlab/benchmarks/multitool_gym.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,13 @@
11
import logging
22
import time
33

4-
from tapeagents.core import Action, Observation, StopStep, Tape
4+
from tapeagents.core import Action, Observation, StopStep
55
from tapeagents.environment import ToolCollectionEnvironment
66
from tapeagents.tools.base import StatefulTool, Tool
77

88
from agentlab.benchmarks.abstract_env import AbstractEnv
99

1010
logger = logging.getLogger(__name__)
11-
EnvTape = Tape[None, Action | Observation]
1211

1312

1413
class MultiToolGym(AbstractEnv):

src/agentlab/experiments/loop.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,17 @@
2323
from browsergym.experiments.utils import count_messages_token, count_tokens
2424
from dataclasses_json import DataClassJsonMixin
2525
from PIL import Image
26-
from tapeagents.core import Step, StepMetadata, Tape
26+
from tapeagents.core import Step, StepMetadata, TapeMetadata
2727
from tapeagents.dialog_tape import AssistantStep, AssistantThought
2828
from tapeagents.io import save_json_tape, save_tape_images
2929
from tqdm import tqdm
3030

31-
from agentlab.agents.tapeagent.agent import DictObservation, TapeAgent
31+
from agentlab.agents.tapeagent.agent import (
32+
DictObservation,
33+
ExtendedMetadata,
34+
Tape,
35+
TapeAgent,
36+
)
3237

3338
logger = logging.getLogger(__name__)
3439

@@ -314,8 +319,8 @@ def run(self):
314319
logger.info("Saving experiment info.")
315320
_save_summary_info(episode_info, self.exp_dir, err_msg, stack_trace)
316321
if isinstance(agent, TapeAgent):
317-
save_json_tape(agent.final_tape, self.exp_dir, "tape.json")
318-
save_tape_images(agent.final_tape, self.exp_dir / "tape_attachments")
322+
task = getattr(env, "task", {})
323+
save_tape(self.exp_dir, episode_info, task, agent.final_tape)
319324
except Exception as e:
320325
logger.exception(f"Error while saving experiment info: {e}")
321326
try:
@@ -949,3 +954,12 @@ def as_tape(steps_info: list[StepInfo]) -> Tape:
949954
)
950955
steps.append(AssistantStep(content=step_info.action, metadata=step_metadata))
951956
return Tape(steps=steps)
957+
958+
959+
def save_tape(exp_dir: str, episode_info: list[StepInfo], task: dict, tape: Tape):
960+
tape.metadata.reward = sum([step.reward for step in episode_info])
961+
tape.metadata.truncated = episode_info[-1].truncated
962+
tape.metadata.terminated = episode_info[-1].terminated
963+
tape.metadata.task = task
964+
save_json_tape(tape, exp_dir, "tape.json")
965+
save_tape_images(tape, exp_dir / "tape_attachments")

0 commit comments

Comments
 (0)