Skip to content

Commit 097ccd2

Browse files
authored
Fixing gen_answer failover leaving raw_answer blank (#1077)
1 parent 35e85a8 commit 097ccd2

File tree

4 files changed

+19
-18
lines changed

4 files changed

+19
-18
lines changed

src/paperqa/agents/tools.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -309,9 +309,6 @@ async def gen_answer(self, state: EnvironmentState) -> str:
309309
Args:
310310
state: Current state.
311311
"""
312-
if not state.docs.docs:
313-
raise EmptyDocsError("Not generating an answer due to having no papers.")
314-
315312
logger.info(f"Generating answer for '{state.session.question}'.")
316313

317314
if f"{self.TOOL_FN_NAME}_initialized" in self.settings.agent.callbacks:

src/paperqa/docs.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
NumpyVectorStore,
2727
VectorStore,
2828
)
29-
from paperqa.prompts import CANNOT_ANSWER_PHRASE
29+
from paperqa.prompts import CANNOT_ANSWER_PHRASE, EMPTY_CONTEXTS
3030
from paperqa.readers import read_doc
3131
from paperqa.settings import MaybeSettings, get_settings
3232
from paperqa.types import Doc, DocDetails, DocKey, PQASession, Text
@@ -742,7 +742,7 @@ async def aquery(
742742
contexts = session.contexts
743743
if answer_config.get_evidence_if_no_contexts and not contexts:
744744
session = await self.aget_evidence(
745-
query=session,
745+
session,
746746
callbacks=callbacks,
747747
settings=settings,
748748
embedding_model=embedding_model,
@@ -774,9 +774,10 @@ async def aquery(
774774
pre_str=pre_str,
775775
)
776776

777-
if len(context_str.strip()) < 10: # noqa: PLR2004
777+
if len(context_str.strip()) <= EMPTY_CONTEXTS:
778778
answer_text = (
779-
f"{CANNOT_ANSWER_PHRASE} this question due to insufficient information."
779+
f"{CANNOT_ANSWER_PHRASE} this question due to"
780+
f" {'having no papers' if not self.docs else 'insufficient information.'}."
780781
)
781782
answer_reasoning = None
782783
else:

src/paperqa/prompts.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,5 +146,6 @@
146146
)
147147

148148
CONTEXT_OUTER_PROMPT = "{context_str}\n\nValid Keys: {valid_keys}"
149+
EMPTY_CONTEXTS = len(CONTEXT_OUTER_PROMPT.format(context_str="", valid_keys="").strip())
149150
CONTEXT_INNER_PROMPT_NOT_DETAILED = "{name}: {text}"
150151
CONTEXT_INNER_PROMPT = f"{CONTEXT_INNER_PROMPT_NOT_DETAILED}\nFrom {{citation}}"

tests/test_agents.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from functools import wraps
1313
from pathlib import Path
1414
from typing import cast
15-
from unittest.mock import AsyncMock, MagicMock, patch
15+
from unittest.mock import AsyncMock, patch
1616
from uuid import uuid4
1717

1818
import ldp.agent
@@ -21,6 +21,7 @@
2121
Environment,
2222
Tool,
2323
ToolRequestMessage,
24+
ToolResponseMessage,
2425
ToolsAdapter,
2526
ToolSelector,
2627
)
@@ -470,26 +471,27 @@ async def test_timeout(agent_test_settings: Settings, agent_type: str | type) ->
470471
agent_test_settings.agent.timeout = 0.05 # Give time for Environment.reset()
471472
agent_test_settings.llm = "gpt-4o-mini"
472473
agent_test_settings.agent.tool_names = {"gen_answer", "complete"}
473-
docs = Docs()
474+
orig_exec_tool_calls = PaperQAEnvironment.exec_tool_calls
475+
tool_responses: list[list[ToolResponseMessage]] = []
474476

475-
async def custom_aget_evidence(*_, **kwargs) -> PQASession: # noqa: RUF029
476-
return kwargs["query"]
477+
async def spy_exec_tool_calls(*args, **kwargs) -> list[ToolResponseMessage]:
478+
responses = await orig_exec_tool_calls(*args, **kwargs)
479+
tool_responses.append(responses)
480+
return responses
477481

478-
with (
479-
patch.object(docs, "docs", {"stub_key": MagicMock(spec_set=Doc)}),
480-
patch.multiple(
481-
Docs, clear_docs=MagicMock(), aget_evidence=custom_aget_evidence
482-
),
483-
):
482+
with patch.object(PaperQAEnvironment, "exec_tool_calls", spy_exec_tool_calls):
484483
response = await agent_query(
485484
query="Are COVID-19 vaccines effective?",
486485
settings=agent_test_settings,
487-
docs=docs,
488486
agent_type=agent_type,
489487
)
490488
# Ensure that GenerateAnswerTool was called in truncation's failover
491489
assert response.status == AgentStatus.TRUNCATED, "Agent did not timeout"
492490
assert CANNOT_ANSWER_PHRASE in response.session.answer
491+
(last_response,) = tool_responses[-1]
492+
assert (
493+
"no papers" in last_response.content
494+
), "Expecting agent to been shown specifics on the failure"
493495

494496

495497
@pytest.mark.flaky(reruns=5, only_rerun=["AssertionError"])

0 commit comments

Comments
 (0)