11import hashlib
2- import json
32import logging
43import shutil
54from typing import Any , cast
109 Message ,
1110 Messages ,
1211 Tool ,
13- eval_answer ,
1412)
1513
1614from .notebook_env import NBEnvironment
@@ -33,7 +31,7 @@ def __init__(
3331 answer : str | int | float | None = None , # noqa: PYI041
3432 system_prompt : str | None = None ,
3533 correct_reward : float = 1.0 ,
36- eval_mode : EvalAnswerMode ,
34+ eval_mode : EvalAnswerMode | None = None ,
3735 metadata : dict [str , Any ] | None = None , # used for NBEvalExpt
3836 mcqs : list [MultipleChoiceQuestion ] | None = None ,
3937 ** kwargs ,
@@ -66,7 +64,7 @@ async def reset(self) -> tuple[Messages, list[Tool]]:
6664
6765 return init_obs , tools
6866
69- async def submit_answer (self , answer : str | float | dict [ str , Any ] | None ) -> str : # type: ignore[override]
67+ async def submit_answer (self , answer : str ) -> str : # type: ignore[override]
7068 """Submit an answer to the problem.
7169
7270 Note that this tool may only be called once and ends the episode.
@@ -79,75 +77,50 @@ async def submit_answer(self, answer: str | float | dict[str, Any] | None) -> st
7977 self .state .done = True
8078 logger .info ("Submitting answer and closing environment" )
8179 await self .close ()
82- correct = False
8380 logger .info ("Answer: %s" , answer )
81+ return answer
8482
85- if self .eval_mode is None :
86- return CORRECT_MSG
87-
88- if isinstance (self .answer , int ):
89- try :
90- answer = int (answer ) # type: ignore[arg-type]
91- except ValueError :
92- pass
93- else :
94- correct = answer == self .answer
83+ @classmethod
84+ def eval_from_task (cls , task : str , gcs_artifact_path : str ) -> "DataAnalysisEnv" :
85+ """
86+ Used for evaluations via crow jobs.
9587
96- elif isinstance (self .answer , float ):
97- try :
98- answer = float (answer ) # type: ignore[arg-type]
99- except ValueError :
100- pass
101- else :
102- correct = abs (answer - self .answer ) < 1e-4 * self .answer
88+ Args:
89+ task: The user query structured as <data_path> | <query>
90+ gcs_artifact_path: The path to the GCS artifact – required for evaluation on crow jobs
91+ """
92+ logger .info ("Using the eval_from_task method" )
93+
94+ # Create temporary directory in GCP mounted storage volume
95+ task_hash = hashlib .sha256 (task .encode ()).hexdigest ()
96+ trajectory_path = cfg .DATA_STORAGE_PATH / f"{ task_hash } -{ time .time ()} "
97+ trajectory_path .mkdir (parents = True , exist_ok = True )
98+ logger .info ("Trajectory path: %s" , trajectory_path )
99+ nb_path = trajectory_path / NBEnvironment .NOTEBOOK_NAME
100+ # Copy task data to trajectory path
101+ for item in (cfg .DATA_STORAGE_PATH / gcs_artifact_path ).iterdir ():
102+ if item .is_file ():
103+ shutil .copy2 (item , trajectory_path )
104+ elif item .is_dir ():
105+ shutil .copytree (item , trajectory_path / item .name , dirs_exist_ok = True )
103106
104- elif isinstance (self .answer , str ):
105- correct = bool (
106- await eval_answer (
107- proposed = str (answer ),
108- correct = str (self .answer ),
109- question = self .problem ,
110- eval_mode = self .eval_mode ,
111- )
107+ language = NBLanguage .PYTHON # In future, this should be a hyperparameter
108+ if trajectory_path .exists ():
109+ logger .info (
110+ "Files in directory: %s" , [f .name for f in trajectory_path .iterdir ()]
112111 )
113- elif isinstance (self .answer , dict ): # This is for mcqs and open questions
114- # Check if answer is a json string
115- if isinstance (answer , str ): # type: ignore[unreachable]
116- # Process json into dictionary
117- try :
118- processed_answer = json .loads (answer )
119- except json .JSONDecodeError :
120- return INCORRECT_MSG
121- else :
122- processed_answer = answer if isinstance (answer , dict ) else {}
123112
124- # Loop through each question and answer
125- for question_id , agent_answer in processed_answer .items ():
126- try :
127- ideal_answer = self .answer [question_id ]
128- question = next (
129- q
130- for q in self .mcqs
131- if q .question_id .lower () == question_id .lower ()
132- )
133- correct = bool (
134- await eval_answer (
135- proposed = str (agent_answer ),
136- correct = str (ideal_answer ),
137- question = question ,
138- eval_mode = self .eval_mode ,
139- )
140- )
141- self .question_rewards [question_id ] = correct
142- except KeyError :
143- self .question_rewards [question_id ] = 0
144- average_reward = sum (self .question_rewards .values ()) / len (self .mcqs )
145- correct = round (average_reward ) == 1.0
146-
147- if correct :
148- self .state .total_reward += self .correct_reward
149- return CORRECT_MSG
150- return INCORRECT_MSG
113+ return cls (
114+ problem_id = f"data-analysis-task-{ task_hash } " ,
115+ problem = task ,
116+ # Using exact just because I won't ultimately be using env evaluation
117+ eval_mode = EvalAnswerMode .EXACT ,
118+ nb_path = nb_path ,
119+ work_dir = trajectory_path ,
120+ language = language ,
121+ system_prompt = prompts .CAPSULE_SYSTEM_PROMPT_OPEN ,
122+ use_tmp_work_dir = False ,
123+ )
151124
152125 @classmethod
153126 def from_task (
@@ -163,6 +136,8 @@ def from_task(
163136 """
164137 logger .info ("User task: %s" , task )
165138 logger .info ("GCS artifact path: %s" , gcs_artifact_path )
139+ if cfg .EVAL :
140+ return cls .eval_from_task (task , gcs_artifact_path ) # type: ignore
166141
167142 if (
168143 gcs_artifact_path
@@ -251,6 +226,7 @@ def export_frame(self) -> Frame:
251226 "total_reward" : self .state .total_reward ,
252227 "nb_state" : self .state .nb ,
253228 "nb_state_html" : nb_to_html (self .state .nb ),
229+ "nb_runtime_errors" : self .state .notebook_runtime_errors ,
254230 },
255231 info = {
256232 "eval_mode" : self .eval_mode ,
0 commit comments