From 8614355140c2c13a740d2ee8ee1d035aa3cd2575 Mon Sep 17 00:00:00 2001 From: Ludovico Mitchener Date: Sun, 4 May 2025 21:27:14 -0700 Subject: [PATCH] Add pqa tool --- src/fhda/config.py | 6 +++++ src/fhda/data_analysis_env.py | 45 ++++++++++++++++++++++++++++++++++- src/scripts/deploy.py | 7 ++++-- 3 files changed, 55 insertions(+), 3 deletions(-) diff --git a/src/fhda/config.py b/src/fhda/config.py index 2ab126d..a183223 100644 --- a/src/fhda/config.py +++ b/src/fhda/config.py @@ -1,5 +1,6 @@ import os from pathlib import Path +from futurehouse_client.models import Stage USE_DOCKER = bool(os.getenv("USE_DOCKER", "true").lower() == "true") NB_ENVIRONMENT_DOCKER_IMAGE = os.getenv( @@ -20,3 +21,8 @@ DATA_STORAGE_PATH = Path("/storage") VALID_FROM_TASK_KWARGS = ["run_notebook_on_edit", "exclude_tools"] + +# FutureHosue client config +ENVIRONMENT = os.getenv("ENVIRONMENT", "prod") +CROW_STAGE = getattr(Stage, ENVIRONMENT.upper(), Stage.PROD) +PLATFORM_API_KEY = os.getenv("CROW_API_KEY") diff --git a/src/fhda/data_analysis_env.py b/src/fhda/data_analysis_env.py index 1234378..09ea07f 100644 --- a/src/fhda/data_analysis_env.py +++ b/src/fhda/data_analysis_env.py @@ -12,6 +12,8 @@ ) from lmi.cost_tracker import GLOBAL_COST_TRACKER, enable_cost_tracking +from futurehouse_client.models import TaskRequest, AuthType +from futurehouse_client import FutureHouseClient from .notebook_env import NBEnvironment from .utils import NBLanguage, MultipleChoiceQuestion, nb_to_html @@ -36,7 +38,8 @@ def __init__( eval_mode: EvalAnswerMode | None = None, metadata: dict[str, Any] | None = None, # used for NBEvalExpt mcqs: list[MultipleChoiceQuestion] | None = None, - exclude_tools: list[str] | None = None, + # Exclude list_workdir and query_literature tools by default + exclude_tools: list[str] | None = ["list_workdir", "query_literature"], **kwargs, ): super().__init__(**kwargs) @@ -55,6 +58,9 @@ def __init__( async def reset(self) -> tuple[Messages, list[Tool]]: # Discard base class's init_obs and make our own with the problem statement _, tools = await super().reset() + + tools.append(Tool.from_function(self.query_literature)) + if self.exclude_tools: tools = [ tool @@ -83,6 +89,43 @@ async def reset(self) -> tuple[Messages, list[Tool]]: return init_obs, tools + # DA Specific Tools + + async def query_literature(self, query: str) -> str: + """Query the scientific literature. Produces a succinct answer citing the scientific literature. + + Args: + query: The scientific question to answer + """ + logger.info("Running PQA query") + client = FutureHouseClient( + stage=cfg.CROW_STAGE, + auth_type=AuthType.API_KEY, + api_key=cfg.PLATFORM_API_KEY, + ) + + job_data = TaskRequest( + name="job-futurehouse-paperqa2", + query=query, + ) + job_id = client.create_task(job_data) + status = "in progress" + while status in ["in progress", "queued"]: + logger.info( + "Waiting for pqa task to complete... checking again in 5 seconds" + ) + time.sleep(5) + status = client.get_task(job_id).status + + if status == "failed": + raise Exception("PaperQA platform job failed") + + job_result = client.get_task(job_id, verbose=True) + answer = job_result.environment_frame["state"]["state"]["response"]["answer"][ + "answer" + ] + return answer + async def submit_answer(self, answer: str) -> str: # type: ignore[override] """Submit an answer to the problem. diff --git a/src/scripts/deploy.py b/src/scripts/deploy.py index 54657d5..f26aeb4 100644 --- a/src/scripts/deploy.py +++ b/src/scripts/deploy.py @@ -12,13 +12,16 @@ ) from futurehouse_client.models.app import TaskQueuesConfig -HIGH = True +HIGH = False +ENVIRONMENT = "DEV" ENV_VARS = { "OPENAI_API_KEY": os.environ["OPENAI_API_KEY"], "ANTHROPIC_API_KEY": os.environ["ANTHROPIC_API_KEY"], "USE_DOCKER": "false", - "STAGE": "PROD", + "STAGE": ENVIRONMENT, + "ENVIRONMENT": ENVIRONMENT, + "API_KEY": os.environ[f"CROW_API_KEY_{ENVIRONMENT}"], } CONTAINER_CONFIG = DockerContainerConfiguration(cpu="8", memory="16Gi")