Skip to content

Commit c58e207

Browse files
authored
Remove from task eval and use env config variables instead (#6)
1 parent 72e51b3 commit c58e207

File tree

3 files changed

+44
-98
lines changed

3 files changed

+44
-98
lines changed

src/fhda/config.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,4 @@
2020
else:
2121
DATA_STORAGE_PATH = Path("/storage")
2222

23-
EVAL = bool(os.getenv("EVAL", "false").lower() == "true")
24-
2523
VALID_FROM_TASK_KWARGS = ["run_notebook_on_edit"]

src/fhda/data_analysis_env.py

Lines changed: 41 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import hashlib
22
import logging
33
import shutil
4-
import json
54
from typing import Any, cast
65
import time
76
from aviary.core import (
@@ -100,55 +99,12 @@ def export_frame(self) -> Frame:
10099
},
101100
)
102101

103-
@classmethod
104-
def eval_from_task(
105-
cls, task: str, gcs_artifact_path: str, environment_config: str | None = None
106-
) -> "DataAnalysisEnv":
107-
"""
108-
Used for evaluations via crow jobs.
109-
110-
Args:
111-
task: The user query structured as <data_path> | <query>
112-
gcs_artifact_path: The path to the GCS artifact – required for evaluation on crow jobs
113-
"""
114-
logger.info("Using the eval_from_task method")
115-
# Create temporary directory in GCP mounted storage volume
116-
task_hash = hashlib.sha256(task.encode()).hexdigest()
117-
trajectory_path = cfg.DATA_STORAGE_PATH / f"{task_hash}-{time.time()}"
118-
trajectory_path.mkdir(parents=True, exist_ok=True)
119-
logger.info("Trajectory path: %s", trajectory_path)
120-
nb_path = trajectory_path / NBEnvironment.NOTEBOOK_NAME
121-
# Copy task data to trajectory path
122-
for item in (cfg.DATA_STORAGE_PATH / gcs_artifact_path).iterdir():
123-
if item.is_file():
124-
shutil.copy2(item, trajectory_path)
125-
elif item.is_dir():
126-
shutil.copytree(item, trajectory_path / item.name, dirs_exist_ok=True)
127-
128-
language = NBLanguage.PYTHON # In future, this should be a hyperparameter
129-
if trajectory_path.exists():
130-
logger.info(
131-
"Files in directory: %s", [f.name for f in trajectory_path.iterdir()]
132-
)
133-
134-
return cls(
135-
problem_id=f"data-analysis-task-{task_hash}",
136-
problem=task,
137-
# Using exact just because I won't ultimately be using env evaluation
138-
eval_mode=EvalAnswerMode.EXACT,
139-
nb_path=nb_path,
140-
work_dir=trajectory_path,
141-
language=language,
142-
system_prompt=prompts.CAPSULE_SYSTEM_PROMPT_OPEN,
143-
use_tmp_work_dir=False,
144-
)
145-
146102
@classmethod
147103
def from_task(
148104
cls,
149105
task: str,
150106
gcs_artifact_path: str | None = None,
151-
environment_config: str | None = None,
107+
environment_config: dict[str, Any] | None = None,
152108
) -> "DataAnalysisEnv":
153109
"""
154110
Perform data analysis on a user query.
@@ -161,74 +117,67 @@ def from_task(
161117
logger.info("User task: %s", task)
162118
logger.info("GCS artifact path: %s", gcs_artifact_path)
163119
logger.info("environment_config: %s", environment_config)
164-
if cfg.EVAL:
165-
return cls.eval_from_task(task, gcs_artifact_path) # type: ignore
166120

167121
if (
168122
not gcs_artifact_path
169-
): # The files are already in the GCS bucket in a job-specific directory
123+
): # Platform jobs should always be associated with data from a GCS bucket
170124
raise NotImplementedError(
171125
"Running crow jobs without gcs_artifact_path is not supported"
172126
)
173-
trajectory_path = cfg.DATA_STORAGE_PATH / gcs_artifact_path
174-
nb_path = trajectory_path / NBEnvironment.NOTEBOOK_NAME
175-
query = task
176-
task_hash = gcs_artifact_path
127+
177128
if environment_config:
178129
kwargs = {
179130
k: v
180-
for k, v in json.loads(environment_config).items()
131+
for k, v in environment_config.items()
181132
if k in cfg.VALID_FROM_TASK_KWARGS
182133
}
183134
else:
184135
kwargs = {}
185136
logger.info("Filtered kwargs: %s", kwargs)
186-
187-
# Augment incoming task with CoT instructions
188-
augmented_task = f"""\
189-
Here is the user query to address:
190-
191-
<query>
192-
{query}
193-
</query>
194-
195-
{prompts.CHAIN_OF_THOUGHT_AGNOSTIC}
196-
{prompts.GENERAL_NOTEBOOK_GUIDELINES}"""
137+
task_hash = hashlib.sha256(task.encode()).hexdigest()
138+
if kwargs.get("eval", False):
139+
# Create a temporary directory in GCP mounted storage volume
140+
trajectory_path = cfg.DATA_STORAGE_PATH / f"{task_hash}-{time.time()}"
141+
trajectory_path.mkdir(parents=True, exist_ok=True)
142+
for item in (cfg.DATA_STORAGE_PATH / gcs_artifact_path).iterdir():
143+
if item.is_file():
144+
shutil.copy2(item, trajectory_path)
145+
elif item.is_dir():
146+
shutil.copytree(
147+
item, trajectory_path / item.name, dirs_exist_ok=True
148+
)
149+
else:
150+
# Use the GCP folder created when uploading the data via the platform
151+
trajectory_path = cfg.DATA_STORAGE_PATH / gcs_artifact_path
152+
# Augment incoming user query with CoT instructions
153+
task = (
154+
f"Here is the user query to address:\n"
155+
f"<query>\n"
156+
f"{task}\n"
157+
f"</query>\n"
158+
f"{prompts.CHAIN_OF_THOUGHT_AGNOSTIC}\n"
159+
f"{prompts.GENERAL_NOTEBOOK_GUIDELINES}"
160+
)
161+
logger.info("Trajectory path: %s", trajectory_path)
162+
nb_path = trajectory_path / NBEnvironment.NOTEBOOK_NAME
197163

198164
language = NBLanguage.PYTHON # In future, this should be a hyperparameter
199165
if language == NBLanguage.R:
200-
augmented_task += f"\n{prompts.R_OUTPUT_RECOMMENDATION_PROMPT}"
201-
202-
# Log all parameters being passed to constructor
203-
logger.info(
204-
"Creating DataAnalysisEnv with parameters: "
205-
"problem_id=data-analysis-task-%s, "
206-
"problem=%s, "
207-
"eval_mode=%s, "
208-
"nb_path=%s, "
209-
"work_dir=%s, "
210-
"language=%s, "
211-
"system_prompt=%s, "
212-
"use_tmp_work_dir=%s, "
213-
"gcs_artifact_path=%s",
214-
task_hash,
215-
augmented_task,
216-
EvalAnswerMode.LLM,
217-
nb_path,
218-
trajectory_path,
219-
language,
220-
prompts.CAPSULE_SYSTEM_PROMPT_QUERY,
221-
False,
222-
gcs_artifact_path,
223-
)
166+
task += f"\n{prompts.R_OUTPUT_RECOMMENDATION_PROMPT}"
167+
224168
if trajectory_path.exists():
225-
logger.info(
226-
"Files in directory: %s", [f.name for f in trajectory_path.iterdir()]
227-
)
169+
files = list(trajectory_path.iterdir())
170+
logger.info("Files in directory: %s", [f.name for f in files])
171+
if not files:
172+
raise ValueError(
173+
f"No files found in trajectory path: {trajectory_path}"
174+
)
175+
else:
176+
raise ValueError(f"Trajectory path does not exist: {trajectory_path}")
228177

229178
return cls(
230179
problem_id=f"data-analysis-task-{task_hash}",
231-
problem=augmented_task,
180+
problem=task,
232181
eval_mode=EvalAnswerMode.LLM,
233182
nb_path=nb_path,
234183
work_dir=trajectory_path,

src/scripts/deploy.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,7 @@
1818
"ANTHROPIC_API_KEY": os.environ["ANTHROPIC_API_KEY"],
1919
"USE_R": "false",
2020
"USE_DOCKER": "false",
21-
"STAGE": "DEV",
22-
"EVAL": "true" if EVAL else "false",
21+
"STAGE": "PROD",
2322
}
2423

2524
CONTAINER_CONFIG = DockerContainerConfiguration(cpu="2", memory="4Gi")
@@ -33,7 +32,7 @@
3332
CrowDeploymentConfig(
3433
requirements_path=Path("pyproject.toml"),
3534
path=Path("src"),
36-
name="bixbench-crow2" if EVAL else "data-analysis-crow",
35+
name="data-analysis-crow",
3736
environment="src.fhda.data_analysis_env.DataAnalysisEnv",
3837
environment_variables=ENV_VARS,
3938
agent="ldp.agent.ReActAgent",
@@ -42,7 +41,7 @@
4241
frame_paths=frame_paths,
4342
timeout=3600,
4443
task_queues_config=TaskQueuesConfig(
45-
name="bixbench-crow2" if EVAL else "data-analysis-crow",
44+
name="data-analysis-crow",
4645
max_running_jobs=300,
4746
),
4847
),

0 commit comments

Comments
 (0)