Skip to content

Commit 7ce600b

Browse files
Added the get_manifest_kwargs function to get more than doi and title… (#848)
Co-authored-by: pre-commit-ci-lite[bot] <117423508+pre-commit-ci-lite[bot]@users.noreply.github.com>
1 parent f8d364a commit 7ce600b

File tree

8 files changed

+101
-47
lines changed

8 files changed

+101
-47
lines changed

paperqa/agents/search.py

Lines changed: 20 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import warnings
1111
import zlib
1212
from collections import Counter
13-
from collections.abc import AsyncIterator, Callable, Collection, Sequence
13+
from collections.abc import AsyncIterator, Callable, Sequence
1414
from datetime import datetime
1515
from enum import StrEnum, auto
1616
from typing import TYPE_CHECKING, Any, ClassVar
@@ -44,7 +44,7 @@
4444

4545
from paperqa.docs import Docs
4646
from paperqa.settings import IndexSettings, get_settings
47-
from paperqa.types import DocDetails
47+
from paperqa.types import VAR_MATCH_LOOKUP, DocDetails
4848
from paperqa.utils import ImpossibleParsingError, hexdigest
4949

5050
from .models import SupportsPickle
@@ -108,14 +108,12 @@ def read_from_string(self, data: str | bytes) -> BaseModel | SupportsPickle:
108108
return pickle.loads(data) # type: ignore[arg-type] # noqa: S301
109109

110110

111-
ENV_VAR_MATCH: Collection[str] = {"1", "true"}
112-
113111
# Cache keys are a two-tuple of index name and absolute index directory
114112
# Cache values are a two-tuple of an opened Index instance and the count
115113
# of SearchIndex instances currently referencing that Index
116114
_OPENED_INDEX_CACHE: dict[tuple[str, str], tuple[Index, int]] = {}
117115
DONT_USE_OPENED_INDEX_CACHE = (
118-
os.environ.get("PQA_INDEX_DONT_CACHE_INDEXES", "").lower() in ENV_VAR_MATCH
116+
os.environ.get("PQA_INDEX_DONT_CACHE_INDEXES", "").lower() in VAR_MATCH_LOOKUP
119117
)
120118

121119

@@ -429,6 +427,17 @@ async def query(
429427
]
430428

431429

430+
def fetch_kwargs_from_manifest(
431+
file_location: str, manifest: dict[str, Any], manifest_fallback_location: str
432+
) -> dict[str, Any]:
433+
manifest_entry: DocDetails | None = manifest.get(file_location) or manifest.get(
434+
manifest_fallback_location
435+
)
436+
if manifest_entry:
437+
return manifest_entry.model_dump()
438+
return {}
439+
440+
432441
async def maybe_get_manifest(
433442
filename: anyio.Path | None = None,
434443
) -> dict[str, DocDetails]:
@@ -491,23 +500,17 @@ async def process_file(
491500
if not await search_index.filecheck(filename=file_location):
492501
logger.info(f"New file to index: {file_location}...")
493502

494-
doi, title = None, None
495-
if file_location in manifest:
496-
manifest_entry = manifest[file_location]
497-
doi, title = manifest_entry.doi, manifest_entry.title
498-
elif manifest_fallback_location in manifest:
499-
# Perhaps manifest used the opposite pathing scheme
500-
manifest_entry = manifest[manifest_fallback_location]
501-
doi, title = manifest_entry.doi, manifest_entry.title
503+
kwargs = fetch_kwargs_from_manifest(
504+
file_location, manifest, manifest_fallback_location
505+
)
502506

503507
tmp_docs = Docs()
504508
try:
505509
await tmp_docs.aadd(
506510
path=abs_file_path,
507-
title=title,
508-
doi=doi,
509511
fields=["title", "author", "journal", "year"],
510512
settings=settings,
513+
**kwargs,
511514
)
512515
except Exception as e:
513516
# We handle any exception here because we want to save_index so we
@@ -569,10 +572,10 @@ def _make_progress_bar_update(
569572
) -> tuple[contextlib.AbstractContextManager, Callable[[], Any] | None]:
570573
# Disable should override enable
571574
env_var_disable = (
572-
os.environ.get("PQA_INDEX_DISABLE_PROGRESS_BAR", "").lower() in ENV_VAR_MATCH
575+
os.environ.get("PQA_INDEX_DISABLE_PROGRESS_BAR", "").lower() in VAR_MATCH_LOOKUP
573576
)
574577
env_var_enable = (
575-
os.environ.get("PQA_INDEX_ENABLE_PROGRESS_BAR", "").lower() in ENV_VAR_MATCH
578+
os.environ.get("PQA_INDEX_ENABLE_PROGRESS_BAR", "").lower() in VAR_MATCH_LOOKUP
576579
)
577580
try:
578581
is_cli = is_running_under_cli() # pylint: disable=used-before-assignment

paperqa/clients/__init__.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,7 @@ async def upgrade_doc_to_doc_details(self, doc: Doc, **kwargs) -> DocDetails:
201201
# note we have some extra fields which may have come from reading the doc text,
202202
# but aren't in the doc object, we add them here too.
203203
extra_fields = {
204-
k: v for k, v in kwargs.items() if k in {"title", "authors", "doi"}
204+
k: v for k, v in kwargs.items() if k in set(DocDetails.model_fields)
205205
}
206206
# abuse our doc_details object to be an int if it's empty
207207
# our __add__ operation supports int by doing nothing
@@ -210,18 +210,21 @@ async def upgrade_doc_to_doc_details(self, doc: Doc, **kwargs) -> DocDetails:
210210
)
211211

212212
if doc_details := await self.query(**kwargs):
213-
if doc.overwrite_fields_from_metadata:
214-
return extra_doc + doc_details
215213

216214
# hard overwrite the details from the prior object
217-
doc_details.dockey = doc.dockey
218-
doc_details.doc_id = doc.dockey
219-
doc_details.docname = doc.docname
220-
doc_details.key = doc.docname
221-
doc_details.citation = doc.citation
215+
if "dockey" in doc.fields_to_overwrite_from_metadata:
216+
doc_details.dockey = doc.dockey
217+
if "doc_id" in doc.fields_to_overwrite_from_metadata:
218+
doc_details.doc_id = doc.dockey
219+
if "docname" in doc.fields_to_overwrite_from_metadata:
220+
doc_details.docname = doc.docname
221+
if "key" in doc.fields_to_overwrite_from_metadata:
222+
doc_details.key = doc.docname
223+
if "citation" in doc.fields_to_overwrite_from_metadata:
224+
doc_details.citation = doc.citation
222225
return extra_doc + doc_details
223226

224227
# if we can't get metadata, just return the doc, but don't overwrite any fields
225228
prior_doc = doc.model_dump()
226-
prior_doc["overwrite_fields_from_metadata"] = False
229+
prior_doc["fields_to_overwrite_from_metadata"] = set()
227230
return DocDetails(**(prior_doc | extra_fields))

paperqa/docs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -347,7 +347,7 @@ async def aadd( # noqa: PLR0912
347347
)
348348
# see if we can upgrade to DocDetails
349349
# if not, we can progress with a normal Doc
350-
# if "overwrite_fields_from_metadata" is used:
350+
# if "fields_to_overwrite_from_metadata" is used:
351351
# will map "docname" to "key", and "dockey" to "doc_id"
352352
if (title or doi) and parse_config.use_doc_details:
353353
if kwargs.get("metadata_client"):

paperqa/sources/clinical_trials.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,7 @@ def format_to_doc_details(trial_data: dict) -> DocDetails:
170170
year=year or None,
171171
citation=citation,
172172
other={"client_source": [CLINICAL_TRIALS_BASE]},
173-
overwrite_fields_from_metadata=False,
173+
fields_to_overwrite_from_metadata=set(),
174174
)
175175

