|
3 | 3 | import os
|
4 | 4 | import pathlib
|
5 | 5 | import warnings
|
| 6 | +from collections import defaultdict |
6 | 7 | from collections.abc import Callable, Mapping, Sequence
|
7 | 8 | from enum import StrEnum
|
8 | 9 | from pydoc import locate
|
9 |
| -from typing import Any, ClassVar, Self, TypeAlias, assert_never, cast |
| 10 | +from typing import ( |
| 11 | + Any, |
| 12 | + ClassVar, |
| 13 | + Protocol, |
| 14 | + Self, |
| 15 | + TypeAlias, |
| 16 | + assert_never, |
| 17 | + cast, |
| 18 | + runtime_checkable, |
| 19 | +) |
10 | 20 |
|
11 | 21 | import anyio
|
12 | 22 | from aviary.core import Tool, ToolSelector
|
|
55 | 65 | summary_prompt,
|
56 | 66 | )
|
57 | 67 | from paperqa.readers import PDFParserFn
|
| 68 | +from paperqa.types import Context |
58 | 69 | from paperqa.utils import hexdigest, pqa_directory
|
59 | 70 | from paperqa.version import __version__
|
60 | 71 |
|
|
63 | 74 | _EnvironmentState: TypeAlias = Any
|
64 | 75 |
|
65 | 76 |
|
| 77 | +@runtime_checkable |
| 78 | +class AsyncContextSerializer(Protocol): |
| 79 | + """Protocol for generating a context string from settings and context.""" |
| 80 | + |
| 81 | + async def __call__( |
| 82 | + self, |
| 83 | + settings: "Settings", |
| 84 | + contexts: Sequence[Context], |
| 85 | + question: str, |
| 86 | + pre_str: str | None, |
| 87 | + ) -> str: ... |
| 88 | + |
| 89 | + |
66 | 90 | class AnswerSettings(BaseModel):
|
67 |
| - model_config = ConfigDict(extra="forbid") |
| 91 | + model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True) |
68 | 92 |
|
69 | 93 | evidence_k: int = Field(
|
70 | 94 | default=10, description="Number of evidence pieces to retrieve."
|
@@ -791,6 +815,14 @@ class Settings(BaseSettings):
|
791 | 815 | exclude=True,
|
792 | 816 | frozen=True,
|
793 | 817 | )
|
| 818 | + custom_context_serializer: AsyncContextSerializer | None = Field( |
| 819 | + default=None, |
| 820 | + description=( |
| 821 | + "Function to turn settings and contexts into an answer context str." |
| 822 | + " If not populated, the default context serializer will be used." |
| 823 | + ), |
| 824 | + exclude=True, |
| 825 | + ) |
794 | 826 |
|
795 | 827 | @model_validator(mode="after")
|
796 | 828 | def _deprecated_field(self) -> Self:
|
@@ -1026,6 +1058,88 @@ def adjust_tools_for_agent_llm(self, tools: list[Tool]) -> None:
|
1026 | 1058 | # Gemini fixed this server-side by mid-April 2025,
|
1027 | 1059 | # so this method is now just available for use
|
1028 | 1060 |
|
| 1061 | + async def context_serializer( |
| 1062 | + self, contexts: Sequence[Context], question: str, pre_str: str | None |
| 1063 | + ) -> str: |
| 1064 | + """Default function for sorting ranked contexts and inserting into a context string.""" |
| 1065 | + if self.custom_context_serializer: |
| 1066 | + return await self.custom_context_serializer( |
| 1067 | + settings=self, contexts=contexts, question=question, pre_str=pre_str |
| 1068 | + ) |
| 1069 | + |
| 1070 | + answer_config = self.answer |
| 1071 | + prompt_config = self.prompts |
| 1072 | + |
| 1073 | + # sort by first score, then name |
| 1074 | + filtered_contexts = sorted( |
| 1075 | + contexts, |
| 1076 | + key=lambda x: (-x.score, x.text.name), |
| 1077 | + )[: answer_config.answer_max_sources] |
| 1078 | + # remove any contexts with a score below the cutoff |
| 1079 | + filtered_contexts = [ |
| 1080 | + c |
| 1081 | + for c in filtered_contexts |
| 1082 | + if c.score >= answer_config.evidence_relevance_score_cutoff |
| 1083 | + ] |
| 1084 | + |
| 1085 | + # shim deprecated flag |
| 1086 | + # TODO: remove in v6 |
| 1087 | + context_inner_prompt = prompt_config.context_inner |
| 1088 | + if ( |
| 1089 | + not answer_config.evidence_detailed_citations |
| 1090 | + and "\nFrom {citation}" in context_inner_prompt |
| 1091 | + ): |
| 1092 | + # Only keep "\nFrom {citation}" if we are showing detailed citations |
| 1093 | + context_inner_prompt = context_inner_prompt.replace("\nFrom {citation}", "") |
| 1094 | + |
| 1095 | + context_str_body = "" |
| 1096 | + if answer_config.group_contexts_by_question: |
| 1097 | + contexts_by_question: dict[str, list[Context]] = defaultdict(list) |
| 1098 | + for c in filtered_contexts: |
| 1099 | + # Fallback to the main session question if not available. |
| 1100 | + # question attribute is optional, so if a user |
| 1101 | + # sets contexts externally, it may not have a question. |
| 1102 | + context_question = getattr(c, "question", question) |
| 1103 | + contexts_by_question[context_question].append(c) |
| 1104 | + |
| 1105 | + context_sections = [] |
| 1106 | + for context_question, contexts_in_group in contexts_by_question.items(): |
| 1107 | + inner_strs = [ |
| 1108 | + context_inner_prompt.format( |
| 1109 | + name=c.id, |
| 1110 | + text=c.context, |
| 1111 | + citation=c.text.doc.formatted_citation, |
| 1112 | + **(c.model_extra or {}), |
| 1113 | + ) |
| 1114 | + for c in contexts_in_group |
| 1115 | + ] |
| 1116 | + # Create a section with a question heading |
| 1117 | + section_header = ( |
| 1118 | + f'Contexts related to the question: "{context_question}"' |
| 1119 | + ) |
| 1120 | + section = f"{section_header}\n\n" + "\n\n".join(inner_strs) |
| 1121 | + context_sections.append(section) |
| 1122 | + context_str_body = "\n\n----\n\n".join(context_sections) |
| 1123 | + else: |
| 1124 | + inner_context_strs = [ |
| 1125 | + context_inner_prompt.format( |
| 1126 | + name=c.id, |
| 1127 | + text=c.context, |
| 1128 | + citation=c.text.doc.formatted_citation, |
| 1129 | + **(c.model_extra or {}), |
| 1130 | + ) |
| 1131 | + for c in filtered_contexts |
| 1132 | + ] |
| 1133 | + context_str_body = "\n\n".join(inner_context_strs) |
| 1134 | + |
| 1135 | + if pre_str: |
| 1136 | + context_str_body += f"\n\nExtra background information: {pre_str}" |
| 1137 | + |
| 1138 | + return prompt_config.context_outer.format( |
| 1139 | + context_str=context_str_body, |
| 1140 | + valid_keys=", ".join([c.id for c in filtered_contexts]), |
| 1141 | + ) |
| 1142 | + |
1029 | 1143 |
|
1030 | 1144 | # Settings: already Settings
|
1031 | 1145 | # dict[str, Any]: serialized Settings
|
|
0 commit comments