Skip to content

Commit 1a29946

Browse files
authored
Enable platform evals (#3)
1 parent 08c7efb commit 1a29946

File tree

8 files changed

+272
-76
lines changed

8 files changed

+272
-76
lines changed

src/fhda/config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,3 +19,5 @@
1919
DATA_STORAGE_PATH = Path("storage")
2020
else:
2121
DATA_STORAGE_PATH = Path("/storage")
22+
23+
EVAL = bool(os.getenv("EVAL", "false").lower() == "true")

src/fhda/data_analysis_env.py

Lines changed: 43 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import hashlib
2-
import json
32
import logging
43
import shutil
54
from typing import Any, cast
@@ -10,7 +9,6 @@
109
Message,
1110
Messages,
1211
Tool,
13-
eval_answer,
1412
)
1513

1614
from .notebook_env import NBEnvironment
@@ -33,7 +31,7 @@ def __init__(
3331
answer: str | int | float | None = None, # noqa: PYI041
3432
system_prompt: str | None = None,
3533
correct_reward: float = 1.0,
36-
eval_mode: EvalAnswerMode,
34+
eval_mode: EvalAnswerMode | None = None,
3735
metadata: dict[str, Any] | None = None, # used for NBEvalExpt
3836
mcqs: list[MultipleChoiceQuestion] | None = None,
3937
**kwargs,
@@ -66,7 +64,7 @@ async def reset(self) -> tuple[Messages, list[Tool]]:
6664

6765
return init_obs, tools
6866

69-
async def submit_answer(self, answer: str | float | dict[str, Any] | None) -> str: # type: ignore[override]
67+
async def submit_answer(self, answer: str) -> str: # type: ignore[override]
7068
"""Submit an answer to the problem.
7169
7270
Note that this tool may only be called once and ends the episode.
@@ -79,75 +77,50 @@ async def submit_answer(self, answer: str | float | dict[str, Any] | None) -> st
7977
self.state.done = True
8078
logger.info("Submitting answer and closing environment")
8179
await self.close()
82-
correct = False
8380
logger.info("Answer: %s", answer)
81+
return answer
8482

85-
if self.eval_mode is None:
86-
return CORRECT_MSG
87-
88-
if isinstance(self.answer, int):
89-
try:
90-
answer = int(answer) # type: ignore[arg-type]
91-
except ValueError:
92-
pass
93-
else:
94-
correct = answer == self.answer
83+
@classmethod
84+
def eval_from_task(cls, task: str, gcs_artifact_path: str) -> "DataAnalysisEnv":
85+
"""
86+
Used for evaluations via crow jobs.
9587
96-
elif isinstance(self.answer, float):
97-
try:
98-
answer = float(answer) # type: ignore[arg-type]
99-
except ValueError:
100-
pass
101-
else:
102-
correct = abs(answer - self.answer) < 1e-4 * self.answer
88+
Args:
89+
task: The user query structured as <data_path> | <query>
90+
gcs_artifact_path: The path to the GCS artifact – required for evaluation on crow jobs
91+
"""
92+
logger.info("Using the eval_from_task method")
93+
94+
# Create temporary directory in GCP mounted storage volume
95+
task_hash = hashlib.sha256(task.encode()).hexdigest()
96+
trajectory_path = cfg.DATA_STORAGE_PATH / f"{task_hash}-{time.time()}"
97+
trajectory_path.mkdir(parents=True, exist_ok=True)
98+
logger.info("Trajectory path: %s", trajectory_path)
99+
nb_path = trajectory_path / NBEnvironment.NOTEBOOK_NAME
100+
# Copy task data to trajectory path
101+
for item in (cfg.DATA_STORAGE_PATH / gcs_artifact_path).iterdir():
102+
if item.is_file():
103+
shutil.copy2(item, trajectory_path)
104+
elif item.is_dir():
105+
shutil.copytree(item, trajectory_path / item.name, dirs_exist_ok=True)
103106

104-
elif isinstance(self.answer, str):
105-
correct = bool(
106-
await eval_answer(
107-
proposed=str(answer),
108-
correct=str(self.answer),
109-
question=self.problem,
110-
eval_mode=self.eval_mode,
111-
)
107+
language = NBLanguage.PYTHON # In future, this should be a hyperparameter
108+
if trajectory_path.exists():
109+
logger.info(
110+
"Files in directory: %s", [f.name for f in trajectory_path.iterdir()]
112111
)
113-
elif isinstance(self.answer, dict): # This is for mcqs and open questions
114-
# Check if answer is a json string
115-
if isinstance(answer, str): # type: ignore[unreachable]
116-
# Process json into dictionary
117-
try:
118-
processed_answer = json.loads(answer)
119-
except json.JSONDecodeError:
120-
return INCORRECT_MSG
121-
else:
122-
processed_answer = answer if isinstance(answer, dict) else {}
123112

124-
# Loop through each question and answer
125-
for question_id, agent_answer in processed_answer.items():
126-
try:
127-
ideal_answer = self.answer[question_id]
128-
question = next(
129-
q
130-
for q in self.mcqs
131-
if q.question_id.lower() == question_id.lower()
132-
)
133-
correct = bool(
134-
await eval_answer(
135-
proposed=str(agent_answer),
136-
correct=str(ideal_answer),
137-
question=question,
138-
eval_mode=self.eval_mode,
139-
)
140-
)
141-
self.question_rewards[question_id] = correct
142-
except KeyError:
143-
self.question_rewards[question_id] = 0
144-
average_reward = sum(self.question_rewards.values()) / len(self.mcqs)
145-
correct = round(average_reward) == 1.0
146-
147-
if correct:
148-
self.state.total_reward += self.correct_reward
149-
return CORRECT_MSG
150-
return INCORRECT_MSG
113+
return cls(
114+
problem_id=f"data-analysis-task-{task_hash}",
115+
problem=task,
116+
# Using exact just because I won't ultimately be using env evaluation
117+
eval_mode=EvalAnswerMode.EXACT,
118+
nb_path=nb_path,
119+
work_dir=trajectory_path,
120+
language=language,
121+
system_prompt=prompts.CAPSULE_SYSTEM_PROMPT_OPEN,
122+
use_tmp_work_dir=False,
123+
)
151124

152125
@classmethod
153126
def from_task(
@@ -163,6 +136,8 @@ def from_task(
163136
"""
164137
logger.info("User task: %s", task)
165138
logger.info("GCS artifact path: %s", gcs_artifact_path)
139+
if cfg.EVAL:
140+
return cls.eval_from_task(task, gcs_artifact_path) # type: ignore
166141

167142
if (
168143
gcs_artifact_path
@@ -251,6 +226,7 @@ def export_frame(self) -> Frame:
251226
"total_reward": self.state.total_reward,
252227
"nb_state": self.state.nb,
253228
"nb_state_html": nb_to_html(self.state.nb),
229+
"nb_runtime_errors": self.state.notebook_runtime_errors,
254230
},
255231
info={
256232
"eval_mode": self.eval_mode,

src/fhda/notebook_env.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ def __init__(
5353
# Add initial cell with rpy2 extension load
5454
nbformat.v4.new_code_cell(source="%load_ext rpy2.ipython")
5555
self.nb.metadata.kernelspec = self.language.make_kernelspec()
56+
self.notebook_runtime_errors: list[str] = []
5657

5758
def save_nb(self):
5859
"""Saves the notebook to disk."""
@@ -248,11 +249,10 @@ def list_workdir(self) -> str:
248249
249250
The contents is represented as a nested JSON dictionary.
250251
"""
251-
logger.info("Listing working directory: %s", self.state.work_dir)
252252
return json.dumps(self._list_dir(self.state.work_dir), indent=2)
253253

254254
# allowing int so that agent doesn't try to force to float
255-
def submit_answer(self, answer: str | float | int) -> str: # noqa: PYI041
255+
def submit_answer(self, answer: str) -> str: # noqa: PYI041
256256
"""Submit an answer to the problem.
257257
258258
Note that this tool may only be called once and ends the episode.
@@ -329,9 +329,11 @@ async def _run_notebook_local(self) -> str:
329329
"""Run notebook using local kernel."""
330330
client = self.state.kernel_manager.client()
331331
client.start_channels()
332-
working_dir_files = list(self.state.work_dir.glob("**/*"))
333-
logger.info(f"Files in working directory: {working_dir_files}")
334-
await utils.nbformat_run_notebook(cells=self.state.cells, client=client)
332+
error_messages = await utils.nbformat_run_notebook(
333+
cells=self.state.cells, client=client
334+
)
335+
if error_messages:
336+
self.state.notebook_runtime_errors.extend(error_messages)
335337
self.state.save_nb()
336338
logger.debug("Saved notebook to disk")
337339
self.state.reload_nb()

src/fhda/prompts.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,19 @@
6161
- The first cell has already been loaded with %load_ext rpy2.ipython so you can use %%R cells from the second cell onwards
6262
"""
6363

64+
GENERAL_NOTEBOOK_GUIDELINES_R = """
65+
General Guidelines:
66+
- Write small to medium-sized cells for easier debugging.
67+
- Edit existing cells by their index number when fixing bugs, rather than creating new ones.
68+
- Check dataframe shapes before printing. Use head() for large dataframes.
69+
- Ensure each cell executes successfully before moving to the next.
70+
- Assume you already have the packages you need installed and only install new ones if you receive errors.
71+
- If you need to install packages, use mamba or conda.
72+
IMPORTANT: Use R cells for all analysis.
73+
- All cells are by default R cells.
74+
"""
75+
76+
6477
AVOID_IMAGES = """
6578
AVOID USING PLOTS/IMAGES. USE TABLES AND PRINT OUTPUTS INSTEAD AS MUCH AS POSSIBLE.
6679
"""

src/fhda/utils.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ def encode_image_to_base64(image: str) -> str:
149149

150150
async def nbformat_run_notebook(
151151
cells: Iterable[nbformat.NotebookNode], client: "AsyncKernelClient"
152-
) -> None:
152+
) -> list[str]:
153153
"""Execute notebook cells using a kernel client and collect outputs.
154154
155155
Args:
@@ -158,7 +158,11 @@ async def nbformat_run_notebook(
158158
159159
Raises:
160160
ValueError: If there is an error executing a cell
161+
162+
Returns:
163+
List of error messages from cells that raised an error
161164
"""
165+
error_messages = []
162166
try:
163167
logger.debug("Beginning cell execution")
164168
for idx, cell in enumerate(cells):
@@ -221,8 +225,11 @@ async def nbformat_run_notebook(
221225
f"Value: {content.get('evalue', 'No error message')}\n"
222226
f"Traceback: {content.get('traceback', [])}"
223227
)
228+
error_messages.append(
229+
f"Cell {idx}: {content.get('evalue', '')}"
230+
)
224231
logger.error(error_msg)
225-
raise ValueError(error_msg)
232+
# raise ValueError(error_msg)
226233
elif (
227234
msg_type == "status"
228235
and content["execution_state"] == "idle"
@@ -233,6 +240,8 @@ async def nbformat_run_notebook(
233240
logger.debug("Stopping kernel channels")
234241
client.stop_channels()
235242

243+
return error_messages
244+
236245

237246
async def exec_cmd(
238247
container: DockerContainer, exec_command: list[str], timeout: float | None = 300

src/scripts/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from pydantic import BaseModel, ConfigDict
1717
from pydantic_core import PydanticUndefined
1818

19-
from .logging import configure_logs
19+
from .expt_logging import configure_logs
2020

2121
logger = logging.getLogger(__name__)
2222

src/scripts/deploy.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,17 @@
99
FramePath,
1010
AuthType,
1111
)
12+
from crow_client.models.app import TaskQueuesConfig
13+
14+
EVAL = True
1215

1316
ENV_VARS = {
1417
"OPENAI_API_KEY": os.environ["OPENAI_API_KEY"],
1518
"ANTHROPIC_API_KEY": os.environ["ANTHROPIC_API_KEY"],
1619
"USE_R": "false",
1720
"USE_DOCKER": "false",
1821
"STAGE": "DEV",
22+
"EVAL": "true" if EVAL else "false",
1923
}
2024

2125
CONTAINER_CONFIG = DockerContainerConfiguration(cpu="2", memory="4Gi")
@@ -29,13 +33,18 @@
2933
CrowDeploymentConfig(
3034
requirements_path=Path("pyproject.toml"),
3135
path=Path("src"),
32-
name="data-analysis-crow",
36+
name="bixbench-crow" if EVAL else "data-analysis-crow",
3337
environment="src.fhda.data_analysis_env.DataAnalysisEnv",
3438
environment_variables=ENV_VARS,
3539
agent="ldp.agent.ReActAgent",
3640
container_config=CONTAINER_CONFIG,
3741
force=True,
3842
frame_paths=frame_paths,
43+
timeout=1200,
44+
task_queues_config=TaskQueuesConfig(
45+
name="bixbench-crow" if EVAL else "data-analysis-crow",
46+
max_running_jobs=300,
47+
),
3948
),
4049
]
4150

0 commit comments

Comments
 (0)