diff --git a/src/paperqa/clients/__init__.py b/src/paperqa/clients/__init__.py index 64ee54a8..81adc588 100644 --- a/src/paperqa/clients/__init__.py +++ b/src/paperqa/clients/__init__.py @@ -208,16 +208,16 @@ async def bulk_query( ) async def upgrade_doc_to_doc_details(self, doc: Doc, **kwargs) -> DocDetails: - - # note we have some extra fields which may have come from reading the doc text, - # but aren't in the doc object, we add them here too. - extra_fields = { + # Collect fields (e.g. title, DOI, or authors) that have been externally + # specified (e.g. by a caller, or inferred from the document's contents) + # but are not on the input `doc` object + provided_fields = { k: v for k, v in kwargs.items() if k in set(DocDetails.model_fields) } - # abuse our doc_details object to be an int if it's empty - # our __add__ operation supports int by doing nothing - extra_doc: int | DocDetails = ( - 0 if not extra_fields else DocDetails(**extra_fields) + # DocDetails.__add__ supports `int` as a no-op route, so if we have no + # provided fields, let's use that no-op route + provided_doc_details: int | DocDetails = ( + 0 if not provided_fields else DocDetails(**provided_fields) ) if doc_details := await self.query(**kwargs): @@ -233,9 +233,10 @@ async def upgrade_doc_to_doc_details(self, doc: Doc, **kwargs) -> DocDetails: doc_details.key = doc.docname if "citation" in doc.fields_to_overwrite_from_metadata: doc_details.citation = doc.citation - return extra_doc + doc_details + if "content_hash" in doc.fields_to_overwrite_from_metadata: + doc_details.content_hash = doc.content_hash + return provided_doc_details + doc_details # if we can't get metadata, just return the doc, but don't overwrite any fields - prior_doc = doc.model_dump() - prior_doc["fields_to_overwrite_from_metadata"] = set() - return DocDetails(**(prior_doc | extra_fields)) + orig_fields = doc.model_dump() | {"fields_to_overwrite_from_metadata": set()} + return DocDetails(**(orig_fields | provided_fields)) diff --git a/src/paperqa/docs.py b/src/paperqa/docs.py index 37a1bbb5..d7dcdffd 100644 --- a/src/paperqa/docs.py +++ b/src/paperqa/docs.py @@ -265,10 +265,10 @@ async def aadd( # noqa: PLR0912 """Add a document to the collection.""" all_settings = get_settings(settings) parse_config = all_settings.parsing + content_hash = md5sum(path) dockey_is_content_hash = False if dockey is None: - # md5 sum of file contents (not path!) - dockey = md5sum(path) + dockey = content_hash dockey_is_content_hash = True if llm_model is None: llm_model = all_settings.get_llm() @@ -276,7 +276,9 @@ async def aadd( # noqa: PLR0912 # Peek first chunk texts = await read_doc( path, - Doc(docname="", citation="", dockey=dockey), # Fake doc + Doc( # Fake doc + docname="", citation="", dockey=dockey, content_hash=content_hash + ), chunk_chars=parse_config.chunk_size, overlap=parse_config.overlap, page_size_limit=parse_config.page_size_limit, @@ -307,6 +309,7 @@ async def aadd( # noqa: PLR0912 ), citation=citation, dockey=dockey, + content_hash=content_hash, ) # try to extract DOI / title from the citation @@ -360,6 +363,7 @@ async def aadd( # noqa: PLR0912 clients=kwargs.pop("clients", DEFAULT_CLIENTS), ) + # Query here means a query to a metadata provider query_kwargs: dict[str, Any] = {} if doi: diff --git a/src/paperqa/llms.py b/src/paperqa/llms.py index 0d547a50..5676a170 100644 --- a/src/paperqa/llms.py +++ b/src/paperqa/llms.py @@ -30,7 +30,7 @@ ) from typing_extensions import override -from paperqa.types import Doc, Text +from paperqa.types import AUTOPOPULATE_VALUE, Doc, Text if TYPE_CHECKING: from qdrant_client.http.models import Record @@ -495,6 +495,7 @@ async def fetch_batch_with_semaphore(offset: int) -> None: docname=doc_data.get("docname", ""), citation=doc_data.get("citation", ""), dockey=doc_data["dockey"], + content_hash=doc_data.get("content_hash", AUTOPOPULATE_VALUE), ) docs.docnames.add(doc_data.get("docname", "")) diff --git a/src/paperqa/types.py b/src/paperqa/types.py index cabd8202..9be3465d 100644 --- a/src/paperqa/types.py +++ b/src/paperqa/types.py @@ -1,6 +1,7 @@ from __future__ import annotations import ast +import contextlib import csv import hashlib import json @@ -37,11 +38,13 @@ from paperqa.utils import ( bytes_to_string, + compute_unique_doc_id, create_bibtex_key, encode_id, format_bibtex, get_citation_ids, maybe_get_date, + md5sum, string_to_bytes, ) from paperqa.version import __version__ as pqa_version @@ -61,7 +64,10 @@ "docname", "dockey", "citation", + "content_hash", # Metadata providers won't give this } +# Sentinel to autopopulate a field within model_validator +AUTOPOPULATE_VALUE = "" # NOTE: this is falsy by design class Doc(Embeddable): @@ -70,6 +76,13 @@ class Doc(Embeddable): docname: str dockey: DocKey citation: str + content_hash: str | None = Field( + default=AUTOPOPULATE_VALUE, + description=( + "Optional hash of the document's contents (to reiterate, not a file path to" + " the document, but the document's contents itself)." + ), + ) # Sort the serialization to minimize the diff of serialized objects fields_to_overwrite_from_metadata: Annotated[set[str], PlainSerializer(sorted)] = ( Field( @@ -171,10 +184,6 @@ def __hash__(self) -> int: return hash((self.name, self.text)) -# Sentinel to autopopulate a field within model_validator -AUTOPOPULATE_VALUE = "" # NOTE: this is falsy by design - - class Context(BaseModel): """A class to hold the context of a question.""" @@ -660,11 +669,17 @@ class DocDetails(Doc): doc_id: str | None = Field( default=None, description=( - "Unique ID for this document. Simple ways to acquire one include" - " hashing the DOI or a stringifying a UUID." + "Unique ID for this document. A simple and robust way to acquire one is" + " hashing the paper content's hash concatenate with the lowercased DOI." + ), + ) + file_location: str | os.PathLike | None = Field( + default=None, + description=( + "Optional path to the stored paper, if stored locally" + " or in a mountable location such as a cloud bucket." ), ) - file_location: str | os.PathLike | None = None license: str | None = Field( default=None, description=( @@ -713,10 +728,10 @@ def lowercase_doi_and_populate_doc_id(cls, data: dict[str, Any]) -> dict[str, An if doi.startswith(url_prefix_to_remove): doi = doi.replace(url_prefix_to_remove, "") data["doi"] = doi.lower() - if "doc_id" not in data or not data["doc_id"]: # keep user defined doc_ids - data["doc_id"] = encode_id(doi.lower()) - elif "doc_id" not in data or not data["doc_id"]: # keep user defined doc_ids - data["doc_id"] = encode_id(uuid4()) + if not data.get("doc_id"): # keep user defined doc_ids + data["doc_id"] = compute_unique_doc_id(doi, data.get("content_hash")) + elif not data.get("doc_id"): # keep user defined doc_ids + data["doc_id"] = compute_unique_doc_id(doi, data.get("content_hash")) if "dockey" in data.get( "fields_to_overwrite_from_metadata", @@ -927,6 +942,17 @@ def populate_bibtex_key_citation(cls, data: dict[str, Any]) -> dict[str, Any]: data["citation"] = data.get("title") or CITATION_FALLBACK_DATA["title"] return data + @classmethod + def populate_content_hash(cls, data: dict[str, Any]) -> dict[str, Any]: + if ( # Check for missing or autopopulate value, but preserve `None` + data.get("content_hash", AUTOPOPULATE_VALUE) == AUTOPOPULATE_VALUE + ): + data["content_hash"] = None # Assume we don't have it + if data.get("file_location"): # Try to update it + with contextlib.suppress(FileNotFoundError): + data["content_hash"] = md5sum(data["file_location"]) + return data + @model_validator(mode="before") @classmethod def validate_all_fields(cls, data: Mapping[str, Any]) -> dict[str, Any]: @@ -946,6 +972,7 @@ def validate_all_fields(cls, data: Mapping[str, Any]) -> dict[str, Any]: data[possibly_str_field], str ): data[possibly_str_field] = ast.literal_eval(data[possibly_str_field]) + data = cls.populate_content_hash(data) data = cls.lowercase_doi_and_populate_doc_id(data) data = cls.remove_invalid_authors(data) data = cls.misc_string_cleaning(data) @@ -1046,7 +1073,7 @@ def __add__(self, other: DocDetails | int) -> DocDetails: # noqa: PLR0912 if self.publication_date and other.publication_date: PREFER_OTHER = self.publication_date <= other.publication_date - merged_data = {} + merged_data: dict[str, Any] = {} # pylint: disable-next=not-an-iterable # pylint bug: https://github.com/pylint-dev/pylint/issues/10144 for field in type(self).model_fields: self_value = getattr(self, field) @@ -1086,11 +1113,11 @@ def __add__(self, other: DocDetails | int) -> DocDetails: # noqa: PLR0912 ) else other.authors ) - merged_data[field] = best_authors or None # type: ignore[assignment] + merged_data[field] = best_authors or None elif field == "key" and self_value is not None and other_value is not None: # if we have multiple keys, we wipe them and allow regeneration - merged_data[field] = None # type: ignore[assignment] + merged_data[field] = None elif field in {"citation_count", "year", "publication_date"}: # get the latest data @@ -1106,6 +1133,12 @@ def __add__(self, other: DocDetails | int) -> DocDetails: # noqa: PLR0912 ) else: merged_data[field] = max(self_value, other_value) + elif field == "content_hash" and ( + self_value and other_value and self_value != other_value + ): + # If hashes are both present but differ, + # we don't know which to pick, so just discard the value + merged_data[field] = None else: # Prefer non-null values, default preference for 'other' object. @@ -1120,10 +1153,13 @@ def __add__(self, other: DocDetails | int) -> DocDetails: # noqa: PLR0912 else self_value ) - # Recalculate doc_id if doi has changed - if merged_data["doi"] != self.doi: - merged_data["doc_id"] = ( - encode_id(merged_data["doi"].lower()) if merged_data["doi"] else None # type: ignore[attr-defined,assignment] + if ( + merged_data["doi"] != self.doi + or merged_data["content_hash"] != self.content_hash + ): + # Recalculate doc_id if doi or content hash has changed + merged_data["doc_id"] = compute_unique_doc_id( + merged_data["doi"], merged_data.get("content_hash") ) # Create and return new DocDetails instance diff --git a/src/paperqa/utils.py b/src/paperqa/utils.py index 25397828..951ebe08 100644 --- a/src/paperqa/utils.py +++ b/src/paperqa/utils.py @@ -17,7 +17,7 @@ from http import HTTPStatus from pathlib import Path from typing import TYPE_CHECKING, Any, BinaryIO, ClassVar, TypeVar -from uuid import UUID +from uuid import UUID, uuid4 import aiohttp import httpx @@ -104,7 +104,7 @@ def strings_similarity(s1: str, s2: str, case_insensitive: bool = True) -> float def hexdigest(data: str | bytes) -> str: if isinstance(data, str): - return hashlib.md5(data.encode("utf-8")).hexdigest() # noqa: S324 + data = data.encode("utf-8") return hashlib.md5(data).hexdigest() # noqa: S324 @@ -217,6 +217,14 @@ def encode_id(value: str | bytes | UUID, maxsize: int | None = 16) -> str: return hashlib.md5(value).hexdigest()[:maxsize] # noqa: S324 +def compute_unique_doc_id(doi: str | None, content_hash: str | None) -> str: + if doi: + value_to_encode: str = doi.lower() + (content_hash or "") + else: + value_to_encode = content_hash or str(uuid4()) + return encode_id(value_to_encode) + + def get_year(ts: datetime | None = None) -> str: """Get the year from the input datetime, otherwise using the current datetime.""" if ts is None: diff --git a/tests/test_agents.py b/tests/test_agents.py index 430834e1..7ed1893d 100644 --- a/tests/test_agents.py +++ b/tests/test_agents.py @@ -60,7 +60,7 @@ from paperqa.prompts import CANNOT_ANSWER_PHRASE, CONTEXT_INNER_PROMPT_NOT_DETAILED from paperqa.settings import AgentSettings, IndexSettings, Settings from paperqa.types import Context, Doc, PQASession, Text -from paperqa.utils import encode_id, extract_thought, get_year, md5sum +from paperqa.utils import compute_unique_doc_id, extract_thought, get_year, md5sum @pytest.mark.asyncio @@ -117,11 +117,16 @@ async def test_get_directory_index( results = await index.query(query="what is a gravity hill?", min_score=5) assert results first_result = results[0] + assert len(first_result.docs) == 1, "Expected one result (gravity_hill.md)" target_doc_path = (paper_dir / "gravity_hill.md").absolute() expected_ids = { - md5sum(target_doc_path), # What we actually expect - encode_id( - "10.2307/j.ctt5vkfh7.11" # Crossref may match this Gravity Hill poem, lol + compute_unique_doc_id( + None, + md5sum(target_doc_path), # What we actually expect + ), + compute_unique_doc_id( + "10.2307/j.ctt5vkfh7.11", # Crossref may match this Gravity Hill poem, lol + next(iter(first_result.docs.values())).content_hash, ), } for expected_id in expected_ids: @@ -132,9 +137,9 @@ async def test_get_directory_index( f"Failed to match an ID in {expected_ids}, got citations" f" {[d.formatted_citation for d in first_result.docs.values()]}." ) - assert all( - x in first_result.docs[expected_id].formatted_citation - for x in ("Wikipedia", "Gravity") + assert ( + "gravity hill" + in first_result.docs[expected_id].formatted_citation.lower() ) # Check getting the same index name will not reprocess files diff --git a/tests/test_paperqa.py b/tests/test_paperqa.py index 8cb21e2d..7f729ea9 100644 --- a/tests/test_paperqa.py +++ b/tests/test_paperqa.py @@ -17,6 +17,7 @@ from unittest.mock import MagicMock, call from uuid import UUID +import anyio import httpx import litellm import numpy as np @@ -752,9 +753,11 @@ async def test_get_reasoning(docs_fixture: Docs, llm: str, llm_settings: dict) - @pytest.mark.asyncio -async def test_duplicate(stub_data_dir: Path) -> None: +async def test_duplicate(stub_data_dir: Path, tmp_path) -> None: """Check Docs doesn't store duplicates, while checking nonduplicate docs are stored.""" docs = Docs() + + # First, check adding a straight-up duplicate doc assert await docs.aadd( stub_data_dir / "bates.txt", citation="WikiMedia Foundation, 2023, Accessed now", @@ -767,16 +770,44 @@ async def test_duplicate(stub_data_dir: Path) -> None: dockey="test1", ) is None - ) + ), "Expected duplicate add to indicate no new doc was added" assert len(docs.docs) == 1, "Should have added only one document" + + # Next, check adding a different doc works, and also check citation inference + common_doi = "10.1234/flag" assert await docs.aadd( - stub_data_dir / "flag_day.html", - citation="WikiMedia Foundation, 2023, Accessed now", - dockey="test2", + stub_data_dir / "flag_day.html", dockey="flag_day", doi=common_doi ) assert ( - len(set(docs.docs.values())) == 2 + len(set(docs.docs.keys())) == 2 ), "Unique documents should be hashed as unique" + flag_day = docs.docs["flag_day"] + assert isinstance(flag_day, DocDetails) + assert flag_day.doi == common_doi + assert all( + x in flag_day.citation.lower() for x in ("wikipedia", "flag") + ), "Expected citation to be inferred" + assert flag_day.content_hash + + # Now, check adding a different file but same metadata + # (emulating main text vs supplemental information) + # will be seen as a different doc + flag_day_content = await anyio.Path(stub_data_dir / "flag_day.html").read_bytes() + assert len(flag_day_content) >= 1000, "Expected long file to test truncation" + await anyio.Path(tmp_path / "flag_day.html").write_bytes(flag_day_content[:-100]) + assert await docs.aadd( + tmp_path / "flag_day.html", dockey="flag_day_shorter", doi=common_doi + ) + assert len(set(docs.docs.keys())) == 3, "Expected a third document to be added" + shorter_flag_day = docs.docs["flag_day_shorter"] + assert isinstance(shorter_flag_day, DocDetails) + assert shorter_flag_day.doi == common_doi + assert all( + x in shorter_flag_day.citation.lower() for x in ("wikipedia", "flag") + ), "Expected citation to be inferred" + assert shorter_flag_day.content_hash + assert flag_day.content_hash != shorter_flag_day.content_hash + assert flag_day.doc_id != shorter_flag_day.doc_id @pytest.mark.asyncio @@ -1058,6 +1089,7 @@ async def test_pdf_reader_w_no_match_doc_details(stub_data_dir: Path) -> None: "Wellawatte et al, XAI Review, 2023", ) (doc_details,) = docs.docs.values() + assert doc_details.content_hash == "41f786fcc56d27ff0c1507153fae3774" assert doc_details.docname == docname, "Added name should match between details" # doc will be a DocDetails object, but nothing can be found # thus, we retain the prior citation data @@ -1147,10 +1179,12 @@ async def test_pdf_reader_match_doc_details(stub_data_dir: Path) -> None: fields=["author", "journal", "citation_count"], ) (doc_details,) = docs.docs.values() + assert doc_details.content_hash == "41f786fcc56d27ff0c1507153fae3774" assert doc_details.docname == docname, "Added name should match between details" # Crossref is non-deterministic in its ordering for results + # (it can give DOI '10.1021/acs.jctc.2c01235' or DOI '10.26434/chemrxiv-2022-qfv02') # thus we need to capture both possible dockeys - assert doc_details.dockey in {"d7763485f06aabde", "5300ef1d5fb960d7"} + assert doc_details.dockey in {"8ce7ddba9c9dcae6", "a353fa2478475c9c"} assert isinstance(doc_details, DocDetails) # note year is unknown because citation string is only parsed for authors/title/doi # AND we do not request it back from the metadata sources