11import hashlib
22import logging
33import shutil
4+ import json
45from typing import Any , cast
56import time
67from aviary .core import (
@@ -80,8 +81,29 @@ async def submit_answer(self, answer: str) -> str: # type: ignore[override]
8081 logger .info ("Answer: %s" , answer )
8182 return answer
8283
84+ def export_frame (self ) -> Frame :
85+ return Frame (
86+ state = {
87+ "last_action" : self .state .actions [- 1 ],
88+ "answer" : self .state .answer ,
89+ "done" : self .state .done ,
90+ "total_reward" : self .state .total_reward ,
91+ "nb_state" : self .state .nb ,
92+ "nb_state_html" : nb_to_html (self .state .nb ),
93+ "nb_runtime_errors" : self .state .notebook_runtime_errors ,
94+ },
95+ info = {
96+ "eval_mode" : self .eval_mode ,
97+ "language" : self .state .language ,
98+ "problem" : self .problem ,
99+ "problem_id" : self .problem_id ,
100+ },
101+ )
102+
83103 @classmethod
84- def eval_from_task (cls , task : str , gcs_artifact_path : str ) -> "DataAnalysisEnv" :
104+ def eval_from_task (
105+ cls , task : str , gcs_artifact_path : str , environment_config : str | None = None
106+ ) -> "DataAnalysisEnv" :
85107 """
86108 Used for evaluations via crow jobs.
87109
@@ -90,7 +112,6 @@ def eval_from_task(cls, task: str, gcs_artifact_path: str) -> "DataAnalysisEnv":
90112 gcs_artifact_path: The path to the GCS artifact – required for evaluation on crow jobs
91113 """
92114 logger .info ("Using the eval_from_task method" )
93-
94115 # Create temporary directory in GCP mounted storage volume
95116 task_hash = hashlib .sha256 (task .encode ()).hexdigest ()
96117 trajectory_path = cfg .DATA_STORAGE_PATH / f"{ task_hash } -{ time .time ()} "
@@ -124,45 +145,44 @@ def eval_from_task(cls, task: str, gcs_artifact_path: str) -> "DataAnalysisEnv":
124145
125146 @classmethod
126147 def from_task (
127- cls , task : str , gcs_artifact_path : str | None = None
148+ cls ,
149+ task : str ,
150+ gcs_artifact_path : str | None = None ,
151+ environment_config : str | None = None ,
128152 ) -> "DataAnalysisEnv" :
129153 """
130154 Perform data analysis on a user query.
131155
132156 Args:
133- task: The user query structured as <data_path> | <query>
134-
135- eg "CaspuleFolder-a7812fg | How many genes are differentially expressed between the two conditions?"
157+ task: The user query
158+ gcs_artifact_path: The path to the GCS artifact – required for evaluation on crow jobs
159+ environment_config: A JSON string of environment configuration
136160 """
137161 logger .info ("User task: %s" , task )
138162 logger .info ("GCS artifact path: %s" , gcs_artifact_path )
163+ logger .info ("environment_config: %s" , environment_config )
139164 if cfg .EVAL :
140165 return cls .eval_from_task (task , gcs_artifact_path ) # type: ignore
141166
142167 if (
143- gcs_artifact_path
168+ not gcs_artifact_path
144169 ): # The files are already in the GCS bucket in a job-specific directory
145- trajectory_path = cfg .DATA_STORAGE_PATH / gcs_artifact_path
146- nb_path = trajectory_path / NBEnvironment .NOTEBOOK_NAME
147- query = task
148- task_hash = gcs_artifact_path
170+ raise NotImplementedError (
171+ "Running crow jobs without gcs_artifact_path is not supported"
172+ )
173+ trajectory_path = cfg .DATA_STORAGE_PATH / gcs_artifact_path
174+ nb_path = trajectory_path / NBEnvironment .NOTEBOOK_NAME
175+ query = task
176+ task_hash = gcs_artifact_path
177+ if environment_config :
178+ kwargs = {
179+ k : v
180+ for k , v in json .loads (environment_config ).items ()
181+ if k in cfg .VALID_FROM_TASK_KWARGS
182+ }
149183 else :
150- # Extract data path and query from task
151- data_path , query = task .split ("|" )
152- # Hash the task to get a unique identifier
153- task_hash = hashlib .sha256 (task .encode ()).hexdigest ()
154- # Create temporary directory in GCP mounted storage volume
155- trajectory_path = cfg .DATA_STORAGE_PATH / f"{ task_hash } -{ time .time ()} "
156- trajectory_path .mkdir (parents = True , exist_ok = True )
157- nb_path = trajectory_path / NBEnvironment .NOTEBOOK_NAME
158- # Copy task data to trajectory path
159- for item in (cfg .DATA_STORAGE_PATH / data_path ).iterdir ():
160- if item .is_file ():
161- shutil .copy2 (item , trajectory_path )
162- elif item .is_dir ():
163- shutil .copytree (
164- item , trajectory_path / item .name , dirs_exist_ok = True
165- )
184+ kwargs = {}
185+ logger .info ("Filtered kwargs: %s" , kwargs )
166186
167187 # Augment incoming task with CoT instructions
168188 augmented_task = f"""\
@@ -215,23 +235,5 @@ def from_task(
215235 language = language ,
216236 system_prompt = prompts .CAPSULE_SYSTEM_PROMPT_QUERY ,
217237 use_tmp_work_dir = False ,
218- )
219-
220- def export_frame (self ) -> Frame :
221- return Frame (
222- state = {
223- "last_action" : self .state .actions [- 1 ],
224- "answer" : self .state .answer ,
225- "done" : self .state .done ,
226- "total_reward" : self .state .total_reward ,
227- "nb_state" : self .state .nb ,
228- "nb_state_html" : nb_to_html (self .state .nb ),
229- "nb_runtime_errors" : self .state .notebook_runtime_errors ,
230- },
231- info = {
232- "eval_mode" : self .eval_mode ,
233- "language" : self .state .language ,
234- "problem" : self .problem ,
235- "problem_id" : self .problem_id ,
236- },
238+ ** kwargs ,
237239 )
0 commit comments