55import datasets
66from pydantic import Field
77from tapeagents .core import Observation , StopStep , Thought
8- from tapeagents .environment import ContainerExecutor
8+ from tapeagents .environment import ContainerExecutor , StatefulTool , Tool
99from tapeagents .steps import ImageObservation
1010from tapeagents .tools .browser import Browser
1111from tapeagents .tools .code_executor import CodeExecutor
1616from agentlab .benchmarks .multitool_gym import MultiToolGym
1717
1818
19- class GaiaBenchmark (AbstractBenchmark ):
20- exp_dir : str
21- name : str = "gaia"
22- split : Literal ["test" , "validation" ]
23-
24- def model_post_init (self , __context : Any ) -> None :
25- self .env_args_list = []
26- dataset = datasets .load_dataset ("gaia-benchmark/GAIA" , "2023_all" )[self .split ]
27- for task in dataset :
28- task_dir = os .path .join (self .name , task ["task_id" ])
29- env_args = GaiaGymArgs (task = task , exp_dir = task_dir )
30- self .env_args_list .append (env_args )
31-
32-
3319class GaiaGym (MultiToolGym ):
3420 task : dict
3521 exp_dir : str
3622
23+ def __init__ (self , tools : list [Tool | StatefulTool ], task : dict , exp_dir : str ):
24+ super ().__init__ (tools = tools )
25+ self .task = task
26+ self .exp_dir = exp_dir
27+
3728 def reset (self ) -> tuple [list [Observation ], dict ]:
3829 super ().reset ()
30+ print ("task:" , self .task )
3931 question = GaiaQuestion .from_task (self .task )
4032 steps = [question ]
4133 if image_obs := with_image (question ):
34+ print ("image_obs:" , image_obs )
4235 steps .append (image_obs )
4336 return steps
4437
@@ -52,9 +45,9 @@ def make_env(self) -> GaiaGym:
5245 self .init_code_sandbox ()
5346 tools = [
5447 WebSearch (),
55- VideoReader (self .exp_dir ),
56- Browser (self .exp_dir , viewport_chars = self .viewport_chars ),
57- CodeExecutor (self .exp_dir ),
48+ VideoReader (exp_path = self .exp_dir ),
49+ Browser (exp_path = self .exp_dir , viewport_chars = self .viewport_chars ),
50+ CodeExecutor (exp_path = self .exp_dir ),
5851 ]
5952 env = GaiaGym (tools = tools , task = self .task , exp_dir = self .exp_dir )
6053 return env
@@ -72,6 +65,21 @@ def init_code_sandbox(self) -> None:
7265 )
7366
7467
68+ class GaiaBenchmark (AbstractBenchmark ):
69+ exp_dir : str
70+ name : str = "gaia"
71+ split : Literal ["test" , "validation" ]
72+ env_args_list : list [GaiaGymArgs ] = None
73+
74+ def model_post_init (self , __context : Any ) -> None :
75+ self .env_args_list = []
76+ dataset = datasets .load_dataset ("gaia-benchmark/GAIA" , "2023_all" )[self .split ]
77+ for task in dataset :
78+ task_dir = os .path .join (self .name , task ["task_id" ])
79+ env_args = GaiaGymArgs (task = task , exp_dir = task_dir )
80+ self .env_args_list .append (env_args )
81+
82+
7583class ExtractedFacts (Thought ):
7684 """
7785 Thought that contains the list of facts extracted from the document
0 commit comments