From f53f8b88155b61b28696aaec079052f6441bdd32 Mon Sep 17 00:00:00 2001 From: Ludovico Mitchener Date: Fri, 11 Apr 2025 12:49:18 -0700 Subject: [PATCH 1/2] Enable cost tracking & update batch jobs --- pyproject.toml | 11 +- src/fhda/Dockerfile.custom_deployment | 184 ++++++++++++++++++++++++++ src/fhda/data_analysis_env.py | 12 +- src/fhda/prompts.py | 81 ++++++++++++ src/scripts/deploy.py | 25 +++- src/scripts/platform_eval.py | 22 +-- src/scripts/platform_run_jobs.py | 88 ++++++++++-- 7 files changed, 388 insertions(+), 35 deletions(-) create mode 100644 src/fhda/Dockerfile.custom_deployment diff --git a/pyproject.toml b/pyproject.toml index b7bcd34..63ca0bc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,7 +9,7 @@ authors = [ dependencies = [ "aiodocker==0.24.0", "fhaviary[server]==0.18.1", - "fh-llm-client==0.0.11", + "fhlmi==0.26.0", "ldp==0.23.0", "pandas==2.2.3", "numpy==2.2.3", @@ -22,11 +22,12 @@ dependencies = [ "google-auth==2.38.0", "google-cloud-storage==3.0.0", "google-cloud-secret-manager==2.23.0", - "crow-client>=0.3.13", + "crow-client>=0.3.14", "jupyter==1.1.1", "nbconvert==7.16.6", "notebook==7.3.2", - "nbformat==5.10.4" + "nbformat==5.10.4", + "pydeseq2==0.5.0" ] description = "Data analysis crow" name = "fhda" @@ -49,7 +50,7 @@ dev = [ run_expt = 'scripts.configurable:_run_expt' [tool.setuptools] -package-dir = {"" = "src"} +package-dir = {"" = "fhda"} [tool.setuptools.packages.find] -where = ["src"] +where = ["fhda"] diff --git a/src/fhda/Dockerfile.custom_deployment b/src/fhda/Dockerfile.custom_deployment new file mode 100644 index 0000000..d793559 --- /dev/null +++ b/src/fhda/Dockerfile.custom_deployment @@ -0,0 +1,184 @@ +# syntax=docker/dockerfile:1.4 +FROM python:3.12-slim AS base + +WORKDIR /app +ENV PYTHONUNBUFFERED=1 +ENV DEBIAN_FRONTEND=noninteractive + +RUN --mount=type=cache,target=/var/cache/apt \ + apt-get update -qq && \ + apt-get install -yq --no-install-recommends \ + git \ + openssh-client \ + wget \ + gpg \ + software-properties-common \ + build-essential && \ + apt-get clean && rm -rf /var/lib/apt/lists/* + +RUN wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O ~/miniconda.sh && \ + chmod +x ~/miniconda.sh && \ + bash ~/miniconda.sh -b -p /app/miniconda && \ + rm ~/miniconda.sh && \ + /app/miniconda/bin/conda init bash + +# Set environment variables to point to conda environment +ENV VIRTUAL_ENV="/app/miniconda/bin" +ENV PATH="/app/miniconda/bin:$PATH" +ENV PYTHONPATH="/app/miniconda/lib/python3.12/site-packages:${PYTHONPATH:-}" + +# Install uv & mamba +RUN pip3 install --no-cache-dir uv==0.5.21 +RUN conda install -c conda-forge mamba -y + + +# Install R and kernels in the crow_env environment +RUN mamba install -c conda-forge -y \ + r-base=4.3.3 \ + r-recommended=4.3 \ + r-irkernel=1.3.2 \ + r-factominer=2.11 \ + r-rcolorbrewer=1.1_3 \ + r-devtools=2.4.5 \ + r-broom=1.0.7 \ + r-data.table=1.15.4 \ + r-enrichr=3.2 \ + r-factoextra=1.0.7 \ + r-ggnewscale=0.5.0 \ + r-ggrepel=0.9.6 \ + r-ggpubr=0.6.0 \ + r-ggvenn=0.1.10 \ + r-janitor=2.2.1 \ + r-multcomp=1.4_26 \ + r-matrix=1.6_5 \ + r-pheatmap=1.0.12 \ + r-tidyverse=2.0.0 \ + r-readxl=1.4.3 \ + r-reshape=0.8.9 \ + r-rstatix=0.7.2 \ + r-viridis=0.6.5 \ + udocker=1.3.17 \ + imbalanced-learn=0.13.0 \ + ipykernel=6.29.5 \ + sqlite=3.47.2 \ + anndata=0.11.1 \ + biopython=1.84 \ + datasets \ + ete3=3.1.3 \ + keras=3.7.0 \ + jupyter=1.0.0 \ + matplotlib=3.10.0 \ + matplotlib-venn=1.1.1 \ + nbconvert=7.16.4 \ + numpy=2.0.2 \ + optuna=4.1.0 \ + openpyxl=3.1.5 \ + pandas=2.2.3 \ + plotly=5.24.1 \ + rpy2=3.5.11 \ + scipy=1.14.1 \ + scanpy=1.10.4 \ + seaborn=0.13.2 \ + scikit-learn=1.6.0 \ + statsmodels=0.14.4 \ + umap-learn=0.5.7 + +RUN python -m ipykernel install --user --name python3 --display-name "Python 3 (ipykernel)" +RUN R -e 'IRkernel::installspec(name = "R", displayname = "R (4.3.3)")' + +RUN mamba install -c conda-forge -c bioconda -y \ + biokit=0.5.0 \ + gseapy=1.1.4 \ + blast=2.16.0 \ + clipkit=2.3.0 \ + fastqc=0.12.1 \ + iqtree=2.3.6 \ + mafft=7.526 \ + metaeuk=7.bba0d80 \ + mygene=3.2.2 \ + perl=5.32.1 \ + phykit=2.0.1 \ + pydeseq2=0.4.12 \ + spades=4.0.0 \ + trim-galore=0.6.10 \ + bioconductor-enhancedvolcano=1.20.0 \ + bioconductor-deseq2=1.42.0 \ + bioconductor-clusterprofiler=4.10.0 \ + bioconductor-org.hs.eg.db=3.18.0 \ + bioconductor-genomicranges=1.54.1 \ + bioconductor-summarizedexperiment=1.32.0 \ + bioconductor-apeglm=1.24.0 + +ENV UV_COMPILE_BYTECODE=1 +ENV UV_LINK_MODE=copy + +FROM base AS builder + +ARG MODULE_NAME +ARG USE_INTERNAL_DEPS +ARG USE_GIT_CROW_CLIENT + + +RUN mkdir -p ~/.ssh && \ + chmod 700 ~/.ssh && \ + ssh-keyscan github.com >> ~/.ssh/known_hosts && \ + printf "Host github.com\n IdentityFile /root/.ssh/pqa_id_ed25519\n IdentityFile /root/.ssh/aviary_id_ed25519\nHost gitlab.company.com\n IdentityFile /root/.ssh/pqa_id_ed25519\n" > ~/.ssh/config + +RUN --mount=type=cache,target=/var/cache/apt \ + apt-get update -qq && \ + apt-get install -yq --no-install-recommends \ + build-essential && \ + apt-get clean && rm -rf /var/lib/apt/lists/* + +ENV VIRTUAL_ENV="/app/miniconda/bin" +ENV PATH="/app/miniconda/bin:$PATH" + +COPY ./${MODULE_NAME} /app/${MODULE_NAME} + +RUN mkdir -p /app/scripts +COPY ./scripts/run_crow_job.py /app/scripts/ + +RUN --mount=type=cache,target=/root/.cache/uv \ + --mount=type=ssh \ + --mount=type=secret,id=ssh_key,target=/root/.ssh/aviary_id_ed25519.tmp \ + if [ "$USE_INTERNAL_DEPS" = "true" ]; then \ + cp /root/.ssh/aviary_id_ed25519.tmp /root/.ssh/aviary_id_ed25519 && \ + chmod 400 /root/.ssh/aviary_id_ed25519 && \ + git clone git@github.com:Future-House/aviary-internal.git /app/aviary_internal && \ + cd /app/aviary_internal/aviary_internal && \ + uv pip install --system -e .; \ + else \ + echo 'Skipping aviary_internal install'; \ + fi && \ + if [ "$USE_GIT_CROW_CLIENT" = "true" ]; then \ + git clone git@github.com:Future-House/crow-ecosystem.git /app/crow-ecosystem && \ + cd /app/crow-ecosystem/packages/crow-client && \ + uv pip install --system -e .; \ + else \ + uv pip install --system crow-client; \ + fi + +WORKDIR /app/${MODULE_NAME} +RUN --mount=type=ssh \ + --mount=type=secret,id=pqa_ssh_key,target=/root/.ssh/pqa_id_ed25519.tmp \ + cp /root/.ssh/pqa_id_ed25519.tmp /root/.ssh/pqa_id_ed25519 && \ + chmod 400 /root/.ssh/pqa_id_ed25519 && \ + if [ -f "pyproject.toml" ]; then \ + uv pip install --system -e .; \ + elif [ -f "requirements.txt" ]; then \ + uv pip install --system -r requirements.txt; \ + else \ + echo "No pyproject.toml or requirements.txt found" && exit 1; \ + fi + +RUN find /app -type l -delete && \ + rm -rf /app/.git + +FROM base AS runtime + +COPY --from=builder /app/ /app/ + +ENV VIRTUAL_ENV="/app/miniconda/bin" +ENV PATH="/app/miniconda/bin:$PATH" +ENV PYTHONPATH="/app/miniconda/lib/python3.12/site-packages:${PYTHONPATH:-}" +CMD ["python", "scripts/run_crow_job.py"] diff --git a/src/fhda/data_analysis_env.py b/src/fhda/data_analysis_env.py index 66ad98f..70ecca0 100644 --- a/src/fhda/data_analysis_env.py +++ b/src/fhda/data_analysis_env.py @@ -11,6 +11,8 @@ Tool, ) +from llmclient import GLOBAL_COST_TRACKER, enable_cost_tracking + from .notebook_env import NBEnvironment from .utils import NBLanguage, MultipleChoiceQuestion, nb_to_html from . import prompts @@ -83,7 +85,7 @@ async def submit_answer(self, answer: str) -> str: # type: ignore[override] def export_frame(self) -> Frame: return Frame( state={ - "last_action": self.state.actions[-1], + "last_action": self.state.actions[-1] if self.state.actions else None, "answer": self.state.answer, "done": self.state.done, "total_reward": self.state.total_reward, @@ -96,6 +98,7 @@ def export_frame(self) -> Frame: "language": self.state.language, "problem": self.problem, "problem_id": self.problem_id, + "cost": GLOBAL_COST_TRACKER.lifetime_cost_usd, }, ) @@ -117,7 +120,8 @@ def from_task( logger.info("User task: %s", task) logger.info("GCS artifact path: %s", gcs_artifact_path) logger.info("environment_config: %s", environment_config) - + # Track cost of running the environment + enable_cost_tracking() if ( not gcs_artifact_path ): # Platform jobs should always be associated with data from a GCS bucket @@ -136,6 +140,7 @@ def from_task( logger.info("Filtered kwargs: %s", kwargs) task_hash = hashlib.sha256(task.encode()).hexdigest() if kwargs.get("eval", False): + logger.info("Eval mode is True") # Create a temporary directory in GCP mounted storage volume trajectory_path = cfg.DATA_STORAGE_PATH / f"{task_hash}-{time.time()}" trajectory_path.mkdir(parents=True, exist_ok=True) @@ -147,6 +152,7 @@ def from_task( item, trajectory_path / item.name, dirs_exist_ok=True ) else: + logger.info("Eval mode is False") # Use the GCP folder created when uploading the data via the platform trajectory_path = cfg.DATA_STORAGE_PATH / gcs_artifact_path # Augment incoming user query with CoT instructions @@ -160,7 +166,7 @@ def from_task( ) logger.info("Trajectory path: %s", trajectory_path) nb_path = trajectory_path / NBEnvironment.NOTEBOOK_NAME - + logger.info("NB path: %s", nb_path) language = NBLanguage.PYTHON # In future, this should be a hyperparameter if language == NBLanguage.R: task += f"\n{prompts.R_OUTPUT_RECOMMENDATION_PROMPT}" diff --git a/src/fhda/prompts.py b/src/fhda/prompts.py index b480e53..3c5fe3f 100644 --- a/src/fhda/prompts.py +++ b/src/fhda/prompts.py @@ -61,6 +61,18 @@ - The first cell has already been loaded with %load_ext rpy2.ipython so you can use %%R cells from the second cell onwards """ +# General notebook guidelines +GENERAL_NOTEBOOK_GUIDELINES_PYTHON = """ +General Guidelines: +- Write small to medium-sized cells for easier debugging. +- Edit existing cells by their index number when fixing bugs, rather than creating new ones. +- Check dataframe shapes before printing. Use head() for large dataframes. +- Ensure each cell executes successfully before moving to the next. +- Assume you already have the packages you need installed and only install new ones if you receive errors. +- If you need to install packages, use pip. +- All cells are by default Python cells. Use python for all analysis. +""" + GENERAL_NOTEBOOK_GUIDELINES_R = """ General Guidelines: - Write small to medium-sized cells for easier debugging. @@ -172,10 +184,79 @@ """ +CHAIN_OF_THOUGHT_AGNOSTIC_PYTHON = """ +Follow these steps to create your notebook, using chain-of-thought reasoning at each stage: + +1. List Directory Contents: + +- Consider how to use the list_workdir tool to recursively list the directory contents. +- Think about how to organize and present this information clearly in the notebook. +- List potential challenges in interpreting the directory structure. +- Consider how the directory structure might inform your approach to the analysis. + +Place the output of the list_workdir tool inside tags. + +2. Load Data and Perform Descriptive Statistics: + +- Identify which data files are most relevant to resolving the task. List these files. +- Plan how to load these files efficiently in Python. +- List the specific descriptive statistics you plan to use (e.g., summary(), str(), head()). +- Consider potential issues like missing data or unexpected formats. How will you handle each? +- Plan how to present this information clearly in the notebook. +- Write down key statistics you expect to see and how you'll interpret them. +- Consider potential data quality issues and how you'll address them. + +Execute your plan to load data and perform descriptive statistics. + +3. Develop Analysis Plan: + +- Break down each task into testable components. List these components. +- For each component, list appropriate statistical tests or visualizations. +- Consider alternative approaches for each component and justify your choices. +- Identify potential confounding factors and how to address them. +- Plan the sequence of your analysis steps, explaining the rationale for each. +- Consider how this analysis plan will be documented in the notebook. +- List potential statistical assumptions for your chosen methods and how you'll test them. +- Think about how your analysis plan addresses your original task. + +Write out your analysis plan as comments in the notebook. + +4. Execute Analysis Plan: + +- For each step in your analysis plan, list the Python functions and libraries you'll use. +- Think about how to structure your code for readability and efficiency. +- Plan how to document your code with clear comments. +- Consider how to present results clearly, using tables or visualizations where appropriate. +- Ensure that all outputs are clearly labeled and explained in the context of the task. +- Plan how you'll interpret each result in relation to the original task. +- Consider potential unexpected results and how you'll handle them. + +Execute your analysis plan, creating new cells as needed. + +5. Conclude and Submit Answer: + +- Reflect on how your results relate to the original task. +- Consider any limitations or uncertainties in your analysis. +- Plan a concise summary of your findings. +- Think about how to phrase your conclusion as clear statements. +- Ensure that the notebook contains all necessary information for another model to derive these answers. +- Consider any additional insights or patterns you've noticed during the analysis. +- Think about potential follow-up questions or areas for further investigation. + +""" + SUBMIT_ANSWER_HYPOTHESIS = """ [Use the submit_answer tool to submit your final answer as a single string either "True" or "False"] Remember, the final notebook should contain all necessary artifacts (plots, tables, print outputs) to solve the task provided. """ +SUBMIT_ANSWER_SINGLE = """ +[Use the submit_answer tool to submit your final answer as a single string] +Example output: +``` +submit_answer("CD94") or submit_answer("-1.23") +``` +Remember, the final notebook should contain all necessary artifacts (plots, tables, print outputs) to solve the task provided. +""" SUBMIT_ANSWER_OPEN = """ [Use the submit_answer tool to submit your final answer as a jsondictionary with keys as the question number and values as a short answer] Example output: diff --git a/src/scripts/deploy.py b/src/scripts/deploy.py index ab9e784..2d11241 100644 --- a/src/scripts/deploy.py +++ b/src/scripts/deploy.py @@ -11,7 +11,7 @@ ) from crow_client.models.app import TaskQueuesConfig -EVAL = False +HIGH = True ENV_VARS = { "OPENAI_API_KEY": os.environ["OPENAI_API_KEY"], @@ -24,7 +24,8 @@ CONTAINER_CONFIG = DockerContainerConfiguration(cpu="2", memory="4Gi") frame_paths = [ - FramePath(path="state.answer", type="text"), + FramePath(path="info.cost", type="text"), + FramePath(path="state.answer", type="markdown"), FramePath(path="state.nb_state_html", type="notebook"), ] @@ -32,7 +33,7 @@ CrowDeploymentConfig( requirements_path=Path("pyproject.toml"), path=Path("src"), - name="data-analysis-crow", + name="data-analysis-crow-high" if HIGH else "data-analysis-crow", environment="src.fhda.data_analysis_env.DataAnalysisEnv", environment_variables=ENV_VARS, agent="ldp.agent.ReActAgent", @@ -47,6 +48,15 @@ ), ] + +def rename_dockerfile(path: Path, new_name: str): + if path.exists(): + path.rename(path.parent / new_name) + print(f"Renamed {path} to {new_name}") + else: + print(f"Warning: {path} does not exist") + + if __name__ == "__main__": client = CrowClient( # stage=Stage.from_string(os.environ.get("CROW_ENV", ENV_VARS["STAGE"])), @@ -55,9 +65,18 @@ auth_type=AuthType.API_KEY, api_key=os.environ[f"CROW_API_KEY_{ENV_VARS['STAGE']}"], ) + + if not HIGH: + dockerfile_path = Path("src/fhda/Dockerfile.custom_deployment") + rename_dockerfile(dockerfile_path, "Dockerfile_skip.custom_deployment") + for crow in CROWS_TO_DEPLOY: try: client.create_crow(crow) print(f"Deploying {crow.name}: {client.get_build_status()}") except Exception as e: print(f"Error deploying {crow.name}: {e}") + + if not HIGH: + dockerfile_path = Path("src/fhda/Dockerfile_skip.custom_deployment") + rename_dockerfile(dockerfile_path, "Dockerfile.custom_deployment") diff --git a/src/scripts/platform_eval.py b/src/scripts/platform_eval.py index eaa351a..b23c878 100644 --- a/src/scripts/platform_eval.py +++ b/src/scripts/platform_eval.py @@ -8,16 +8,15 @@ import logging from pathlib import Path from crow_client import CrowClient -from crow_client.models import ( - AuthType, - Stage, -) +from crow_client.models import AuthType, Stage, JobResponse from aviary.utils import MultipleChoiceQuestion, eval_answer, EvalAnswerMode # Configure logging logger = logging.getLogger(__name__) +ENV = "PROD" + def setup_logging(log_level: int = logging.INFO) -> None: """Configure logging""" @@ -32,7 +31,7 @@ def setup_logging(log_level: int = logging.INFO) -> None: def create_client( api_key: Optional[str] = None, - stage: Stage = Stage.DEV, + stage: Stage = getattr(Stage, ENV), organization: str = "FutureHouse", ) -> CrowClient: """Create and return a CrowClient instance.""" @@ -40,7 +39,7 @@ def create_client( stage=stage, organization=organization, auth_type=AuthType.API_KEY, - api_key=api_key or os.environ["CROW_API_KEY"], + api_key=api_key or os.environ[f"CROW_API_KEY_{ENV}"], ) @@ -78,8 +77,10 @@ async def fetch_jobs_batch( List of fetched jobs """ - async def get_job_async(job_id: str) -> Dict[str, Any]: - return await asyncio.to_thread(client.get_job, job_id) + async def get_job_async(job_id: str) -> JobResponse: + return await asyncio.to_thread( + client.get_job, job_id, False, True + ) # False for history, True for verbose results = [] @@ -97,7 +98,7 @@ async def get_job_async(job_id: str) -> Dict[str, Any]: if i + batch_size < len(job_ids): await asyncio.sleep(0.5) - + results = [i.model_dump() for i in results] return results @@ -167,6 +168,7 @@ def prepare_dataframe(df: pd.DataFrame) -> pd.DataFrame: Returns: Processed DataFrame ready for evaluation """ + print(df.head()) df["answer"] = df["environment_frame"].apply(fetch_answer) df["question_keys"] = df["questions"].apply(lambda x: [i["question_id"] for i in x]) exploded = df.explode("question_keys") @@ -248,7 +250,7 @@ async def main( output_path: Union[str, Path], job_request_batch_size: int = 10, api_key: Optional[str] = None, - stage: Stage = Stage.DEV, + stage: Stage = getattr(Stage, ENV), organization: str = "FutureHouse", log_level: int = logging.INFO, ) -> Tuple[pd.DataFrame, Dict[str, Union[int, float]]]: diff --git a/src/scripts/platform_run_jobs.py b/src/scripts/platform_run_jobs.py index 2da3796..6074b3e 100644 --- a/src/scripts/platform_run_jobs.py +++ b/src/scripts/platform_run_jobs.py @@ -2,7 +2,6 @@ import json import logging import os -import uuid from typing import Any import ast import time @@ -17,19 +16,28 @@ logger = logging.getLogger(__name__) -JOB_NAME = "job-futurehouse-data-analysis-crow-dev" -CROW_STAGE = Stage.DEV -API_KEY = os.environ.get("CROW_API_KEY") -RUN_UUID = str(uuid.uuid4()) -GCS_ARTIFACT_PATH = "bixbench_data/" -HF_REPO = "futurehouse/bixbench" +ENV = "PROD" +JOB_NAME = "job-futurehouse-data-analysis-crow" +CROW_STAGE = getattr(Stage, "LOCAL") # TODO: Change to ENV +API_KEY = os.environ.get(f"CROW_API_KEY_{ENV}") +DATASET_NAME = "bb50k" +if DATASET_NAME == "bixbench": + GCS_ARTIFACT_PATH = "bixbench_data/" + HF_REPO = "futurehouse/bixbench" + SUBMIT_ANSWER_PROMPT = prompts.SUBMIT_ANSWER_OPEN +elif DATASET_NAME == "bb50k": + BB50K_PATH = "local/bb50k/ngs_analysis_rna_seq_dge_dataset_0_qa_metadata_questions_20250404_210834.json" + GCS_ARTIFACT_PATH = "bb50k/" + SUBMIT_ANSWER_PROMPT = prompts.SUBMIT_ANSWER_SINGLE +else: + raise ValueError(f"Dataset {DATASET_NAME} not supported") MODEL = "claude-3-7-sonnet-latest" TEMPERATURE = 1 NUM_RETRIES = 3 MAX_STEPS = 50 AVOID_IMAGES = True -NUM_ITERATIONS = 5 -RUN_NAME = "baseline-3.7-single-cell-run2" +NUM_ITERATIONS = 2 +RUN_NAME = "bb50k_v1" RESULTS_FILE = f"local/bixbench_runs/{RUN_NAME}-{time.strftime('%Y%m%d-%H%M%S')}.json" RUNTIME_PARAMS = { "model": MODEL, @@ -39,6 +47,7 @@ "avoid_images": AVOID_IMAGES, "run_name": RUN_NAME, } +MINI_MODE = False MINUTES = 60 SLEEP_TIME = 0.5 * MINUTES @@ -59,9 +68,9 @@ async def prepare_job(capsule: dict[str, Any]) -> JobRequest: {formatted_question} - {prompts.CHAIN_OF_THOUGHT_AGNOSTIC} - {prompts.SUBMIT_ANSWER_OPEN} - {prompts.GENERAL_NOTEBOOK_GUIDELINES}""" + {prompts.CHAIN_OF_THOUGHT_AGNOSTIC_PYTHON} + {SUBMIT_ANSWER_PROMPT} + {prompts.GENERAL_NOTEBOOK_GUIDELINES_PYTHON}""" if AVOID_IMAGES: task += prompts.AVOID_IMAGES @@ -83,7 +92,10 @@ async def prepare_job(capsule: dict[str, Any]) -> JobRequest: name=JOB_NAME, query=task, runtime_config=RuntimeConfig( - agent=agent, max_steps=MAX_STEPS, upload_id=capsule["data_folder"] + agent=agent, + max_steps=MAX_STEPS, + upload_id=capsule["data_folder"], + environment_config={"run_notebook_on_edit": False, "eval": True}, ), ) return job_data @@ -126,6 +138,45 @@ async def load_bixbench_data( return processed_dataset +async def load_bb50k_data( + open_question: bool = True, +) -> list[dict[str, Any]]: + """Load the BixBench dataset.""" + data = json.load( + open( + "local/bb50k/ngs_analysis_rna_seq_dge_dataset_0_qa_metadata_questions_20250404_210834.json" + ) + ) + data = data["questions"] + processed_data = [] + for i in data: + processed_data.append( + { + "data_folder": GCS_ARTIFACT_PATH + "dataset0", + "short_id": i["qa_id"], + "categories": i["generator_class"], + "uuid": i["qa_id"], + "domain": i["domain"], + "workflow": i["workflow"], + "dataset": i["dataset"], + "source_node": i["source_node"], + "node_execution_order": i["node_execution_order"], + "answer_type": i["answer_type"], + "template": i["template"], + "questions": [ + MultipleChoiceQuestion( + question=i["question"], + options=[], + ideal_answer=str(i["answer_value"]), + shuffle_seed=MultipleChoiceQuestion.SEED_USING_QUESTION, + prompt_without_options=open_question, + ) + ], + } + ) + return processed_data + + async def submit_jobs( data: list[dict[str, Any]], ) -> list[dict[str, Any]]: @@ -189,7 +240,16 @@ async def save_results(jobs: list[dict[str, Any]], output_file: str): async def main(): - data = await load_bixbench_data() + if DATASET_NAME == "bixbench": + data = await load_bixbench_data() + elif DATASET_NAME == "bb50k": + data = await load_bb50k_data() + else: + raise ValueError(f"Dataset {DATASET_NAME} not supported") + + if MINI_MODE: + data = data[:5] + jobs = await submit_jobs(data) await save_results(jobs, RESULTS_FILE) From 833b4a7565e070bc474d0f21ccdb52cd114276fb Mon Sep 17 00:00:00 2001 From: Ludovico Mitchener Date: Fri, 11 Apr 2025 12:51:49 -0700 Subject: [PATCH 2/2] Update pyproject --- pyproject.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 63ca0bc..d931b2a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,7 +50,7 @@ dev = [ run_expt = 'scripts.configurable:_run_expt' [tool.setuptools] -package-dir = {"" = "fhda"} +package-dir = {"" = "src"} [tool.setuptools.packages.find] -where = ["fhda"] +where = ["src"]