Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/fhda/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,5 @@
DATA_STORAGE_PATH = Path("storage")
else:
DATA_STORAGE_PATH = Path("/storage")

EVAL = bool(os.getenv("EVAL", "false").lower() == "true")
110 changes: 43 additions & 67 deletions src/fhda/data_analysis_env.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import hashlib
import json
import logging
import shutil
from typing import Any, cast
Expand All @@ -10,7 +9,6 @@
Message,
Messages,
Tool,
eval_answer,
)

from .notebook_env import NBEnvironment
Expand All @@ -33,7 +31,7 @@ def __init__(
answer: str | int | float | None = None, # noqa: PYI041
system_prompt: str | None = None,
correct_reward: float = 1.0,
eval_mode: EvalAnswerMode,
eval_mode: EvalAnswerMode | None = None,
metadata: dict[str, Any] | None = None, # used for NBEvalExpt
mcqs: list[MultipleChoiceQuestion] | None = None,
**kwargs,
Expand Down Expand Up @@ -66,7 +64,7 @@ async def reset(self) -> tuple[Messages, list[Tool]]:

return init_obs, tools

async def submit_answer(self, answer: str | float | dict[str, Any] | None) -> str: # type: ignore[override]
async def submit_answer(self, answer: str) -> str: # type: ignore[override]
"""Submit an answer to the problem.

Note that this tool may only be called once and ends the episode.
Expand All @@ -79,75 +77,50 @@ async def submit_answer(self, answer: str | float | dict[str, Any] | None) -> st
self.state.done = True
logger.info("Submitting answer and closing environment")
await self.close()
correct = False
logger.info("Answer: %s", answer)
return answer

if self.eval_mode is None:
return CORRECT_MSG

if isinstance(self.answer, int):
try:
answer = int(answer) # type: ignore[arg-type]
except ValueError:
pass
else:
correct = answer == self.answer
@classmethod
def eval_from_task(cls, task: str, gcs_artifact_path: str) -> "DataAnalysisEnv":
"""
Used for evaluations via crow jobs.

elif isinstance(self.answer, float):
try:
answer = float(answer) # type: ignore[arg-type]
except ValueError:
pass
else:
correct = abs(answer - self.answer) < 1e-4 * self.answer
Args:
task: The user query structured as <data_path> | <query>
gcs_artifact_path: The path to the GCS artifact – required for evaluation on crow jobs
"""
logger.info("Using the eval_from_task method")

# Create temporary directory in GCP mounted storage volume
task_hash = hashlib.sha256(task.encode()).hexdigest()
trajectory_path = cfg.DATA_STORAGE_PATH / f"{task_hash}-{time.time()}"
trajectory_path.mkdir(parents=True, exist_ok=True)
logger.info("Trajectory path: %s", trajectory_path)
nb_path = trajectory_path / NBEnvironment.NOTEBOOK_NAME
# Copy task data to trajectory path
for item in (cfg.DATA_STORAGE_PATH / gcs_artifact_path).iterdir():
if item.is_file():
shutil.copy2(item, trajectory_path)
elif item.is_dir():
shutil.copytree(item, trajectory_path / item.name, dirs_exist_ok=True)

elif isinstance(self.answer, str):
correct = bool(
await eval_answer(
proposed=str(answer),
correct=str(self.answer),
question=self.problem,
eval_mode=self.eval_mode,
)
language = NBLanguage.PYTHON # In future, this should be a hyperparameter
if trajectory_path.exists():
logger.info(
"Files in directory: %s", [f.name for f in trajectory_path.iterdir()]
)
elif isinstance(self.answer, dict): # This is for mcqs and open questions
# Check if answer is a json string
if isinstance(answer, str): # type: ignore[unreachable]
# Process json into dictionary
try:
processed_answer = json.loads(answer)
except json.JSONDecodeError:
return INCORRECT_MSG
else:
processed_answer = answer if isinstance(answer, dict) else {}

# Loop through each question and answer
for question_id, agent_answer in processed_answer.items():
try:
ideal_answer = self.answer[question_id]
question = next(
q
for q in self.mcqs
if q.question_id.lower() == question_id.lower()
)
correct = bool(
await eval_answer(
proposed=str(agent_answer),
correct=str(ideal_answer),
question=question,
eval_mode=self.eval_mode,
)
)
self.question_rewards[question_id] = correct
except KeyError:
self.question_rewards[question_id] = 0
average_reward = sum(self.question_rewards.values()) / len(self.mcqs)
correct = round(average_reward) == 1.0

if correct:
self.state.total_reward += self.correct_reward
return CORRECT_MSG
return INCORRECT_MSG
return cls(
problem_id=f"data-analysis-task-{task_hash}",
problem=task,
# Using exact just because I won't ultimately be using env evaluation
eval_mode=EvalAnswerMode.EXACT,
nb_path=nb_path,
work_dir=trajectory_path,
language=language,
system_prompt=prompts.CAPSULE_SYSTEM_PROMPT_OPEN,
use_tmp_work_dir=False,
)

@classmethod
def from_task(
Expand All @@ -163,6 +136,8 @@ def from_task(
"""
logger.info("User task: %s", task)
logger.info("GCS artifact path: %s", gcs_artifact_path)
if cfg.EVAL:
return cls.eval_from_task(task, gcs_artifact_path) # type: ignore

if (
gcs_artifact_path
Expand Down Expand Up @@ -251,6 +226,7 @@ def export_frame(self) -> Frame:
"total_reward": self.state.total_reward,
"nb_state": self.state.nb,
"nb_state_html": nb_to_html(self.state.nb),
"nb_runtime_errors": self.state.notebook_runtime_errors,
},
info={
"eval_mode": self.eval_mode,
Expand Down
12 changes: 7 additions & 5 deletions src/fhda/notebook_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def __init__(
# Add initial cell with rpy2 extension load
nbformat.v4.new_code_cell(source="%load_ext rpy2.ipython")
self.nb.metadata.kernelspec = self.language.make_kernelspec()
self.notebook_runtime_errors: list[str] = []

def save_nb(self):
"""Saves the notebook to disk."""
Expand Down Expand Up @@ -248,11 +249,10 @@ def list_workdir(self) -> str:

The contents is represented as a nested JSON dictionary.
"""
logger.info("Listing working directory: %s", self.state.work_dir)
return json.dumps(self._list_dir(self.state.work_dir), indent=2)

# allowing int so that agent doesn't try to force to float
def submit_answer(self, answer: str | float | int) -> str: # noqa: PYI041
def submit_answer(self, answer: str) -> str: # noqa: PYI041
"""Submit an answer to the problem.

Note that this tool may only be called once and ends the episode.
Expand Down Expand Up @@ -329,9 +329,11 @@ async def _run_notebook_local(self) -> str:
"""Run notebook using local kernel."""
client = self.state.kernel_manager.client()
client.start_channels()
working_dir_files = list(self.state.work_dir.glob("**/*"))
logger.info(f"Files in working directory: {working_dir_files}")
await utils.nbformat_run_notebook(cells=self.state.cells, client=client)
error_messages = await utils.nbformat_run_notebook(
cells=self.state.cells, client=client
)
if error_messages:
self.state.notebook_runtime_errors.extend(error_messages)
self.state.save_nb()
logger.debug("Saved notebook to disk")
self.state.reload_nb()
Expand Down
13 changes: 13 additions & 0 deletions src/fhda/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,19 @@
- The first cell has already been loaded with %load_ext rpy2.ipython so you can use %%R cells from the second cell onwards
"""

GENERAL_NOTEBOOK_GUIDELINES_R = """
General Guidelines:
- Write small to medium-sized cells for easier debugging.
- Edit existing cells by their index number when fixing bugs, rather than creating new ones.
- Check dataframe shapes before printing. Use head() for large dataframes.
- Ensure each cell executes successfully before moving to the next.
- Assume you already have the packages you need installed and only install new ones if you receive errors.
- If you need to install packages, use mamba or conda.
IMPORTANT: Use R cells for all analysis.
- All cells are by default R cells.
"""


AVOID_IMAGES = """
AVOID USING PLOTS/IMAGES. USE TABLES AND PRINT OUTPUTS INSTEAD AS MUCH AS POSSIBLE.
"""
Expand Down
13 changes: 11 additions & 2 deletions src/fhda/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def encode_image_to_base64(image: str) -> str:

async def nbformat_run_notebook(
cells: Iterable[nbformat.NotebookNode], client: "AsyncKernelClient"
) -> None:
) -> list[str]:
"""Execute notebook cells using a kernel client and collect outputs.

Args:
Expand All @@ -158,7 +158,11 @@ async def nbformat_run_notebook(

Raises:
ValueError: If there is an error executing a cell

Returns:
List of error messages from cells that raised an error
"""
error_messages = []
try:
logger.debug("Beginning cell execution")
for idx, cell in enumerate(cells):
Expand Down Expand Up @@ -221,8 +225,11 @@ async def nbformat_run_notebook(
f"Value: {content.get('evalue', 'No error message')}\n"
f"Traceback: {content.get('traceback', [])}"
)
error_messages.append(
f"Cell {idx}: {content.get('evalue', '')}"
)
logger.error(error_msg)
raise ValueError(error_msg)
# raise ValueError(error_msg)
elif (
msg_type == "status"
and content["execution_state"] == "idle"
Expand All @@ -233,6 +240,8 @@ async def nbformat_run_notebook(
logger.debug("Stopping kernel channels")
client.stop_channels()

return error_messages


async def exec_cmd(
container: DockerContainer, exec_command: list[str], timeout: float | None = 300
Expand Down
2 changes: 1 addition & 1 deletion src/scripts/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from pydantic import BaseModel, ConfigDict
from pydantic_core import PydanticUndefined

from .logging import configure_logs
from .expt_logging import configure_logs

logger = logging.getLogger(__name__)

Expand Down
11 changes: 10 additions & 1 deletion src/scripts/deploy.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,17 @@
FramePath,
AuthType,
)
from crow_client.models.app import TaskQueuesConfig

EVAL = True

ENV_VARS = {
"OPENAI_API_KEY": os.environ["OPENAI_API_KEY"],
"ANTHROPIC_API_KEY": os.environ["ANTHROPIC_API_KEY"],
"USE_R": "false",
"USE_DOCKER": "false",
"STAGE": "DEV",
"EVAL": "true" if EVAL else "false",
}

CONTAINER_CONFIG = DockerContainerConfiguration(cpu="2", memory="4Gi")
Expand All @@ -29,13 +33,18 @@
CrowDeploymentConfig(
requirements_path=Path("pyproject.toml"),
path=Path("src"),
name="data-analysis-crow",
name="bixbench-crow" if EVAL else "data-analysis-crow",
environment="src.fhda.data_analysis_env.DataAnalysisEnv",
environment_variables=ENV_VARS,
agent="ldp.agent.ReActAgent",
container_config=CONTAINER_CONFIG,
force=True,
frame_paths=frame_paths,
timeout=1200,
task_queues_config=TaskQueuesConfig(
name="bixbench-crow" if EVAL else "data-analysis-crow",
max_running_jobs=300,
),
),
]

Expand Down
Loading