Skip to content

Commit 7bb570c

Browse files
authored
Fixing gather_evidence and complete response messages (#812)
1 parent cf377fc commit 7bb570c

File tree

3 files changed

+48
-32
lines changed

3 files changed

+48
-32
lines changed

paperqa/agents/env.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,7 @@ def make_clinical_trial_status(
160160

161161

162162
def clinical_trial_status(state: "EnvironmentState") -> str:
163+
relevant_contexts = state.get_relevant_contexts()
163164
return make_clinical_trial_status(
164165
total_paper_count=len(
165166
{
@@ -172,9 +173,8 @@ def clinical_trial_status(state: "EnvironmentState") -> str:
172173
relevant_paper_count=len(
173174
{
174175
c.text.doc.dockey
175-
for c in state.session.contexts
176-
if c.score > state.RELEVANT_SCORE_CUTOFF
177-
and CLINICAL_TRIALS_BASE
176+
for c in relevant_contexts
177+
if CLINICAL_TRIALS_BASE
178178
not in getattr(c.text.doc, "other", {}).get("client_source", [])
179179
}
180180
),
@@ -189,15 +189,12 @@ def clinical_trial_status(state: "EnvironmentState") -> str:
189189
relevant_clinical_trials=len(
190190
{
191191
c.text.doc.dockey
192-
for c in state.session.contexts
193-
if c.score > state.RELEVANT_SCORE_CUTOFF
194-
and CLINICAL_TRIALS_BASE
192+
for c in relevant_contexts
193+
if CLINICAL_TRIALS_BASE
195194
in getattr(c.text.doc, "other", {}).get("client_source", [])
196195
}
197196
),
198-
evidence_count=len(
199-
[c for c in state.session.contexts if c.score > state.RELEVANT_SCORE_CUTOFF]
200-
),
197+
evidence_count=len(relevant_contexts),
201198
cost=state.session.cost,
202199
)
203200

paperqa/agents/tools.py

Lines changed: 16 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from paperqa.docs import Docs
1818
from paperqa.settings import Settings
1919
from paperqa.sources.clinical_trials import add_clinical_trials_to_docs
20-
from paperqa.types import DocDetails, PQASession
20+
from paperqa.types import Context, DocDetails, PQASession
2121

2222
from .search import get_directory_index
2323

@@ -35,18 +35,11 @@ def make_status(
3535

3636

3737
def default_status(state: "EnvironmentState") -> str:
38+
relevant_contexts = state.get_relevant_contexts()
3839
return make_status(
3940
total_paper_count=len(state.docs.docs),
40-
relevant_paper_count=len(
41-
{
42-
c.text.doc.dockey
43-
for c in state.session.contexts
44-
if c.score > state.RELEVANT_SCORE_CUTOFF
45-
}
46-
),
47-
evidence_count=len(
48-
[c for c in state.session.contexts if c.score > state.RELEVANT_SCORE_CUTOFF]
49-
),
41+
relevant_paper_count=len({c.text.doc.dockey for c in relevant_contexts}),
42+
evidence_count=len(relevant_contexts),
5043
cost=state.session.cost,
5144
)
5245

@@ -80,6 +73,11 @@ def status(self) -> str:
8073
return self.status_fn(cast(Self, self))
8174
return default_status(self)
8275

76+
def get_relevant_contexts(self) -> list[Context]:
77+
return [
78+
c for c in self.session.contexts if c.score > self.RELEVANT_SCORE_CUTOFF
79+
]
80+
8381
def record_action(self, action: ToolRequestMessage) -> None:
8482
self.session.add_tokens(action)
8583
self.session.tool_history.append([tc.function.name for tc in action.tool_calls])
@@ -227,7 +225,8 @@ async def gather_evidence(self, question: str, state: EnvironmentState) -> str:
227225

228226
logger.info(f"{self.TOOL_FN_NAME} starting for question {question!r}.")
229227
original_question = state.session.question
230-
l1_all = l1_relevant = l0 = len(state.session.contexts)
228+
l1 = l0 = len(state.session.contexts)
229+
l1_relevant = l0_relevant = len(state.get_relevant_contexts())
231230

232231
try:
233232
# Swap out the question with the more specific question
@@ -245,14 +244,8 @@ async def gather_evidence(self, question: str, state: EnvironmentState) -> str:
245244
f"{self.TOOL_FN_NAME}_aget_evidence"
246245
),
247246
)
248-
l1_all = len(state.session.contexts)
249-
l1_relevant = len(
250-
[
251-
c
252-
for c in state.session.contexts
253-
if c.score > state.RELEVANT_SCORE_CUTOFF
254-
]
255-
)
247+
l1 = len(state.session.contexts)
248+
l1_relevant = len(state.get_relevant_contexts())
256249
finally:
257250
state.session.question = original_question
258251

@@ -284,7 +277,7 @@ async def gather_evidence(self, question: str, state: EnvironmentState) -> str:
284277
)
285278

286279
return (
287-
f"Added {l1_all - l0} pieces of evidence, {l1_relevant - l0} of which were"
280+
f"Added {l1 - l0} pieces of evidence, {l1_relevant - l0_relevant} of which were"
288281
f" relevant.{best_evidence}\n\n" + status
289282
)
290283

@@ -412,10 +405,10 @@ async def complete(
412405

413406
logger.info(
414407
f"Completing '{state.session.question}' as"
415-
f" '{'a success' if has_successful_answer else 'unsure'}'."
408+
f" '{'certain' if has_successful_answer else 'unsure'}'."
416409
)
417410
# Return answer and status to simplify postprocessing of tool response
418-
return f"{'Success' if has_successful_answer else 'Unsure'} | {state.status}"
411+
return f"{'Certain' if has_successful_answer else 'Unsure'} | {state.status}"
419412

420413

421414
class ClinicalTrialsSearch(NamedTool):

tests/test_agents.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -580,6 +580,17 @@ def new_status(state: EnvironmentState) -> str:
580580
gather_evidence_initialized_callback.assert_awaited_once_with(env_state)
581581
gather_evidence_completed_callback.assert_awaited_once_with(env_state)
582582

583+
split = re.split(
584+
r"(\d+) pieces of evidence, (\d+) of which were relevant",
585+
response,
586+
maxsplit=1,
587+
)
588+
assert len(split) == 4, "Unexpected response shape"
589+
total_added_1, relevant_added_1 = int(split[1]), int(split[2])
590+
assert all(
591+
x >= 0 for x in (total_added_1, relevant_added_1)
592+
), "Expected non-negative counts"
593+
assert len(env_state.get_relevant_contexts()) == relevant_added_1
583594
# ensure 1 piece of top evidence is returned
584595
assert "\n1." in response, "gather_evidence did not return any results"
585596
assert (
@@ -591,6 +602,21 @@ def new_status(state: EnvironmentState) -> str:
591602
response = await gather_evidence_tool.gather_evidence(
592603
session.question, state=env_state
593604
)
605+
606+
split = re.split(
607+
r"(\d+) pieces of evidence, (\d+) of which were relevant",
608+
response,
609+
maxsplit=1,
610+
)
611+
assert len(split) == 4, "Unexpected response shape"
612+
total_added_2, relevant_added_2 = int(split[1]), int(split[2])
613+
assert all(
614+
x >= 0 for x in (total_added_2, relevant_added_2)
615+
), "Expected non-negative counts"
616+
assert (
617+
len(env_state.get_relevant_contexts())
618+
== relevant_added_1 + relevant_added_2
619+
)
594620
# ensure both evidences are returned
595621
assert "\n1." in response, "gather_evidence did not return any results"
596622
assert "\n2." in response, "gather_evidence should return 2 contexts"

0 commit comments

Comments
 (0)