17
17
from paperqa .docs import Docs
18
18
from paperqa .settings import Settings
19
19
from paperqa .sources .clinical_trials import add_clinical_trials_to_docs
20
- from paperqa .types import DocDetails , PQASession
20
+ from paperqa .types import Context , DocDetails , PQASession
21
21
22
22
from .search import get_directory_index
23
23
@@ -35,18 +35,11 @@ def make_status(
35
35
36
36
37
37
def default_status (state : "EnvironmentState" ) -> str :
38
+ relevant_contexts = state .get_relevant_contexts ()
38
39
return make_status (
39
40
total_paper_count = len (state .docs .docs ),
40
- relevant_paper_count = len (
41
- {
42
- c .text .doc .dockey
43
- for c in state .session .contexts
44
- if c .score > state .RELEVANT_SCORE_CUTOFF
45
- }
46
- ),
47
- evidence_count = len (
48
- [c for c in state .session .contexts if c .score > state .RELEVANT_SCORE_CUTOFF ]
49
- ),
41
+ relevant_paper_count = len ({c .text .doc .dockey for c in relevant_contexts }),
42
+ evidence_count = len (relevant_contexts ),
50
43
cost = state .session .cost ,
51
44
)
52
45
@@ -80,6 +73,11 @@ def status(self) -> str:
80
73
return self .status_fn (cast (Self , self ))
81
74
return default_status (self )
82
75
76
+ def get_relevant_contexts (self ) -> list [Context ]:
77
+ return [
78
+ c for c in self .session .contexts if c .score > self .RELEVANT_SCORE_CUTOFF
79
+ ]
80
+
83
81
def record_action (self , action : ToolRequestMessage ) -> None :
84
82
self .session .add_tokens (action )
85
83
self .session .tool_history .append ([tc .function .name for tc in action .tool_calls ])
@@ -227,7 +225,8 @@ async def gather_evidence(self, question: str, state: EnvironmentState) -> str:
227
225
228
226
logger .info (f"{ self .TOOL_FN_NAME } starting for question { question !r} ." )
229
227
original_question = state .session .question
230
- l1_all = l1_relevant = l0 = len (state .session .contexts )
228
+ l1 = l0 = len (state .session .contexts )
229
+ l1_relevant = l0_relevant = len (state .get_relevant_contexts ())
231
230
232
231
try :
233
232
# Swap out the question with the more specific question
@@ -245,14 +244,8 @@ async def gather_evidence(self, question: str, state: EnvironmentState) -> str:
245
244
f"{ self .TOOL_FN_NAME } _aget_evidence"
246
245
),
247
246
)
248
- l1_all = len (state .session .contexts )
249
- l1_relevant = len (
250
- [
251
- c
252
- for c in state .session .contexts
253
- if c .score > state .RELEVANT_SCORE_CUTOFF
254
- ]
255
- )
247
+ l1 = len (state .session .contexts )
248
+ l1_relevant = len (state .get_relevant_contexts ())
256
249
finally :
257
250
state .session .question = original_question
258
251
@@ -284,7 +277,7 @@ async def gather_evidence(self, question: str, state: EnvironmentState) -> str:
284
277
)
285
278
286
279
return (
287
- f"Added { l1_all - l0 } pieces of evidence, { l1_relevant - l0 } of which were"
280
+ f"Added { l1 - l0 } pieces of evidence, { l1_relevant - l0_relevant } of which were"
288
281
f" relevant.{ best_evidence } \n \n " + status
289
282
)
290
283
@@ -412,10 +405,10 @@ async def complete(
412
405
413
406
logger .info (
414
407
f"Completing '{ state .session .question } ' as"
415
- f" '{ 'a success ' if has_successful_answer else 'unsure' } '."
408
+ f" '{ 'certain ' if has_successful_answer else 'unsure' } '."
416
409
)
417
410
# Return answer and status to simplify postprocessing of tool response
418
- return f"{ 'Success ' if has_successful_answer else 'Unsure' } | { state .status } "
411
+ return f"{ 'Certain ' if has_successful_answer else 'Unsure' } | { state .status } "
419
412
420
413
421
414
class ClinicalTrialsSearch (NamedTool ):
0 commit comments