176176

@@ -311,7 +311,7 @@ async def add_clinical_trials_to_docs(
311311
year=datetime.now().year,
312312
citation=f"Clinical Trials Search via ClinicalTrials.gov: {query}",
313313
other={"client_source": [CLINICAL_TRIALS_BASE]},
314-
overwrite_fields_from_metadata=False,
314+
fields_to_overwrite_from_metadata=set(),
315315
)
316316

317317
await docs.aadd_texts(

paperqa/types.py

Lines changed: 42 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -40,17 +40,26 @@
4040
logger = logging.getLogger(__name__)
4141

4242

43+
VAR_MATCH_LOOKUP: Collection[str] = {"1", "true"}
44+
VAR_MISMATCH_LOOKUP: Collection[str] = {"0", "false"}
45+
DEFAULT_FIELDS_TO_OVERWRITE_FROM_METADATA: Collection[str] = {
46+
"key",
47+
"doc_id",
48+
"docname",
49+
"dockey",
50+
"citation",
51+
}
52+
53+
4354
class Doc(Embeddable):
4455
model_config = ConfigDict(extra="forbid")
4556

4657
docname: str
4758
dockey: DocKey
4859
citation: str
49-
overwrite_fields_from_metadata: bool = Field(
50-
default=True,
51-
description=(
52-
"flag to overwrite fields from metadata when upgrading to a DocDetails"
53-
),
60+
fields_to_overwrite_from_metadata: set[str] = Field(
61+
default_factory=lambda: set(DEFAULT_FIELDS_TO_OVERWRITE_FROM_METADATA),
62+
description="fields from metadata to overwrite when upgrading to a DocDetails",
5463
)
5564

5665
@model_validator(mode="before")
@@ -410,7 +419,10 @@ def lowercase_doi_and_populate_doc_id(cls, data: dict[str, Any]) -> dict[str, An
410419
else:
411420
data["doc_id"] = encode_id(uuid4())
412421

413-
if data.get("overwrite_fields_from_metadata", True):
422+
if "dockey" in data.get(
423+
"fields_to_overwrite_from_metadata",
424+
DEFAULT_FIELDS_TO_OVERWRITE_FROM_METADATA,
425+
):
414426
data["dockey"] = data["doc_id"]
415427

416428
return data
@@ -502,10 +514,13 @@ def overwrite_docname_dockey_for_compatibility_w_doc(
502514
) -> dict[str, Any]:
503515
"""Overwrite fields from metadata if specified."""
504516
overwrite_fields = {"key": "docname", "doc_id": "dockey"}
505-
if data.get("overwrite_fields_from_metadata", True):
506-
for field, old_field in overwrite_fields.items():
507-
if data.get(field):
508-
data[old_field] = data[field]
517+
fields_to_overwrite = data.get(
518+
"fields_to_overwrite_from_metadata",
519+
DEFAULT_FIELDS_TO_OVERWRITE_FROM_METADATA,
520+
)
521+
for field in overwrite_fields.keys() & fields_to_overwrite:
522+
if data.get(field):
523+
data[overwrite_fields[field]] = data[field]
509524
return data
510525

511526
@classmethod
@@ -516,7 +531,7 @@ def populate_bibtex_key_citation( # noqa: PLR0912
516531
517532
Missing values, 'unknown' keys, and incomplete bibtex entries are regenerated.
518533
519-
When overwrite_fields_from_metadata:
534+
When fields_to_overwrite_from_metadata:
520535
If bibtex is regenerated, the citation field is also regenerated.
521536
522537
Otherwise we keep the citation field as is.
@@ -529,7 +544,10 @@ def populate_bibtex_key_citation( # noqa: PLR0912
529544
data.get("year") or CITATION_FALLBACK_DATA["year"], # type: ignore[arg-type]
530545
data.get("title") or CITATION_FALLBACK_DATA["title"], # type: ignore[arg-type]
531546
)
532-
if data.get("overwrite_fields_from_metadata", True):
547+
if "docname" in data.get(
548+
"fields_to_overwrite_from_metadata",
549+
DEFAULT_FIELDS_TO_OVERWRITE_FROM_METADATA,
550+
):
533551
data["docname"] = data["key"]
534552

535553
# even if we have a bibtex, it may not be complete, thus we need to add to it
@@ -591,26 +609,35 @@ def populate_bibtex_key_citation( # noqa: PLR0912
591609
entries={data["key"]: new_entry}
592610
).to_string("bibtex")
593611
# clear out the citation, since it will be regenerated
594-
if data.get("overwrite_fields_from_metadata", True):
612+
if "citation" in data.get(
613+
"fields_to_overwrite_from_metadata",
614+
DEFAULT_FIELDS_TO_OVERWRITE_FROM_METADATA,
615+
):
595616
data["citation"] = None
596617
except Exception:
597618
logger.warning(
598619
"Failed to generate bibtex for"
599620
f" {data.get('docname') or data.get('citation')}"
600621
)
601-
if not data.get("citation") and data.get("bibtex") is not None:
622+
if data.get("citation") is None and data.get("bibtex") is not None:
602623
data["citation"] = format_bibtex(
603624
data["bibtex"], missing_replacements=CITATION_FALLBACK_DATA # type: ignore[arg-type]
604625
)
605-
elif not data.get("citation"):
626+
elif data.get("citation") is None:
606627
data["citation"] = data.get("title") or CITATION_FALLBACK_DATA["title"]
607628
return data
608629

609630
@model_validator(mode="before")
610631
@classmethod
611632
def validate_all_fields(cls, data: Mapping[str, Any]) -> dict[str, Any]:
633+
612634
data = deepcopy(data) # Avoid mutating input
613635
data = dict(data)
636+
if isinstance(data.get("fields_to_overwrite_from_metadata"), str):
637+
data["fields_to_overwrite_from_metadata"] = {
638+
s.strip()
639+
for s in data.get("fields_to_overwrite_from_metadata", "").split(",")
640+
}
614641
data = cls.lowercase_doi_and_populate_doc_id(data)
615642
data = cls.remove_invalid_authors(data)
616643
data = cls.misc_string_cleaning(data)
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
file_location,doi,title,citation,fields_to_overwrite_from_metadata
2+
"bates.txt",,"Frederick Bates (Wikipedia article)","","key, doc_id, docname, dockey"

tests/test_agents.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -243,10 +243,31 @@ async def test_get_directory_index_w_manifest(agent_test_settings: Settings) ->
243243

244244
results = await index.query(query="who is Frederick Bates?")
245245
top_result = next(iter(results[0].docs.values()))
246-
assert top_result.dockey == md5sum(abs_paper_dir / "bates.txt")
246+
247+
# note: we get every possible field from the manifest constructed in maybe_get_manifest,
248+
# and then DocDetails construction sets the dockey to the doc_id.
249+
assert top_result.dockey == top_result.doc_id
247250
# note: this title comes from the manifest, so we know it worked
248251
assert top_result.title == "Frederick Bates (Wikipedia article)"
249252

253+
assert "wikipedia article" in top_result.citation.lower(), (
254+
"Other tests check we can override citation,"
255+
" so here we check here it's actually populated"
256+
)
257+
258+
259+
@pytest.mark.asyncio
260+
async def test_get_directory_index_w_no_citations(
261+
agent_test_settings: Settings,
262+
) -> None:
263+
agent_test_settings.agent.index.manifest_file = "stub_manifest_nocitation.csv"
264+
index = await get_directory_index(settings=agent_test_settings)
265+
266+
results = await index.query(query="who is Frederick Bates?")
267+
top_result = next(iter(results[0].docs.values()))
268+
269+
assert not top_result.citation
270+
250271

251272
@pytest.mark.flaky(reruns=2, only_rerun=["AssertionError", "httpx.RemoteProtocolError"])
252273
@pytest.mark.parametrize("agent_type", [FAKE_AGENT_TYPE, ToolSelector, SimpleAgent])

tests/test_paperqa.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1324,7 +1324,6 @@ def test_docdetails_deserialization() -> None:
13241324
"docname": "Stub",
13251325
"embedding": None,
13261326
"formatted_citation": "stub",
1327-
"overwrite_fields_from_metadata": True,
13281327
}
13291328
deepcopy_deserialize_to_doc = deepcopy(deserialize_to_doc)
13301329
doc = Doc(**deserialize_to_doc)
@@ -1338,7 +1337,6 @@ def test_docdetails_deserialization() -> None:
13381337
for key, value in {
13391338
"docname": "unknownauthorsUnknownyearunknowntitle",
13401339
"citation": "Unknown authors. Unknown title. Unknown journal, Unknown year.",
1341-
"overwrite_fields_from_metadata": True,
13421340
"key": "unknownauthorsUnknownyearunknowntitle",
13431341
"bibtex": (
13441342
'@article{unknownauthorsUnknownyearunknowntitle,\n author = "authors,'

0 commit comments

Comments
 (0)