Skip to content

Commit f1f7d5b

Browse files
authored
Ensuring manifest deserialization works for non-primitive fields (#1011)
1 parent 31e01bc commit f1f7d5b

File tree

4 files changed

+72
-7
lines changed

4 files changed

+72
-7
lines changed

README.md

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -760,7 +760,8 @@ The paper directory is not modified in any way, it's just read from.
760760
The indexing process attempts to infer paper metadata like title and DOI
761761
using LLM-powered text processing.
762762
You can avoid this point of uncertainty using a "manifest" file,
763-
which is a CSV containing three columns (order doesn't matter):
763+
which is a CSV containing `DocDetails` fields (order doesn't matter).
764+
For example:
764765

765766
- `file_location`: relative path to the paper's PDF within the index directory
766767
- `doi`: DOI of the paper
@@ -769,6 +770,9 @@ which is a CSV containing three columns (order doesn't matter):
769770
By providing this information,
770771
we ensure queries to metadata providers like Crossref are accurate.
771772

773+
To ease creating a manifest, there is a helper class method `Doc.to_csv`,
774+
which also works when called on `DocDetails`.
775+
772776
### Reusing Index
773777

774778
The local search indexes are built based on a hash of the current `Settings` object.

paperqa/agents/search.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import pathlib
99
import pickle
1010
import re
11+
import sys
1112
import warnings
1213
import zlib
1314
from collections import Counter
@@ -451,9 +452,12 @@ async def maybe_get_manifest(
451452
try:
452453
async with await anyio.open_file(filename, mode="r") as file:
453454
content = await file.read()
455+
reader_kwargs: dict[str, Any] = {}
456+
if sys.version_info >= (3, 12): # Unlocks `bool | None` fields
457+
reader_kwargs["quoting"] = csv.QUOTE_NOTNULL
454458
file_loc_to_records = {
455459
str(r.get("file_location")): r
456-
for r in csv.DictReader(content.splitlines())
460+
for r in csv.DictReader(content.splitlines(), **reader_kwargs)
457461
if r.get("file_location")
458462
}
459463
if not file_loc_to_records:

paperqa/types.py

Lines changed: 46 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,16 @@
11
from __future__ import annotations
22

3+
import ast
4+
import csv
35
import logging
46
import os
57
import re
68
import warnings
7-
from collections.abc import Collection, Mapping
9+
from collections.abc import Collection, Iterable, Mapping, Sequence
810
from copy import deepcopy
911
from datetime import datetime
1012
from enum import StrEnum
11-
from typing import Annotated, Any, ClassVar, cast
13+
from typing import Annotated, Any, ClassVar, Self, cast
1214
from uuid import UUID, uuid4
1315

1416
import tiktoken
@@ -99,6 +101,33 @@ def matches_filter_criteria(self, filter_criteria: Mapping[str, Any]) -> bool:
99101
return False
100102
return True
101103

104+
FIELDS_TO_EXCLUDE_FROM_CSV: ClassVar[set[str]] = {
105+
"embedding", # Don't store to allow for configuration of embedding models
106+
}
107+
CSV_FIELDS_UP_FRONT: ClassVar[Sequence[str]] = ()
108+
109+
@classmethod
110+
def to_csv(cls, values: Iterable[Self], target_csv_path: str | os.PathLike) -> None:
111+
"""Dump many instances into a CSV, for later use as a manifest."""
112+
headers = set(cls.model_fields) - cls.FIELDS_TO_EXCLUDE_FROM_CSV
113+
with open(target_csv_path, "w", encoding="utf-8") as f:
114+
writer = csv.DictWriter(
115+
f,
116+
fieldnames=[
117+
*sorted(cls.CSV_FIELDS_UP_FRONT), # Make easy reading
118+
*sorted(headers - set(cls.CSV_FIELDS_UP_FRONT)),
119+
],
120+
)
121+
writer.writeheader()
122+
writer.writerows(
123+
[
124+
v.model_dump(
125+
exclude={"formatted_citation"} | cls.FIELDS_TO_EXCLUDE_FROM_CSV
126+
)
127+
for v in values
128+
]
129+
)
130+
102131

103132
class Text(Embeddable):
104133
"""A text chunk ready for use in retrieval with a linked document."""
@@ -565,6 +594,11 @@ class DocDetails(Doc):
565594
"http://dx.doi.org/",
566595
}
567596
AUTHOR_NAMES_TO_REMOVE: ClassVar[Collection[str]] = {"et al", "et al."}
597+
FIELDS_TO_EXCLUDE_FROM_CSV: ClassVar[set[str]] = {
598+
"bibtex", # Let this be autogenerated, to avoid dealing with newlines
599+
"embedding", # Don't store to allow for configuration of embedding models
600+
}
601+
CSV_FIELDS_UP_FRONT: ClassVar[Sequence[str]] = ("doi", "file_location")
568602

569603
@field_validator("key")
570604
@classmethod
@@ -805,10 +839,18 @@ def validate_all_fields(cls, data: Mapping[str, Any]) -> dict[str, Any]:
805839
data = deepcopy(data) # Avoid mutating input
806840
data = dict(data)
807841
if isinstance(data.get("fields_to_overwrite_from_metadata"), str):
842+
raw_value = data["fields_to_overwrite_from_metadata"]
843+
if (raw_value[0], raw_value[-1]) in {("[", "]"), ("{", "}")}:
844+
# If string-ified set or list, remove brackets before split
845+
raw_value = raw_value[1:-1]
808846
data["fields_to_overwrite_from_metadata"] = {
809-
s.strip()
810-
for s in data.get("fields_to_overwrite_from_metadata", "").split(",")
847+
s.strip("\"' ") for s in raw_value.split(",")
811848
}
849+
for possibly_str_field in ("authors", "other"):
850+
if data.get(possibly_str_field) and isinstance(
851+
data[possibly_str_field], str
852+
):
853+
data[possibly_str_field] = ast.literal_eval(data[possibly_str_field])
812854
data = cls.lowercase_doi_and_populate_doc_id(data)
813855
data = cls.remove_invalid_authors(data)
814856
data = cls.misc_string_cleaning(data)

tests/test_paperqa.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
import contextlib
2+
import csv
23
import os
34
import pathlib
45
import pickle
56
import re
7+
import sys
68
from collections.abc import AsyncIterable, Sequence
79
from copy import deepcopy
810
from datetime import datetime, timedelta
@@ -1473,13 +1475,15 @@ def test_docdetails_merge_with_list_fields() -> None:
14731475
assert isinstance(merged_doc, DocDetails), "Merged doc should also be DocDetails"
14741476

14751477

1476-
def test_docdetails_deserialization() -> None:
1478+
@pytest.mark.skipif(sys.version_info < (3, 12), reason="Uses `csv.QUOTE_NOTNULL`.")
1479+
def test_docdetails_deserialization(tmp_path) -> None:
14771480
deserialize_to_doc = {
14781481
"citation": "stub",
14791482
"dockey": "stub",
14801483
"docname": "Stub",
14811484
"embedding": None,
14821485
"formatted_citation": "stub",
1486+
"fields_to_overwrite_from_metadata": {"key", "doc_id", "docname", "citation"},
14831487
}
14841488
deepcopy_deserialize_to_doc = deepcopy(deserialize_to_doc)
14851489
doc = Doc(**deserialize_to_doc)
@@ -1510,6 +1514,17 @@ def test_docdetails_deserialization() -> None:
15101514
deserialize_to_doc == deepcopy_deserialize_to_doc
15111515
), "Deserialization should not mutate input"
15121516

1517+
doc_details = DocDetails(
1518+
**deserialize_to_doc, other={"apple": "sauce"}, authors=["Thomas Anderson"]
1519+
)
1520+
DocDetails.to_csv([doc_details], target_csv_path=Path(tmp_path) / "manifest.csv")
1521+
with open(tmp_path / "manifest.csv", encoding="utf-8") as f:
1522+
csv_deserialized = DocDetails(
1523+
# type ignore comments are here since mypy can't recognize pytest skip
1524+
**next(csv.DictReader(f.readlines(), quoting=csv.QUOTE_NOTNULL)) # type: ignore[attr-defined,unused-ignore]
1525+
)
1526+
assert doc_details == csv_deserialized, "Round-trip CSV deserialization failed"
1527+
15131528

15141529
def test_docdetails_doc_id_roundtrip() -> None:
15151530
"""Test that DocDetails can be initialized with doc_id or doi inputs."""

0 commit comments

Comments
 (0)