Skip to content

Commit 72e51b3

Browse files
authored
Platform eval (#5)
1 parent 63acd9a commit 72e51b3

File tree

8 files changed

+624
-213
lines changed

8 files changed

+624
-213
lines changed

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ dependencies = [
2222
"google-auth==2.38.0",
2323
"google-cloud-storage==3.0.0",
2424
"google-cloud-secret-manager==2.23.0",
25-
"crow-client==0.3.6",
25+
"crow-client>=0.3.13",
2626
"jupyter==1.1.1",
2727
"nbconvert==7.16.6",
2828
"notebook==7.3.2",
@@ -52,4 +52,4 @@ run_expt = 'scripts.configurable:_run_expt'
5252
package-dir = {"" = "src"}
5353

5454
[tool.setuptools.packages.find]
55-
where = ["src"]
55+
where = ["src"]

src/fhda/config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,3 +21,5 @@
2121
DATA_STORAGE_PATH = Path("/storage")
2222

2323
EVAL = bool(os.getenv("EVAL", "false").lower() == "true")
24+
25+
VALID_FROM_TASK_KWARGS = ["run_notebook_on_edit"]

src/fhda/data_analysis_env.py

Lines changed: 48 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import hashlib
22
import logging
33
import shutil
4+
import json
45
from typing import Any, cast
56
import time
67
from aviary.core import (
@@ -80,8 +81,29 @@ async def submit_answer(self, answer: str) -> str: # type: ignore[override]
8081
logger.info("Answer: %s", answer)
8182
return answer
8283

84+
def export_frame(self) -> Frame:
85+
return Frame(
86+
state={
87+
"last_action": self.state.actions[-1],
88+
"answer": self.state.answer,
89+
"done": self.state.done,
90+
"total_reward": self.state.total_reward,
91+
"nb_state": self.state.nb,
92+
"nb_state_html": nb_to_html(self.state.nb),
93+
"nb_runtime_errors": self.state.notebook_runtime_errors,
94+
},
95+
info={
96+
"eval_mode": self.eval_mode,
97+
"language": self.state.language,
98+
"problem": self.problem,
99+
"problem_id": self.problem_id,
100+
},
101+
)
102+
83103
@classmethod
84-
def eval_from_task(cls, task: str, gcs_artifact_path: str) -> "DataAnalysisEnv":
104+
def eval_from_task(
105+
cls, task: str, gcs_artifact_path: str, environment_config: str | None = None
106+
) -> "DataAnalysisEnv":
85107
"""
86108
Used for evaluations via crow jobs.
87109
@@ -90,7 +112,6 @@ def eval_from_task(cls, task: str, gcs_artifact_path: str) -> "DataAnalysisEnv":
90112
gcs_artifact_path: The path to the GCS artifact – required for evaluation on crow jobs
91113
"""
92114
logger.info("Using the eval_from_task method")
93-
94115
# Create temporary directory in GCP mounted storage volume
95116
task_hash = hashlib.sha256(task.encode()).hexdigest()
96117
trajectory_path = cfg.DATA_STORAGE_PATH / f"{task_hash}-{time.time()}"
@@ -124,45 +145,44 @@ def eval_from_task(cls, task: str, gcs_artifact_path: str) -> "DataAnalysisEnv":
124145

125146
@classmethod
126147
def from_task(
127-
cls, task: str, gcs_artifact_path: str | None = None
148+
cls,
149+
task: str,
150+
gcs_artifact_path: str | None = None,
151+
environment_config: str | None = None,
128152
) -> "DataAnalysisEnv":
129153
"""
130154
Perform data analysis on a user query.
131155
132156
Args:
133-
task: The user query structured as <data_path> | <query>
134-
135-
eg "CaspuleFolder-a7812fg | How many genes are differentially expressed between the two conditions?"
157+
task: The user query
158+
gcs_artifact_path: The path to the GCS artifact – required for evaluation on crow jobs
159+
environment_config: A JSON string of environment configuration
136160
"""
137161
logger.info("User task: %s", task)
138162
logger.info("GCS artifact path: %s", gcs_artifact_path)
163+
logger.info("environment_config: %s", environment_config)
139164
if cfg.EVAL:
140165
return cls.eval_from_task(task, gcs_artifact_path) # type: ignore
141166

142167
if (
143-
gcs_artifact_path
168+
not gcs_artifact_path
144169
): # The files are already in the GCS bucket in a job-specific directory
145-
trajectory_path = cfg.DATA_STORAGE_PATH / gcs_artifact_path
146-
nb_path = trajectory_path / NBEnvironment.NOTEBOOK_NAME
147-
query = task
148-
task_hash = gcs_artifact_path
170+
raise NotImplementedError(
171+
"Running crow jobs without gcs_artifact_path is not supported"
172+
)
173+
trajectory_path = cfg.DATA_STORAGE_PATH / gcs_artifact_path
174+
nb_path = trajectory_path / NBEnvironment.NOTEBOOK_NAME
175+
query = task
176+
task_hash = gcs_artifact_path
177+
if environment_config:
178+
kwargs = {
179+
k: v
180+
for k, v in json.loads(environment_config).items()
181+
if k in cfg.VALID_FROM_TASK_KWARGS
182+
}
149183
else:
150-
# Extract data path and query from task
151-
data_path, query = task.split("|")
152-
# Hash the task to get a unique identifier
153-
task_hash = hashlib.sha256(task.encode()).hexdigest()
154-
# Create temporary directory in GCP mounted storage volume
155-
trajectory_path = cfg.DATA_STORAGE_PATH / f"{task_hash}-{time.time()}"
156-
trajectory_path.mkdir(parents=True, exist_ok=True)
157-
nb_path = trajectory_path / NBEnvironment.NOTEBOOK_NAME
158-
# Copy task data to trajectory path
159-
for item in (cfg.DATA_STORAGE_PATH / data_path).iterdir():
160-
if item.is_file():
161-
shutil.copy2(item, trajectory_path)
162-
elif item.is_dir():
163-
shutil.copytree(
164-
item, trajectory_path / item.name, dirs_exist_ok=True
165-
)
184+
kwargs = {}
185+
logger.info("Filtered kwargs: %s", kwargs)
166186

167187
# Augment incoming task with CoT instructions
168188
augmented_task = f"""\
@@ -215,23 +235,5 @@ def from_task(
215235
language=language,
216236
system_prompt=prompts.CAPSULE_SYSTEM_PROMPT_QUERY,
217237
use_tmp_work_dir=False,
218-
)
219-
220-
def export_frame(self) -> Frame:
221-
return Frame(
222-
state={
223-
"last_action": self.state.actions[-1],
224-
"answer": self.state.answer,
225-
"done": self.state.done,
226-
"total_reward": self.state.total_reward,
227-
"nb_state": self.state.nb,
228-
"nb_state_html": nb_to_html(self.state.nb),
229-
"nb_runtime_errors": self.state.notebook_runtime_errors,
230-
},
231-
info={
232-
"eval_mode": self.eval_mode,
233-
"language": self.state.language,
234-
"problem": self.problem,
235-
"problem_id": self.problem_id,
236-
},
238+
**kwargs,
237239
)

src/fhda/notebook_env.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,7 @@ def __init__(
125125
use_tmp_work_dir: bool = True,
126126
language: utils.NBLanguage = utils.NBLanguage.PYTHON,
127127
allow_download_from_gcs: bool = False,
128+
run_notebook_on_edit: bool = False,
128129
):
129130
"""Initialize a notebook environment.
130131
@@ -139,6 +140,8 @@ def __init__(
139140
allow_download_from_gcs: If True, the environment will expose a tool to download
140141
directories from the aviary-storage GCS bucket. Should only be enabled if the
141142
task requires data on GCS. Disabled by default.
143+
run_notebook_on_edit: If True (default), the whole notebook will be rerun
144+
after each edit. If False, only a the cell that was edited will be rerun.
142145
"""
143146
self.work_dir = Path(work_dir)
144147
self.nb_path = Path(nb_path) if nb_path else self.work_dir / self.NOTEBOOK_NAME
@@ -147,6 +150,7 @@ def __init__(
147150
self.language = language
148151
self.allow_download_from_gcs = allow_download_from_gcs
149152
self.use_docker = cfg.USE_DOCKER
153+
self.run_notebook_on_edit = run_notebook_on_edit
150154

151155
async def reset(self) -> tuple[Messages, list[Tool]]:
152156
nb_path, work_dir = self._set_work_dir()
@@ -218,7 +222,7 @@ async def edit_cell(self, contents: str, idx: int | None = None) -> str:
218222
219223
ONLY CODE CELLS ARE SUPPORTED. Do no attempt to write Markdown or raw text,
220224
though you are permitted (and encouraged) to write comments in the code cells.
221-
The notebook will be automatically rerun if a successful edit is made.
225+
The cell will be automatically rerun if a successful edit is made.
222226
223227
Args:
224228
contents: Cell contents to insert. We assume the cell is a code block.
@@ -242,7 +246,12 @@ async def edit_cell(self, contents: str, idx: int | None = None) -> str:
242246
return f"Edited cell #{idx}."
243247
finally:
244248
self.state.save_nb()
245-
await self.run_notebook()
249+
if self.run_notebook_on_edit:
250+
args = {}
251+
else:
252+
idx = len(self.state.cells) - 1 if idx is None else idx
253+
args = {"cell_idx": idx}
254+
await self.run_notebook(**args)
246255

247256
def list_workdir(self) -> str:
248257
"""Recursively lists the contents of the working directory.
@@ -283,12 +292,14 @@ def _list_dir(self, path: Path) -> TListDir:
283292
cast(list, index["files"]).append(item.name)
284293
return index
285294

286-
async def run_notebook(self) -> str:
295+
async def run_notebook(self, cell_idx: int | None = None) -> str:
287296
"""Run the entire notebook sequentially."""
288297
logger.debug("Starting notebook execution")
289298
if self.use_docker:
299+
if cell_idx is not None:
300+
raise ValueError("Cell index not supported for Docker")
290301
return await self._run_notebook_docker()
291-
return await self._run_notebook_local()
302+
return await self._run_notebook_local(cell_idx=cell_idx)
292303

293304
async def _run_notebook_docker(self) -> str:
294305
"""Run notebook using Docker container."""
@@ -325,12 +336,12 @@ async def _run_notebook_docker(self) -> str:
325336
self.state.reload_nb()
326337
return "Executed all cells."
327338

328-
async def _run_notebook_local(self) -> str:
339+
async def _run_notebook_local(self, cell_idx: int | None = None) -> str:
329340
"""Run notebook using local kernel."""
330341
client = self.state.kernel_manager.client()
331342
client.start_channels()
332343
error_messages = await utils.nbformat_run_notebook(
333-
cells=self.state.cells, client=client
344+
cells=self.state.cells, client=client, cell_idx=cell_idx
334345
)
335346
if error_messages:
336347
self.state.notebook_runtime_errors.extend(error_messages)

src/fhda/utils.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,9 @@ def encode_image_to_base64(image: str) -> str:
148148

149149

150150
async def nbformat_run_notebook(
151-
cells: Iterable[nbformat.NotebookNode], client: "AsyncKernelClient"
151+
cells: Iterable[nbformat.NotebookNode],
152+
client: "AsyncKernelClient",
153+
cell_idx: int | None = None,
152154
) -> list[str]:
153155
"""Execute notebook cells using a kernel client and collect outputs.
154156
@@ -163,9 +165,13 @@ async def nbformat_run_notebook(
163165
List of error messages from cells that raised an error
164166
"""
165167
error_messages = []
168+
logger.debug(f"Running notebook with cell_idx: {cell_idx}")
166169
try:
167170
logger.debug("Beginning cell execution")
168171
for idx, cell in enumerate(cells):
172+
if cell_idx is not None and idx != cell_idx:
173+
logger.debug(f"Skipping cell {idx} because cell_idx is {cell_idx}")
174+
continue
169175
if cell.cell_type == "code":
170176
logger.debug(f"Executing code cell {idx}")
171177
cell.outputs = [] # Initialize empty outputs list

src/scripts/deploy.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
)
1212
from crow_client.models.app import TaskQueuesConfig
1313

14-
EVAL = True
14+
EVAL = False
1515

1616
ENV_VARS = {
1717
"OPENAI_API_KEY": os.environ["OPENAI_API_KEY"],
@@ -33,27 +33,28 @@
3333
CrowDeploymentConfig(
3434
requirements_path=Path("pyproject.toml"),
3535
path=Path("src"),
36-
name="bixbench-crow" if EVAL else "data-analysis-crow",
36+
name="bixbench-crow2" if EVAL else "data-analysis-crow",
3737
environment="src.fhda.data_analysis_env.DataAnalysisEnv",
3838
environment_variables=ENV_VARS,
3939
agent="ldp.agent.ReActAgent",
4040
container_config=CONTAINER_CONFIG,
4141
force=True,
4242
frame_paths=frame_paths,
43-
timeout=1200,
43+
timeout=3600,
4444
task_queues_config=TaskQueuesConfig(
45-
name="bixbench-crow" if EVAL else "data-analysis-crow",
45+
name="bixbench-crow2" if EVAL else "data-analysis-crow",
4646
max_running_jobs=300,
4747
),
4848
),
4949
]
5050

5151
if __name__ == "__main__":
5252
client = CrowClient(
53-
stage=Stage.from_string(os.environ.get("CROW_ENV", "DEV")),
53+
# stage=Stage.from_string(os.environ.get("CROW_ENV", ENV_VARS["STAGE"])),
54+
stage=Stage.from_string(os.environ.get("CROW_ENV", "LOCAL")),
5455
organization="FutureHouse",
5556
auth_type=AuthType.API_KEY,
56-
api_key=os.environ["CROW_API_KEY"],
57+
api_key=os.environ[f"CROW_API_KEY_{ENV_VARS['STAGE']}"],
5758
)
5859
for crow in CROWS_TO_DEPLOY:
5960
try:

0 commit comments

Comments
 (0)