1212)
1313
1414from lmi .cost_tracker import GLOBAL_COST_TRACKER , enable_cost_tracking
15+ from futurehouse_client .models import TaskRequest , AuthType
16+ from futurehouse_client import FutureHouseClient
1517
1618from .notebook_env import NBEnvironment
1719from .utils import NBLanguage , MultipleChoiceQuestion , nb_to_html
@@ -36,7 +38,8 @@ def __init__(
3638 eval_mode : EvalAnswerMode | None = None ,
3739 metadata : dict [str , Any ] | None = None , # used for NBEvalExpt
3840 mcqs : list [MultipleChoiceQuestion ] | None = None ,
39- exclude_tools : list [str ] | None = None ,
41+ # Exclude list_workdir and query_literature tools by default
42+ exclude_tools : list [str ] | None = ["list_workdir" , "query_literature" ],
4043 ** kwargs ,
4144 ):
4245 super ().__init__ (** kwargs )
@@ -55,6 +58,9 @@ def __init__(
5558 async def reset (self ) -> tuple [Messages , list [Tool ]]:
5659 # Discard base class's init_obs and make our own with the problem statement
5760 _ , tools = await super ().reset ()
61+
62+ tools .append (Tool .from_function (self .query_literature ))
63+
5864 if self .exclude_tools :
5965 tools = [
6066 tool
@@ -83,6 +89,43 @@ async def reset(self) -> tuple[Messages, list[Tool]]:
8389
8490 return init_obs , tools
8591
92+ # DA Specific Tools
93+
94+ async def query_literature (self , query : str ) -> str :
95+ """Query the scientific literature. Produces a succinct answer citing the scientific literature.
96+
97+ Args:
98+ query: The scientific question to answer
99+ """
100+ logger .info ("Running PQA query" )
101+ client = FutureHouseClient (
102+ stage = cfg .CROW_STAGE ,
103+ auth_type = AuthType .API_KEY ,
104+ api_key = cfg .PLATFORM_API_KEY ,
105+ )
106+
107+ job_data = TaskRequest (
108+ name = "job-futurehouse-paperqa2" ,
109+ query = query ,
110+ )
111+ job_id = client .create_task (job_data )
112+ status = "in progress"
113+ while status in ["in progress" , "queued" ]:
114+ logger .info (
115+ "Waiting for pqa task to complete... checking again in 5 seconds"
116+ )
117+ time .sleep (5 )
118+ status = client .get_task (job_id ).status
119+
120+ if status == "failed" :
121+ raise Exception ("PaperQA platform job failed" )
122+
123+ job_result = client .get_task (job_id , verbose = True )
124+ answer = job_result .environment_frame ["state" ]["state" ]["response" ]["answer" ][
125+ "answer"
126+ ]
127+ return answer
128+
86129 async def submit_answer (self , answer : str ) -> str : # type: ignore[override]
87130 """Submit an answer to the problem.
88131
0 commit comments