Skip to content

Commit cbe5d9d

Browse files
mskarlinCopilot
andauthored
Move context str generation into configurable function (#1039)
Co-authored-by: Copilot <[email protected]>
1 parent a56d10e commit cbe5d9d

File tree

4 files changed

+151
-72
lines changed

4 files changed

+151
-72
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -874,6 +874,7 @@ will return much faster than the first query and we'll be certain the authors ma
874874
| `batch_size` | `1` | Batch size for calling LLMs. |
875875
| `texts_index_mmr_lambda` | `1.0` | Lambda for MMR in text index. |
876876
| `verbosity` | `0` | Integer verbosity level for logging (0-3). 3 = all LLM/Embeddings calls logged. |
877+
| `custom_context_serializer` | `None` | Custom async function (see typing for signature) to override the default answer context serialization. |
877878
| `answer.evidence_k` | `10` | Number of evidence pieces to retrieve. |
878879
| `answer.evidence_detailed_citations` | `True` | Include detailed citations in summaries. |
879880
| `answer.evidence_retrieval` | `True` | Use retrieval vs processing all docs. |

src/paperqa/docs.py

Lines changed: 7 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
import tempfile
88
import urllib.request
99
import warnings
10-
from collections import defaultdict
1110
from collections.abc import Callable, Sequence
1211
from datetime import datetime
1312
from io import BytesIO
@@ -30,7 +29,7 @@
3029
from paperqa.prompts import CANNOT_ANSWER_PHRASE
3130
from paperqa.readers import read_doc
3231
from paperqa.settings import MaybeSettings, get_settings
33-
from paperqa.types import Context, Doc, DocDetails, DocKey, PQASession, Text
32+
from paperqa.types import Doc, DocDetails, DocKey, PQASession, Text
3433
from paperqa.utils import (
3534
citation_to_docname,
3635
get_loop,
@@ -710,7 +709,7 @@ def query(
710709
)
711710
)
712711

713-
async def aquery( # noqa: PLR0912
712+
async def aquery(
714713
self,
715714
query: PQASession | str,
716715
settings: MaybeSettings = None,
@@ -765,75 +764,13 @@ async def aquery( # noqa: PLR0912
765764
session.add_tokens(pre)
766765
pre_str = pre.text
767766

768-
# sort by first score, then name
769-
filtered_contexts = sorted(
770-
contexts,
771-
key=lambda x: (-x.score, x.text.name),
772-
)[: answer_config.answer_max_sources]
773-
# remove any contexts with a score below the cutoff
774-
filtered_contexts = [
775-
c
776-
for c in filtered_contexts
777-
if c.score >= answer_config.evidence_relevance_score_cutoff
778-
]
779-
780-
# shim deprecated flag
781-
# TODO: remove in v6
782-
context_inner_prompt = prompt_config.context_inner
783-
if (
784-
not answer_config.evidence_detailed_citations
785-
and "\nFrom {citation}" in context_inner_prompt
786-
):
787-
# Only keep "\nFrom {citation}" if we are showing detailed citations
788-
context_inner_prompt = context_inner_prompt.replace("\nFrom {citation}", "")
789-
790-
context_str_body = ""
791-
if answer_config.group_contexts_by_question:
792-
contexts_by_question: dict[str, list[Context]] = defaultdict(list)
793-
for c in filtered_contexts:
794-
# Fallback to the main session question if not available.
795-
# question attribute is optional, so if a user
796-
# sets contexts externally, it may not have a question.
797-
question = getattr(c, "question", session.question)
798-
contexts_by_question[question].append(c)
799-
800-
context_sections = []
801-
for question, contexts_in_group in contexts_by_question.items():
802-
inner_strs = [
803-
context_inner_prompt.format(
804-
name=c.id,
805-
text=c.context,
806-
citation=c.text.doc.formatted_citation,
807-
**(c.model_extra or {}),
808-
)
809-
for c in contexts_in_group
810-
]
811-
# Create a section with a question heading
812-
section_header = f'Contexts related to the question: "{question}"'
813-
section = f"{section_header}\n\n" + "\n\n".join(inner_strs)
814-
context_sections.append(section)
815-
context_str_body = "\n\n----\n\n".join(context_sections)
816-
else:
817-
inner_context_strs = [
818-
context_inner_prompt.format(
819-
name=c.id,
820-
text=c.context,
821-
citation=c.text.doc.formatted_citation,
822-
**(c.model_extra or {}),
823-
)
824-
for c in filtered_contexts
825-
]
826-
context_str_body = "\n\n".join(inner_context_strs)
827-
828-
if pre_str:
829-
context_str_body += f"\n\nExtra background information: {pre_str}"
830-
831-
context_str = prompt_config.context_outer.format(
832-
context_str=context_str_body,
833-
valid_keys=", ".join([c.id for c in filtered_contexts]),
767+
context_str = await query_settings.context_serializer(
768+
contexts=contexts,
769+
question=session.question,
770+
pre_str=pre_str,
834771
)
835772

836-
if len(context_str_body.strip()) < 10: # noqa: PLR2004
773+
if len(context_str.strip()) < 10: # noqa: PLR2004
837774
answer_text = (
838775
f"{CANNOT_ANSWER_PHRASE} this question due to insufficient information."
839776
)

src/paperqa/settings.py

Lines changed: 116 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,20 @@
33
import os
44
import pathlib
55
import warnings
6+
from collections import defaultdict
67
from collections.abc import Callable, Mapping, Sequence
78
from enum import StrEnum
89
from pydoc import locate
9-
from typing import Any, ClassVar, Self, TypeAlias, assert_never, cast
10+
from typing import (
11+
Any,
12+
ClassVar,
13+
Protocol,
14+
Self,
15+
TypeAlias,
16+
assert_never,
17+
cast,
18+
runtime_checkable,
19+
)
1020

1121
import anyio
1222
from aviary.core import Tool, ToolSelector
@@ -55,6 +65,7 @@
5565
summary_prompt,
5666
)
5767
from paperqa.readers import PDFParserFn
68+
from paperqa.types import Context
5869
from paperqa.utils import hexdigest, pqa_directory
5970
from paperqa.version import __version__
6071

@@ -63,8 +74,21 @@
6374
_EnvironmentState: TypeAlias = Any
6475

6576

77+
@runtime_checkable
78+
class AsyncContextSerializer(Protocol):
79+
"""Protocol for generating a context string from settings and context."""
80+
81+
async def __call__(
82+
self,
83+
settings: "Settings",
84+
contexts: Sequence[Context],
85+
question: str,
86+
pre_str: str | None,
87+
) -> str: ...
88+
89+
6690
class AnswerSettings(BaseModel):
67-
model_config = ConfigDict(extra="forbid")
91+
model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True)
6892

6993
evidence_k: int = Field(
7094
default=10, description="Number of evidence pieces to retrieve."
@@ -791,6 +815,14 @@ class Settings(BaseSettings):
791815
exclude=True,
792816
frozen=True,
793817
)
818+
custom_context_serializer: AsyncContextSerializer | None = Field(
819+
default=None,
820+
description=(
821+
"Function to turn settings and contexts into an answer context str."
822+
" If not populated, the default context serializer will be used."
823+
),
824+
exclude=True,
825+
)
794826

795827
@model_validator(mode="after")
796828
def _deprecated_field(self) -> Self:
@@ -1026,6 +1058,88 @@ def adjust_tools_for_agent_llm(self, tools: list[Tool]) -> None:
10261058
# Gemini fixed this server-side by mid-April 2025,
10271059
# so this method is now just available for use
10281060

1061+
async def context_serializer(
1062+
self, contexts: Sequence[Context], question: str, pre_str: str | None
1063+
) -> str:
1064+
"""Default function for sorting ranked contexts and inserting into a context string."""
1065+
if self.custom_context_serializer:
1066+
return await self.custom_context_serializer(
1067+
settings=self, contexts=contexts, question=question, pre_str=pre_str
1068+
)
1069+
1070+
answer_config = self.answer
1071+
prompt_config = self.prompts
1072+
1073+
# sort by first score, then name
1074+
filtered_contexts = sorted(
1075+
contexts,
1076+
key=lambda x: (-x.score, x.text.name),
1077+
)[: answer_config.answer_max_sources]
1078+
# remove any contexts with a score below the cutoff
1079+
filtered_contexts = [
1080+
c
1081+
for c in filtered_contexts
1082+
if c.score >= answer_config.evidence_relevance_score_cutoff
1083+
]
1084+
1085+
# shim deprecated flag
1086+
# TODO: remove in v6
1087+
context_inner_prompt = prompt_config.context_inner
1088+
if (
1089+
not answer_config.evidence_detailed_citations
1090+
and "\nFrom {citation}" in context_inner_prompt
1091+
):
1092+
# Only keep "\nFrom {citation}" if we are showing detailed citations
1093+
context_inner_prompt = context_inner_prompt.replace("\nFrom {citation}", "")
1094+
1095+
context_str_body = ""
1096+
if answer_config.group_contexts_by_question:
1097+
contexts_by_question: dict[str, list[Context]] = defaultdict(list)
1098+
for c in filtered_contexts:
1099+
# Fallback to the main session question if not available.
1100+
# question attribute is optional, so if a user
1101+
# sets contexts externally, it may not have a question.
1102+
context_question = getattr(c, "question", question)
1103+
contexts_by_question[context_question].append(c)
1104+
1105+
context_sections = []
1106+
for context_question, contexts_in_group in contexts_by_question.items():
1107+
inner_strs = [
1108+
context_inner_prompt.format(
1109+
name=c.id,
1110+
text=c.context,
1111+
citation=c.text.doc.formatted_citation,
1112+
**(c.model_extra or {}),
1113+
)
1114+
for c in contexts_in_group
1115+
]
1116+
# Create a section with a question heading
1117+
section_header = (
1118+
f'Contexts related to the question: "{context_question}"'
1119+
)
1120+
section = f"{section_header}\n\n" + "\n\n".join(inner_strs)
1121+
context_sections.append(section)
1122+
context_str_body = "\n\n----\n\n".join(context_sections)
1123+
else:
1124+
inner_context_strs = [
1125+
context_inner_prompt.format(
1126+
name=c.id,
1127+
text=c.context,
1128+
citation=c.text.doc.formatted_citation,
1129+
**(c.model_extra or {}),
1130+
)
1131+
for c in filtered_contexts
1132+
]
1133+
context_str_body = "\n\n".join(inner_context_strs)
1134+
1135+
if pre_str:
1136+
context_str_body += f"\n\nExtra background information: {pre_str}"
1137+
1138+
return prompt_config.context_outer.format(
1139+
context_str=context_str_body,
1140+
valid_keys=", ".join([c.id for c in filtered_contexts]),
1141+
)
1142+
10291143

10301144
# Settings: already Settings
10311145
# dict[str, Any]: serialized Settings

tests/test_paperqa.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
from paperqa.prompts import CANNOT_ANSWER_PHRASE
5353
from paperqa.prompts import qa_prompt as default_qa_prompt
5454
from paperqa.readers import PDFParserFn, read_doc
55+
from paperqa.settings import AsyncContextSerializer
5556
from paperqa.types import ChunkMetadata, Context
5657
from paperqa.utils import (
5758
clean_possessives,
@@ -566,6 +567,32 @@ async def test_query(docs_fixture) -> None:
566567
await docs_fixture.aquery("Is XAI usable in chemistry?", settings=settings)
567568

568569

570+
@pytest.mark.asyncio
571+
async def test_custom_context_str_fn(docs_fixture) -> None:
572+
573+
async def custom_context_str_fn( # noqa: RUF029
574+
settings: Settings, # noqa: ARG001
575+
contexts: list[Context], # noqa: ARG001
576+
question: str, # noqa: ARG001
577+
pre_str: str | None = None, # noqa: ARG001
578+
) -> str:
579+
return "TEST OVERRIDE"
580+
581+
assert isinstance(custom_context_str_fn, AsyncContextSerializer)
582+
583+
settings = Settings(
584+
custom_context_serializer=custom_context_str_fn,
585+
prompts={"answer_iteration_prompt": None},
586+
)
587+
588+
session = await docs_fixture.aquery(
589+
"Is XAI usable in chemistry?", settings=settings
590+
)
591+
assert (
592+
session.context == "TEST OVERRIDE"
593+
), "Expected custom context string to be returned."
594+
595+
569596
@pytest.mark.asyncio
570597
async def test_aquery_groups_contexts_by_question(docs_fixture) -> None:
571598

0 commit comments

Comments
 (0)