1- import json
2- import os
1+ import time
32from typing import List , Dict , Any
43from src .server .task import Task , Session
54from src .typings import (
98 AgentOutputStatus
109)
1110from src .typings .output import TaskOutput
12-
13- # Import your existing components
1411from .agentenv_gaia .environment import GaiaEnvServer
15- from .agentenv_gaia .load_data import load_gaia_data
1612
1713class GAIATask (Task ):
1814 def __init__ (self ,
@@ -28,94 +24,125 @@ def __init__(self,
2824 self .data_dir = data_dir
2925 self .max_rounds = max_rounds
3026
31- # Initialize GAIA environment server
27+ # Initialize GAIA environment server (this handles everything!)
3228 self .gaia_server = GaiaEnvServer ()
3329
34- # Load dataset
35- self .dataset = load_gaia_data (
36- data_dir = data_dir ,
37- level = level ,
38- dataset = dataset_type
39- )
30+ # Get dataset size for indices (use existing preloaded data if available)
31+ try :
32+ if dataset_type == "validation" and hasattr (self .gaia_server , 'validation_data' ) and self .gaia_server .validation_data is not None :
33+ self .dataset_size = len (self .gaia_server .validation_data )
34+ elif dataset_type == "test" and hasattr (self .gaia_server , 'test_data' ) and self .gaia_server .test_data is not None :
35+ self .dataset_size = len (self .gaia_server .test_data )
36+ else :
37+ # Fallback: create a temp environment to determine size
38+ temp_env_id = self .gaia_server .create (id = 0 , dataset_type = dataset_type )
39+ self .dataset_size = 165 if dataset_type == "validation" else 300 # GAIA dataset sizes
40+ # Clean up temp environment
41+ if temp_env_id in self .gaia_server .env_instances :
42+ del self .gaia_server .env_instances [temp_env_id ]
43+ del self .gaia_server .env_locks [temp_env_id ]
44+ except Exception as e :
45+ print (f"Warning: Could not determine dataset size: { e } " )
46+ self .dataset_size = 165 if dataset_type == "validation" else 300
4047
4148 def get_indices (self ) -> List [SampleIndex ]:
42- return list (range (len ( self .dataset ) ))
49+ return list (range (self .dataset_size ))
4350
4451 async def start_sample (self , index : SampleIndex , session : Session ) -> TaskSampleExecutionResult :
45- # Create GAIA environment instance
46- env_id = self .gaia_server .create (
47- id = index ,
48- dataset_type = self .dataset_type
49- )
52+ """
53+ Execute a single GAIA sample - minimal wrapper around existing environment
54+ """
55+ start_time = time .time ()
5056
5157 try :
52- # Get initial observation/question
53- obs = self .gaia_server .observation (env_id )
58+ env_id = self .gaia_server .create (id = index , dataset_type = self .dataset_type )
5459
55- # Present the task to agent
56- session .inject ({"role" : "user" , "content" : obs })
60+ # Get initial observatin
61+ initial_obs = self .gaia_server .observation (env_id )
62+ session .inject ({"role" : "user" , "content" : initial_obs })
5763
58- # Multi-turn interaction
64+ # Multi-turn interaction loop
65+ final_reward = 0.0
5966 for round_num in range (self .max_rounds ):
60- # Get agent response
6167 response = await session .action ()
6268
63- if response .status != AgentOutputStatus .NORMAL :
69+ if response .status == AgentOutputStatus .AGENT_CONTEXT_LIMIT :
70+ final_status = SampleStatus .AGENT_CONTEXT_LIMIT
71+ break
72+ elif response .status != AgentOutputStatus .NORMAL :
73+ final_status = SampleStatus .AGENT_VALIDATION_FAILED
6474 break
6575
66- # Execute action in GAIA environment
67- step_result = self .gaia_server .step (env_id , response .content )
76+ observation , reward , done , info = self .gaia_server .step (env_id , response .content or "" )
77+ final_reward = reward
78+ if observation :
79+ session .inject ({"role" : "user" , "content" : observation })
6880
69- # Check if task is complete
70- if step_result .get ("done" , False ):
81+ # Check if done
82+ if done :
83+ final_status = SampleStatus .COMPLETED
7184 break
72-
73- # Provide feedback to agent
74- if "observation" in step_result :
75- session .inject ({
76- "role" : "user" ,
77- "content" : step_result ["observation" ]
78- })
85+ else :
86+ final_status = SampleStatus .TASK_LIMIT_REACHED
7987
80- # Get final result
81- final_result = self .gaia_server .observation (env_id )
82-
83- # Evaluate answer
84- correct_answer = self .dataset .iloc [index ]["true_answer" ]
85- agent_answer = self ._extract_final_answer (final_result )
86-
87- score = 1.0 if self ._evaluate_answer (agent_answer , correct_answer ) else 0.0
88+ env_data = self .gaia_server .env_instances [env_id ]
89+ dataset_item = env_data ["dataset" ]
8890
8991 return TaskSampleExecutionResult (
90- status = SampleStatus . COMPLETED ,
92+ status = final_status ,
9193 result = {
92- "score" : score ,
93- "agent_answer" : agent_answer ,
94- "correct_answer" : correct_answer ,
95- "question" : self .dataset .iloc [index ]["question" ]
94+ "score" : final_reward ,
95+ "question" : dataset_item .get ("question" , "" ),
96+ "true_answer" : dataset_item .get ("true_answer" , "" ),
97+ "rounds_used" : round_num + 1 if 'round_num' in locals () else 0 ,
98+ "execution_time" : time .time () - start_time ,
99+ "level" : self .level ,
100+ "dataset_type" : self .dataset_type ,
101+ "steps_taken" : info .get ("steps_taken" , 0 ) if 'info' in locals () else 0 ,
102+ "env_state" : env_data ["state" ] if final_status == SampleStatus .COMPLETED else {}
96103 }
97104 )
98105
106+ except Exception as e :
107+ return TaskSampleExecutionResult (
108+ status = SampleStatus .TASK_ERROR ,
109+ result = {
110+ "error" : f"Task error: { str (e )} " ,
111+ "score" : 0.0 ,
112+ "execution_time" : time .time () - start_time
113+ }
114+ )
115+
99116 finally :
100117 # Clean up environment
101- if env_id :
102- # Add cleanup method to GaiaEnvServer if needed
118+ try :
119+ if 'env_id' in locals () and env_id in self .gaia_server .env_instances :
120+ del self .gaia_server .env_instances [env_id ]
121+ del self .gaia_server .env_locks [env_id ]
122+ except :
103123 pass
104124
105125 def calculate_overall (self , results : List [TaskOutput ]) -> Dict [str , Any ]:
106- scores = [r .result .get ("score" , 0.0 ) for r in results if r .result ]
126+ """
127+ Calculate aggregate metrics - simple and minimal
128+ """
129+ valid_results = [r for r in results if r .result ]
130+
131+ if not valid_results :
132+ return {
133+ "accuracy" : 0.0 ,
134+ "total_samples" : len (results ),
135+ "error_rate" : 1.0
136+ }
137+
138+ scores = [r .result .get ("score" , 0.0 ) for r in valid_results ]
139+ error_count = sum (1 for r in valid_results if "error" in r .result )
140+
107141 return {
108142 "accuracy" : sum (scores ) / len (scores ) if scores else 0.0 ,
109143 "total_samples" : len (results ),
110- "correct_answers" : sum (scores )
111- }
112-
113- def _extract_final_answer (self , result ):
114- # Extract final answer from GAIA environment result
115- # Implement based on your result format
116- pass
117-
118- def _evaluate_answer (self , agent_answer , correct_answer ):
119- # Implement GAIA's answer evaluation logic
120- # Handle exact match with normalization
121- pass
144+ "valid_samples" : len (valid_results ),
145+ "error_rate" : error_count / len (valid_results ) if valid_results else 1.0 ,
146+ "level" : self .level ,
147+ "dataset_type" : self .dataset_type
148+ }
0 commit comments