Skip to content

Commit ffe0332

Browse files
revise gaia task.py, add init
1 parent 2d6fecf commit ffe0332

File tree

2 files changed

+94
-64
lines changed

2 files changed

+94
-64
lines changed
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from .task import GAIATask
2+
3+
__all__ = ["GAIATask"]
Lines changed: 91 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
1-
import json
2-
import os
1+
import time
32
from typing import List, Dict, Any
43
from src.server.task import Task, Session
54
from src.typings import (
@@ -9,10 +8,7 @@
98
AgentOutputStatus
109
)
1110
from src.typings.output import TaskOutput
12-
13-
# Import your existing components
1411
from .agentenv_gaia.environment import GaiaEnvServer
15-
from .agentenv_gaia.load_data import load_gaia_data
1612

1713
class GAIATask(Task):
1814
def __init__(self,
@@ -28,94 +24,125 @@ def __init__(self,
2824
self.data_dir = data_dir
2925
self.max_rounds = max_rounds
3026

31-
# Initialize GAIA environment server
27+
# Initialize GAIA environment server (this handles everything!)
3228
self.gaia_server = GaiaEnvServer()
3329

34-
# Load dataset
35-
self.dataset = load_gaia_data(
36-
data_dir=data_dir,
37-
level=level,
38-
dataset=dataset_type
39-
)
30+
# Get dataset size for indices (use existing preloaded data if available)
31+
try:
32+
if dataset_type == "validation" and hasattr(self.gaia_server, 'validation_data') and self.gaia_server.validation_data is not None:
33+
self.dataset_size = len(self.gaia_server.validation_data)
34+
elif dataset_type == "test" and hasattr(self.gaia_server, 'test_data') and self.gaia_server.test_data is not None:
35+
self.dataset_size = len(self.gaia_server.test_data)
36+
else:
37+
# Fallback: create a temp environment to determine size
38+
temp_env_id = self.gaia_server.create(id=0, dataset_type=dataset_type)
39+
self.dataset_size = 165 if dataset_type == "validation" else 300 # GAIA dataset sizes
40+
# Clean up temp environment
41+
if temp_env_id in self.gaia_server.env_instances:
42+
del self.gaia_server.env_instances[temp_env_id]
43+
del self.gaia_server.env_locks[temp_env_id]
44+
except Exception as e:
45+
print(f"Warning: Could not determine dataset size: {e}")
46+
self.dataset_size = 165 if dataset_type == "validation" else 300
4047

4148
def get_indices(self) -> List[SampleIndex]:
42-
return list(range(len(self.dataset)))
49+
return list(range(self.dataset_size))
4350

4451
async def start_sample(self, index: SampleIndex, session: Session) -> TaskSampleExecutionResult:
45-
# Create GAIA environment instance
46-
env_id = self.gaia_server.create(
47-
id=index,
48-
dataset_type=self.dataset_type
49-
)
52+
"""
53+
Execute a single GAIA sample - minimal wrapper around existing environment
54+
"""
55+
start_time = time.time()
5056

5157
try:
52-
# Get initial observation/question
53-
obs = self.gaia_server.observation(env_id)
58+
env_id = self.gaia_server.create(id=index, dataset_type=self.dataset_type)
5459

55-
# Present the task to agent
56-
session.inject({"role": "user", "content": obs})
60+
# Get initial observatin
61+
initial_obs = self.gaia_server.observation(env_id)
62+
session.inject({"role": "user", "content": initial_obs})
5763

58-
# Multi-turn interaction
64+
# Multi-turn interaction loop
65+
final_reward = 0.0
5966
for round_num in range(self.max_rounds):
60-
# Get agent response
6167
response = await session.action()
6268

63-
if response.status != AgentOutputStatus.NORMAL:
69+
if response.status == AgentOutputStatus.AGENT_CONTEXT_LIMIT:
70+
final_status = SampleStatus.AGENT_CONTEXT_LIMIT
71+
break
72+
elif response.status != AgentOutputStatus.NORMAL:
73+
final_status = SampleStatus.AGENT_VALIDATION_FAILED
6474
break
6575

66-
# Execute action in GAIA environment
67-
step_result = self.gaia_server.step(env_id, response.content)
76+
observation, reward, done, info = self.gaia_server.step(env_id, response.content or "")
77+
final_reward = reward
78+
if observation:
79+
session.inject({"role": "user", "content": observation})
6880

69-
# Check if task is complete
70-
if step_result.get("done", False):
81+
# Check if done
82+
if done:
83+
final_status = SampleStatus.COMPLETED
7184
break
72-
73-
# Provide feedback to agent
74-
if "observation" in step_result:
75-
session.inject({
76-
"role": "user",
77-
"content": step_result["observation"]
78-
})
85+
else:
86+
final_status = SampleStatus.TASK_LIMIT_REACHED
7987

80-
# Get final result
81-
final_result = self.gaia_server.observation(env_id)
82-
83-
# Evaluate answer
84-
correct_answer = self.dataset.iloc[index]["true_answer"]
85-
agent_answer = self._extract_final_answer(final_result)
86-
87-
score = 1.0 if self._evaluate_answer(agent_answer, correct_answer) else 0.0
88+
env_data = self.gaia_server.env_instances[env_id]
89+
dataset_item = env_data["dataset"]
8890

8991
return TaskSampleExecutionResult(
90-
status=SampleStatus.COMPLETED,
92+
status=final_status,
9193
result={
92-
"score": score,
93-
"agent_answer": agent_answer,
94-
"correct_answer": correct_answer,
95-
"question": self.dataset.iloc[index]["question"]
94+
"score": final_reward,
95+
"question": dataset_item.get("question", ""),
96+
"true_answer": dataset_item.get("true_answer", ""),
97+
"rounds_used": round_num + 1 if 'round_num' in locals() else 0,
98+
"execution_time": time.time() - start_time,
99+
"level": self.level,
100+
"dataset_type": self.dataset_type,
101+
"steps_taken": info.get("steps_taken", 0) if 'info' in locals() else 0,
102+
"env_state": env_data["state"] if final_status == SampleStatus.COMPLETED else {}
96103
}
97104
)
98105

106+
except Exception as e:
107+
return TaskSampleExecutionResult(
108+
status=SampleStatus.TASK_ERROR,
109+
result={
110+
"error": f"Task error: {str(e)}",
111+
"score": 0.0,
112+
"execution_time": time.time() - start_time
113+
}
114+
)
115+
99116
finally:
100117
# Clean up environment
101-
if env_id:
102-
# Add cleanup method to GaiaEnvServer if needed
118+
try:
119+
if 'env_id' in locals() and env_id in self.gaia_server.env_instances:
120+
del self.gaia_server.env_instances[env_id]
121+
del self.gaia_server.env_locks[env_id]
122+
except:
103123
pass
104124

105125
def calculate_overall(self, results: List[TaskOutput]) -> Dict[str, Any]:
106-
scores = [r.result.get("score", 0.0) for r in results if r.result]
126+
"""
127+
Calculate aggregate metrics - simple and minimal
128+
"""
129+
valid_results = [r for r in results if r.result]
130+
131+
if not valid_results:
132+
return {
133+
"accuracy": 0.0,
134+
"total_samples": len(results),
135+
"error_rate": 1.0
136+
}
137+
138+
scores = [r.result.get("score", 0.0) for r in valid_results]
139+
error_count = sum(1 for r in valid_results if "error" in r.result)
140+
107141
return {
108142
"accuracy": sum(scores) / len(scores) if scores else 0.0,
109143
"total_samples": len(results),
110-
"correct_answers": sum(scores)
111-
}
112-
113-
def _extract_final_answer(self, result):
114-
# Extract final answer from GAIA environment result
115-
# Implement based on your result format
116-
pass
117-
118-
def _evaluate_answer(self, agent_answer, correct_answer):
119-
# Implement GAIA's answer evaluation logic
120-
# Handle exact match with normalization
121-
pass
144+
"valid_samples": len(valid_results),
145+
"error_rate": error_count / len(valid_results) if valid_results else 1.0,
146+
"level": self.level,
147+
"dataset_type": self.dataset_type
148+
}

0 commit comments

Comments
 (0)