11import hashlib
22import logging
33import shutil
4- import json
54from typing import Any , cast
65import time
76from aviary .core import (
@@ -100,55 +99,12 @@ def export_frame(self) -> Frame:
10099 },
101100 )
102101
103- @classmethod
104- def eval_from_task (
105- cls , task : str , gcs_artifact_path : str , environment_config : str | None = None
106- ) -> "DataAnalysisEnv" :
107- """
108- Used for evaluations via crow jobs.
109-
110- Args:
111- task: The user query structured as <data_path> | <query>
112- gcs_artifact_path: The path to the GCS artifact – required for evaluation on crow jobs
113- """
114- logger .info ("Using the eval_from_task method" )
115- # Create temporary directory in GCP mounted storage volume
116- task_hash = hashlib .sha256 (task .encode ()).hexdigest ()
117- trajectory_path = cfg .DATA_STORAGE_PATH / f"{ task_hash } -{ time .time ()} "
118- trajectory_path .mkdir (parents = True , exist_ok = True )
119- logger .info ("Trajectory path: %s" , trajectory_path )
120- nb_path = trajectory_path / NBEnvironment .NOTEBOOK_NAME
121- # Copy task data to trajectory path
122- for item in (cfg .DATA_STORAGE_PATH / gcs_artifact_path ).iterdir ():
123- if item .is_file ():
124- shutil .copy2 (item , trajectory_path )
125- elif item .is_dir ():
126- shutil .copytree (item , trajectory_path / item .name , dirs_exist_ok = True )
127-
128- language = NBLanguage .PYTHON # In future, this should be a hyperparameter
129- if trajectory_path .exists ():
130- logger .info (
131- "Files in directory: %s" , [f .name for f in trajectory_path .iterdir ()]
132- )
133-
134- return cls (
135- problem_id = f"data-analysis-task-{ task_hash } " ,
136- problem = task ,
137- # Using exact just because I won't ultimately be using env evaluation
138- eval_mode = EvalAnswerMode .EXACT ,
139- nb_path = nb_path ,
140- work_dir = trajectory_path ,
141- language = language ,
142- system_prompt = prompts .CAPSULE_SYSTEM_PROMPT_OPEN ,
143- use_tmp_work_dir = False ,
144- )
145-
146102 @classmethod
147103 def from_task (
148104 cls ,
149105 task : str ,
150106 gcs_artifact_path : str | None = None ,
151- environment_config : str | None = None ,
107+ environment_config : dict [ str , Any ] | None = None ,
152108 ) -> "DataAnalysisEnv" :
153109 """
154110 Perform data analysis on a user query.
@@ -161,74 +117,67 @@ def from_task(
161117 logger .info ("User task: %s" , task )
162118 logger .info ("GCS artifact path: %s" , gcs_artifact_path )
163119 logger .info ("environment_config: %s" , environment_config )
164- if cfg .EVAL :
165- return cls .eval_from_task (task , gcs_artifact_path ) # type: ignore
166120
167121 if (
168122 not gcs_artifact_path
169- ): # The files are already in the GCS bucket in a job-specific directory
123+ ): # Platform jobs should always be associated with data from a GCS bucket
170124 raise NotImplementedError (
171125 "Running crow jobs without gcs_artifact_path is not supported"
172126 )
173- trajectory_path = cfg .DATA_STORAGE_PATH / gcs_artifact_path
174- nb_path = trajectory_path / NBEnvironment .NOTEBOOK_NAME
175- query = task
176- task_hash = gcs_artifact_path
127+
177128 if environment_config :
178129 kwargs = {
179130 k : v
180- for k , v in json . loads ( environment_config ) .items ()
131+ for k , v in environment_config .items ()
181132 if k in cfg .VALID_FROM_TASK_KWARGS
182133 }
183134 else :
184135 kwargs = {}
185136 logger .info ("Filtered kwargs: %s" , kwargs )
186-
187- # Augment incoming task with CoT instructions
188- augmented_task = f"""\
189- Here is the user query to address:
190-
191- <query>
192- { query }
193- </query>
194-
195- { prompts .CHAIN_OF_THOUGHT_AGNOSTIC }
196- { prompts .GENERAL_NOTEBOOK_GUIDELINES } """
137+ task_hash = hashlib .sha256 (task .encode ()).hexdigest ()
138+ if kwargs .get ("eval" , False ):
139+ # Create a temporary directory in GCP mounted storage volume
140+ trajectory_path = cfg .DATA_STORAGE_PATH / f"{ task_hash } -{ time .time ()} "
141+ trajectory_path .mkdir (parents = True , exist_ok = True )
142+ for item in (cfg .DATA_STORAGE_PATH / gcs_artifact_path ).iterdir ():
143+ if item .is_file ():
144+ shutil .copy2 (item , trajectory_path )
145+ elif item .is_dir ():
146+ shutil .copytree (
147+ item , trajectory_path / item .name , dirs_exist_ok = True
148+ )
149+ else :
150+ # Use the GCP folder created when uploading the data via the platform
151+ trajectory_path = cfg .DATA_STORAGE_PATH / gcs_artifact_path
152+ # Augment incoming user query with CoT instructions
153+ task = (
154+ f"Here is the user query to address:\n "
155+ f"<query>\n "
156+ f"{ task } \n "
157+ f"</query>\n "
158+ f"{ prompts .CHAIN_OF_THOUGHT_AGNOSTIC } \n "
159+ f"{ prompts .GENERAL_NOTEBOOK_GUIDELINES } "
160+ )
161+ logger .info ("Trajectory path: %s" , trajectory_path )
162+ nb_path = trajectory_path / NBEnvironment .NOTEBOOK_NAME
197163
198164 language = NBLanguage .PYTHON # In future, this should be a hyperparameter
199165 if language == NBLanguage .R :
200- augmented_task += f"\n { prompts .R_OUTPUT_RECOMMENDATION_PROMPT } "
201-
202- # Log all parameters being passed to constructor
203- logger .info (
204- "Creating DataAnalysisEnv with parameters: "
205- "problem_id=data-analysis-task-%s, "
206- "problem=%s, "
207- "eval_mode=%s, "
208- "nb_path=%s, "
209- "work_dir=%s, "
210- "language=%s, "
211- "system_prompt=%s, "
212- "use_tmp_work_dir=%s, "
213- "gcs_artifact_path=%s" ,
214- task_hash ,
215- augmented_task ,
216- EvalAnswerMode .LLM ,
217- nb_path ,
218- trajectory_path ,
219- language ,
220- prompts .CAPSULE_SYSTEM_PROMPT_QUERY ,
221- False ,
222- gcs_artifact_path ,
223- )
166+ task += f"\n { prompts .R_OUTPUT_RECOMMENDATION_PROMPT } "
167+
224168 if trajectory_path .exists ():
225- logger .info (
226- "Files in directory: %s" , [f .name for f in trajectory_path .iterdir ()]
227- )
169+ files = list (trajectory_path .iterdir ())
170+ logger .info ("Files in directory: %s" , [f .name for f in files ])
171+ if not files :
172+ raise ValueError (
173+ f"No files found in trajectory path: { trajectory_path } "
174+ )
175+ else :
176+ raise ValueError (f"Trajectory path does not exist: { trajectory_path } " )
228177
229178 return cls (
230179 problem_id = f"data-analysis-task-{ task_hash } " ,
231- problem = augmented_task ,
180+ problem = task ,
232181 eval_mode = EvalAnswerMode .LLM ,
233182 nb_path = nb_path ,
234183 work_dir = trajectory_path ,
0 commit comments