1- import fcntl
21import logging
32import os
43import re
54import shutil
65import string
76from dataclasses import dataclass
87from pathlib import Path
9- from typing import Any , Literal
8+ from typing import Any , Literal , Self
109
1110import datasets
11+ import hydra
12+ import podman
13+ from omegaconf import DictConfig
1214from pdf2image import convert_from_path
13- from pydantic import Field
15+ from pydantic import ConfigDict , Field
1416from tapeagents .core import Action , Observation , StopStep , Thought
1517from tapeagents .environment import ContainerExecutor , StatefulTool , Tool
1618from tapeagents .steps import ImageObservation
17- from tapeagents .tools .browser import Browser
18- from tapeagents .tools .code_executor import CodeExecutor
19- from tapeagents .tools .media_reader import VideoReader
2019from tapeagents .tools .simple_browser import SimpleTextBrowser
21- from tapeagents .tools .web_search import WebSearch
2220
2321from agentlab .benchmarks .abstract_env import AbstractBenchmark , AbstractEnvArgs
2422from agentlab .benchmarks .multitool_gym import MultiToolGym
2523
2624logger = logging .getLogger (__name__ )
2725
26+ CONTAINER_NAME = "gaia_code_shared"
27+
2828
2929class GaiaGym (MultiToolGym ):
3030 task : dict
@@ -61,30 +61,33 @@ def calculate_reward(self, action: Action) -> float:
6161
6262@dataclass
6363class GaiaGymArgs (AbstractEnvArgs ):
64+ model_config = ConfigDict (arbitrary_types_allowed = True )
6465 task : dict [str , Any ]
65- viewport_chars : int
6666 task_seed : int
6767 task_name : str
68+ env_config : DictConfig
6869
6970 def __init__ (
70- self , task_name : str , task : dict [str , Any ], viewport_chars : int = 64000 , task_seed : int = 0
71+ self ,
72+ task_name : str ,
73+ task : dict [str , Any ],
74+ env_config : DictConfig ,
75+ task_seed : int = 0 ,
7176 ):
7277 self .task_name = task_name
7378 self .task = task
74- self .viewport_chars = viewport_chars
7579 self .task_seed = task_seed
80+ self .env_config = env_config
7681
7782 def make_env (self , exp_dir : str | Path , action_mapping = None ) -> GaiaGym :
7883 exp_dir = str (exp_dir )
7984 logger .info (f"Init gaia env with directory { exp_dir } " )
8085 os .environ ["TAPEAGENTS_SQLITE_DB" ] = os .path .join (exp_dir , "tapedata.sqlite" )
8186 init_code_sandbox (exp_dir )
82- tools = [
83- WebSearch (),
84- VideoReader (exp_path = exp_dir ),
85- Browser (exp_path = exp_dir , viewport_chars = self .viewport_chars , navigation_only = True ),
86- CodeExecutor (exp_path = exp_dir , reuse_computer_container = True ),
87- ]
87+ for i in range (len (self .env_config .tools )):
88+ if hasattr (self .env_config .tools [i ], "exp_path" ):
89+ self .env_config .tools [i ].exp_path = exp_dir
90+ tools = hydra .utils .instantiate (self .env_config .tools )
8891 env = GaiaGym (tools = tools , task = self .task , exp_dir = exp_dir )
8992 return env
9093
@@ -94,27 +97,43 @@ def init_code_sandbox(exp_dir: str) -> None:
9497 root_exp_dir = Path (exp_dir ).parent
9598 code_path = os .path .join (root_exp_dir , "shared_code" )
9699 os .makedirs (code_path , exist_ok = True )
97-
98- container_name = "gaia_code_shared"
99- os .environ ["COMPUTER_CONTAINER_NAME" ] = container_name
100+ os .environ ["COMPUTER_CONTAINER_NAME" ] = CONTAINER_NAME
100101
101102 # symlink task code to the shared code directory
102103 task_code_path = os .path .join (exp_dir , "code" )
103104 if not os .path .exists (task_code_path ):
104105 os .symlink (code_path , task_code_path )
105106
106107 try :
107- ContainerExecutor (container_name = container_name , work_dir = code_path , no_deps = True )
108+ ContainerExecutor (container_name = CONTAINER_NAME , work_dir = code_path , no_deps = True )
108109 except Exception as e :
109110 logger .warning (f"Failed to initialize container executor: { e } " )
110111
111112
113+ def stop_old_sandbox ():
114+ try :
115+ podman .from_env ().containers .get (CONTAINER_NAME ).stop ()
116+ except Exception as e :
117+ logger .warning (f"Failed to stop old container { CONTAINER_NAME } : { e } " )
118+
119+
112120class GaiaBenchmark (AbstractBenchmark ):
121+ model_config = ConfigDict (arbitrary_types_allowed = True )
113122 name : str = "gaia"
114123 split : Literal ["test" , "validation" ]
115124 level : Literal ["1" , "2" , "3" , "all" ] = "all"
116125 env_args_list : list [GaiaGymArgs ] = None # type: ignore
117126 dataset : dict = None # type: ignore
127+ env_config : DictConfig = None # type: ignore
128+
129+ @classmethod
130+ def from_config (cls , config : DictConfig , dataset : dict = None ) -> Self :
131+ return cls (
132+ split = config .split ,
133+ level = config .level ,
134+ env_config = config .environment ,
135+ dataset = dataset ,
136+ )
118137
119138 def model_post_init (self , __context : Any ) -> None :
120139 if not self .dataset :
@@ -130,7 +149,8 @@ def model_post_init(self, __context: Any) -> None:
130149 continue
131150 number += 1
132151 task ["number" ] = number
133- env_args = GaiaGymArgs (task_name = "gaia." + task ["task_id" ], task = task )
152+ name = f"gaia.{ task ['task_id' ]} "
153+ env_args = GaiaGymArgs (task_name = name , task = task , env_config = self .env_config )
134154 self .env_args_list .append (env_args )
135155 logger .info (f"Loaded { len (self .env_args_list )} tasks from { self .split } split" )
136156
0 commit comments