Skip to content

Commit 3fd383d

Browse files
committed
fix
1 parent 7999bb0 commit 3fd383d

File tree

1 file changed

+10
-8
lines changed

1 file changed

+10
-8
lines changed

src/agentlab/benchmarks/gaia.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
import datasets
1212
from pydantic import Field
13-
from tapeagents.core import Action, Observation, StopStep, Thought
13+
from tapeagents.core import Action, Observation, Step, StopStep, Thought
1414
from tapeagents.environment import ContainerExecutor, StatefulTool, Tool
1515
from tapeagents.steps import ImageObservation
1616
from tapeagents.tools.browser import Browser
@@ -40,7 +40,7 @@ def reset(self, seed=None) -> tuple[list[Observation], dict]:
4040
"""
4141
super().reset()
4242
question = GaiaQuestion.from_task(self.task)
43-
steps = [question]
43+
steps: list[Observation] = [question]
4444
if image_obs := with_image(question):
4545
steps.append(image_obs)
4646
return steps, {}
@@ -120,10 +120,12 @@ def model_post_init(self, __context: Any) -> None:
120120
"gaia-benchmark/GAIA", "2023_all", trust_remote_code=True
121121
) # type: ignore
122122
self.env_args_list = []
123-
for i, task in enumerate(self.dataset[self.split]):
123+
number = 0
124+
for task in self.dataset[self.split]:
124125
if self.level != "all" and task["Level"] != self.level:
125126
continue
126-
task["number"] = i
127+
number += 1
128+
task["number"] = number
127129
env_args = GaiaGymArgs(task_name="gaia." + task["task_id"], task=task)
128130
self.env_args_list.append(env_args)
129131
logger.info(f"Loaded {len(self.env_args_list)} tasks from {self.split} split")
@@ -141,7 +143,7 @@ class ExtractedFacts(Thought):
141143

142144

143145
class GaiaQuestion(Observation):
144-
kind: Literal["question"] = "question"
146+
kind: Literal["question"] = "question" # type: ignore
145147
content: str
146148
filename: str | None = None
147149

@@ -178,7 +180,7 @@ class GaiaAnswer(StopStep):
178180
If unable to determine the final answer, output an empty string.
179181
"""
180182

181-
kind: Literal["gaia_answer_action"] = "gaia_answer_action"
183+
kind: Literal["gaia_answer_action"] = "gaia_answer_action" # type: ignore
182184
success: bool = Field(description="True if the task was successful, False otherwise")
183185
overview: str = Field(
184186
description="List of steps performed to answer the question. If the task was not successful, includes the reason for failure"
@@ -199,8 +201,8 @@ def step_error(step_dict: dict, last_action: str | None) -> str:
199201
error = "browser"
200202
elif kind == "llm_output_parsing_failure_action":
201203
error = "parsing"
202-
elif kind == "action_failure":
203-
error = last_action if last_action else "unknown_action_execution_failure"
204+
elif kind == "action_execution_failure":
205+
error = last_action if last_action else "action_failure"
204206
elif kind == "code_execution_result" and step_dict.get("result", {}).get("exit_code"):
205207
error = "code"
206208
return error

0 commit comments

Comments
 (0)