Skip to content

Commit 7999bb0

Browse files
committed
separate gaia-related renderings from general tape view
1 parent 4896c67 commit 7999bb0

File tree

2 files changed

+32
-21
lines changed

2 files changed

+32
-21
lines changed

src/agentlab/analyze/tapes.py

Lines changed: 13 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from tapeagents.tape_browser import TapeBrowser
1212

1313
from agentlab.agents.tapeagent.agent import ExtendedMetadata, Tape
14+
from agentlab.benchmarks.gaia import step_error
1415

1516
logger = logging.getLogger(__name__)
1617
fmt = "%(asctime)s - %(levelname)s - %(name)s:%(lineno)d - %(funcName)s() - %(message)s"
@@ -83,7 +84,7 @@ def get_tape_files(self) -> list[str]:
8384
logger.info(f"Found {len(exps)} experiments in {self.tapes_folder}")
8485
return sorted(exps)
8586

86-
def get_steps(self, tape) -> list:
87+
def get_steps(self, tape: dict) -> list:
8788
return tape["steps"]
8889

8990
def load_llm_calls(self):
@@ -102,9 +103,10 @@ def get_tape_name(self, i: int, tape: Tape) -> str:
102103
mark = "⚠ "
103104
if tape.metadata.task.get("file_name"):
104105
mark += "📁 "
105-
n = f"{tape.metadata.task.get('Level', '')}.{tape.metadata.task.get('number','')}"
106-
name = tape[0].content["content"][:32] + "..."
107-
return f"{n} {mark}{name}"
106+
number = tape.metadata.task.get("number", "")
107+
n = f"{tape.metadata.task.get('Level', '')}.{number} " if number else ""
108+
name = tape.steps[0].content["content"][:32] + "..."
109+
return f"{n}{mark}{name}"
108110

109111
def get_exp_label(self, filename: str, tapes: list[Tape]) -> str:
110112
acc, n_solved = self.calculate_accuracy(tapes)
@@ -142,20 +144,8 @@ def get_exp_label(self, filename: str, tapes: list[Tape]) -> str:
142144
if kind.endswith("action"):
143145
actions[kind] += 1
144146
last_action = kind
145-
if kind == "search_results_observation" and not len(step_dict.get("serp")):
146-
errors["search_empty"] += 1
147-
if kind == "page_observation" and step_dict.get("error"):
148-
errors["browser"] += 1
149-
elif kind == "llm_output_parsing_failure_action":
150-
errors["parsing"] += 1
151-
elif kind == "action_execution_failure":
152-
if last_action:
153-
errors[f"{last_action}"] += 1
154-
else:
155-
errors["unknown_action_execution_failure"] += 1
156-
elif kind == "code_execution_result":
157-
if step_dict.get("result", {}).get("exit_code"):
158-
errors["code_execution"] += 1
147+
if error := self.get_step_error(step_dict, last_action):
148+
errors[error] += 1
159149
timers, timer_counts = self.aggregate_timer_times(tapes)
160150
html = f"<h2>Solved {acc:.2f}%, {n_solved} out of {len(tapes)}</h2>"
161151
if "all" in filename:
@@ -177,10 +167,13 @@ def get_exp_label(self, filename: str, tapes: list[Tape]) -> str:
177167
html += f"<h2>Timings</h2>{timers_str}"
178168
return html
179169

170+
def get_step_error(self, step_dict: dict, last_action: str | None) -> str:
171+
return step_error(step_dict, last_action)
172+
180173
def calculate_accuracy(self, tapes: list[Tape]) -> tuple[float, int]:
181174
solved = [tape.metadata.reward for tape in tapes]
182175
accuracy = 100 * (sum(solved) / len(solved) if solved else 0.0)
183-
return accuracy, sum(solved)
176+
return accuracy, int(sum(solved))
184177

185178
def aggregate_timer_times(self, tapes: list[Tape]):
186179
timer_sums = defaultdict(float)
@@ -198,7 +191,7 @@ def aggregate_timer_times(self, tapes: list[Tape]):
198191
timer_counts[action_kind] += 1
199192
return dict(timer_sums), dict(timer_counts)
200193

201-
def load_tapes(self, exp_dir: str) -> list[dict]:
194+
def load_tapes(self, exp_dir: str) -> list[Tape]:
202195
tapes: list[Tape] = []
203196
fpath = Path(self.tapes_folder) / exp_dir
204197
for json_file in fpath.rglob("tape.json"):

src/agentlab/benchmarks/gaia.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import shutil
55
import string
66
from dataclasses import dataclass
7+
from math import exp
78
from pathlib import Path
89
from typing import Any, Literal
910

@@ -78,11 +79,12 @@ def __init__(
7879
def make_env(self, exp_dir: str | Path, action_mapping=None) -> GaiaGym:
7980
exp_dir = str(exp_dir)
8081
logger.info(f"Init gaia env with directory {exp_dir}")
82+
os.environ["TAPEAGENTS_SQLITE_DB"] = os.path.join(exp_dir, "tapedata.sqlite")
8183
self.init_code_sandbox(exp_dir)
8284
tools = [
8385
WebSearch(),
8486
VideoReader(exp_path=exp_dir),
85-
Browser(exp_path=exp_dir, viewport_chars=self.viewport_chars),
87+
Browser(exp_path=exp_dir, viewport_chars=self.viewport_chars, navigation_only=True),
8688
CodeExecutor(exp_path=exp_dir, reuse_computer_container=True),
8789
]
8890
env = GaiaGym(tools=tools, task=self.task, exp_dir=exp_dir)
@@ -188,6 +190,22 @@ class GaiaAnswer(StopStep):
188190
long_answer: str = Field(description="Detailed final answer not restricted by format rules")
189191

190192

193+
def step_error(step_dict: dict, last_action: str | None) -> str:
194+
kind = step_dict.get("kind", "unknown")
195+
error = ""
196+
if kind == "search_results_observation" and not len(step_dict.get("serp", [])):
197+
error = "search_empty"
198+
elif kind == "page_observation" and step_dict.get("error"):
199+
error = "browser"
200+
elif kind == "llm_output_parsing_failure_action":
201+
error = "parsing"
202+
elif kind == "action_failure":
203+
error = last_action if last_action else "unknown_action_execution_failure"
204+
elif kind == "code_execution_result" and step_dict.get("result", {}).get("exit_code"):
205+
error = "code"
206+
return error
207+
208+
191209
def normalize_number_str(number_str: str) -> float:
192210
# we replace these common units and commas to allow
193211
# conversion to float

0 commit comments

Comments
 (0)