Skip to content

Commit 608850e

Browse files
authored
Support grouping contexts by question in the final context (#1032)
1 parent 2ca6b8f commit 608850e

File tree

4 files changed

+120
-16
lines changed

4 files changed

+120
-16
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -885,6 +885,7 @@ will return much faster than the first query and we'll be certain the authors ma
885885
| `answer.max_concurrent_requests` | `4` | Max concurrent requests to LLMs. |
886886
| `answer.answer_filter_extra_background` | `False` | Whether to cite background info from model. |
887887
| `answer.get_evidence_if_no_contexts` | `True` | Allow lazy evidence gathering. |
888+
| `answer.group_contexts_by_question` | `False` | Groups the final contexts by the underlying `gather_evidence` question in the final context prompt. |
888889
| `answer.evidence_relevance_score_cutoff` | `1` | Cutoff evidence relevance score to include in the answer context (inclusive) |
889890
| `parsing.chunk_size` | `5000` | Characters per chunk (0 for no chunking). |
890891
| `parsing.page_size_limit` | `1,280,000` | Character limit per page. |

src/paperqa/docs.py

Lines changed: 43 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import tempfile
88
import urllib.request
99
import warnings
10+
from collections import defaultdict
1011
from collections.abc import Callable, Sequence
1112
from datetime import datetime
1213
from io import BytesIO
@@ -29,7 +30,7 @@
2930
from paperqa.prompts import CANNOT_ANSWER_PHRASE
3031
from paperqa.readers import read_doc
3132
from paperqa.settings import MaybeSettings, get_settings
32-
from paperqa.types import Doc, DocDetails, DocKey, PQASession, Text
33+
from paperqa.types import Context, Doc, DocDetails, DocKey, PQASession, Text
3334
from paperqa.utils import (
3435
citation_to_docname,
3536
get_loop,
@@ -785,26 +786,53 @@ async def aquery( # noqa: PLR0912
785786
# Only keep "\nFrom {citation}" if we are showing detailed citations
786787
context_inner_prompt = context_inner_prompt.replace("\nFrom {citation}", "")
787788

788-
inner_context_strs = [
789-
context_inner_prompt.format(
790-
name=c.id,
791-
text=c.context,
792-
citation=c.text.doc.formatted_citation,
793-
**(c.model_extra or {}),
794-
)
795-
for c in filtered_contexts
796-
]
789+
context_str_body = ""
790+
if answer_config.group_contexts_by_question:
791+
contexts_by_question: dict[str, list[Context]] = defaultdict(list)
792+
for c in filtered_contexts:
793+
# Fallback to the main session question if not available.
794+
# question attribute is optional, so if a user
795+
# sets contexts externally, it may not have a question.
796+
question = getattr(c, "question", session.question)
797+
contexts_by_question[question].append(c)
798+
799+
context_sections = []
800+
for question, contexts_in_group in contexts_by_question.items():
801+
inner_strs = [
802+
context_inner_prompt.format(
803+
name=c.id,
804+
text=c.context,
805+
citation=c.text.doc.formatted_citation,
806+
**(c.model_extra or {}),
807+
)
808+
for c in contexts_in_group
809+
]
810+
# Create a section with a question heading
811+
section_header = f'Contexts related to the question: "{question}"'
812+
section = f"{section_header}\n\n" + "\n\n".join(inner_strs)
813+
context_sections.append(section)
814+
context_str_body = "\n\n----\n\n".join(context_sections)
815+
else:
816+
inner_context_strs = [
817+
context_inner_prompt.format(
818+
name=c.id,
819+
text=c.context,
820+
citation=c.text.doc.formatted_citation,
821+
**(c.model_extra or {}),
822+
)
823+
for c in filtered_contexts
824+
]
825+
context_str_body = "\n\n".join(inner_context_strs)
826+
797827
if pre_str:
798-
inner_context_strs += (
799-
[f"Extra background information: {pre_str}"] if pre_str else []
800-
)
828+
context_str_body += f"\n\nExtra background information: {pre_str}"
801829

802830
context_str = prompt_config.context_outer.format(
803-
context_str="\n\n".join(inner_context_strs),
831+
context_str=context_str_body,
804832
valid_keys=", ".join([c.id for c in filtered_contexts]),
805833
)
806834

807-
if len(context_str) < 10: # noqa: PLR2004
835+
if len(context_str_body.strip()) < 10: # noqa: PLR2004
808836
answer_text = (
809837
f"{CANNOT_ANSWER_PHRASE} this question due to insufficient information."
810838
)

src/paperqa/settings.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,10 @@ class AnswerSettings(BaseModel):
120120
" called before evidence was gathered."
121121
),
122122
)
123+
group_contexts_by_question: bool = Field(
124+
default=False,
125+
description="Whether to group contexts by question when generating answers.",
126+
)
123127

124128
@model_validator(mode="after")
125129
def _deprecated_field(self) -> Self:

tests/test_paperqa.py

Lines changed: 72 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@
5252
from paperqa.prompts import CANNOT_ANSWER_PHRASE
5353
from paperqa.prompts import qa_prompt as default_qa_prompt
5454
from paperqa.readers import PDFParserFn, read_doc
55-
from paperqa.types import ChunkMetadata
55+
from paperqa.types import ChunkMetadata, Context
5656
from paperqa.utils import (
5757
clean_possessives,
5858
encode_id,
@@ -566,6 +566,77 @@ async def test_query(docs_fixture) -> None:
566566
await docs_fixture.aquery("Is XAI usable in chemistry?", settings=settings)
567567

568568

569+
@pytest.mark.asyncio
570+
async def test_aquery_groups_contexts_by_question(docs_fixture) -> None:
571+
572+
session = PQASession(question="What is the relationship between chemistry and AI?")
573+
574+
doc = Doc(docname="test_doc", citation="Test Doc, 2025", dockey="key1")
575+
text1 = Text(text="XAI is useful for molecules.", name="t1", doc=doc)
576+
text2 = Text(text="Drug discovery uses AI.", name="t2", doc=doc)
577+
text3 = Text(text="Organic chemistry is a field.", name="t3", doc=doc)
578+
579+
session.contexts = [
580+
Context(
581+
text=text1,
582+
context="Explanation about XAI and molecules.",
583+
score=6,
584+
question="Is XAI usable in chemistry?",
585+
),
586+
Context(
587+
text=text2,
588+
context="Details on how drug discovery leverages AI.",
589+
score=5,
590+
question="Is XAI usable in chemistry?",
591+
),
592+
Context(
593+
text=text3,
594+
context="General facts about organic chemistry.",
595+
score=5,
596+
question="What is organic chemistry?",
597+
),
598+
]
599+
600+
settings = Settings(
601+
prompts={"answer_iteration_prompt": None},
602+
answer={"group_contexts_by_question": True},
603+
)
604+
605+
result = await docs_fixture.aquery(session, settings=settings)
606+
607+
final_context_str = result.context
608+
609+
assert (
610+
'Contexts related to the question: "Is XAI usable in chemistry?"'
611+
in final_context_str
612+
)
613+
614+
assert (
615+
'Contexts related to the question: "What is organic chemistry?"'
616+
in final_context_str
617+
)
618+
619+
assert "Explanation about XAI and molecules." in final_context_str
620+
assert "Details on how drug discovery leverages AI." in final_context_str
621+
assert "General facts about organic chemistry." in final_context_str
622+
623+
assert "\n\n----\n\n" in final_context_str
624+
q1_header_pos = final_context_str.find(
625+
'Contexts related to the question: "Is XAI usable in chemistry?"'
626+
)
627+
q2_header_pos = final_context_str.find(
628+
'Contexts related to the question: "What is organic chemistry?"'
629+
)
630+
context1_pos = final_context_str.find("Explanation about XAI and molecules.")
631+
context3_pos = final_context_str.find("General facts about organic chemistry.")
632+
633+
assert (
634+
0 == q1_header_pos < context1_pos
635+
), "Expected q1 header to be first, and the context to follow."
636+
assert q1_header_pos < q2_header_pos
637+
assert q2_header_pos < context3_pos
638+
639+
569640
@pytest.mark.asyncio
570641
async def test_query_with_iteration(docs_fixture) -> None:
571642
# we store these results to check that the prompts are OK

0 commit comments

Comments
 (0)