Skip to content

Commit 641f583

Browse files
authored
Fixed type problems from llmclient (#770)
Co-authored-by: James Braza <[email protected]> - Corrected typing in `DocDetails` - Corrected the typing of `Text.doc` from `Doc` to `Doc | DocDetails` and updated how `Text.doc` is created - Defined `AUTOPOPULATE_VALUE` to declare default values for `DocDetails` entries - Changed input data typing on `@model_validator from `dict` to `Mapping` to ensure the immutability of the input. - Tests for serialization/deserialization and immutability of input data were included.
1 parent 7bb570c commit 641f583

File tree

14 files changed

+108
-33
lines changed

14 files changed

+108
-33
lines changed

paperqa/agents/helpers.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -92,11 +92,13 @@ def table_formatter(
9292
table.add_column("Title", style="cyan")
9393
table.add_column("File", style="magenta")
9494
for obj, filename in objects:
95-
try:
96-
display_name = cast(DocDetails, cast(Docs, obj).texts[0].doc).title
97-
except AttributeError:
98-
display_name = cast(Docs, obj).texts[0].doc.formatted_citation
99-
table.add_row(cast(str, display_name)[:max_chars_per_column], filename)
95+
docs = cast(Docs, obj) # Assume homogeneous objects
96+
doc = docs.texts[0].doc
97+
if isinstance(doc, DocDetails) and doc.title:
98+
display_name: str = doc.title # Prefer title if available
99+
else:
100+
display_name = doc.formatted_citation
101+
table.add_row(display_name[:max_chars_per_column], filename)
100102
return table
101103
raise NotImplementedError(
102104
f"Object type {type(example_object)} can not be converted to table."

paperqa/agents/search.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import warnings
1111
import zlib
1212
from collections.abc import Callable, Collection, Sequence
13+
from datetime import datetime
1314
from enum import StrEnum, auto
1415
from typing import TYPE_CHECKING, Any, ClassVar
1516
from uuid import UUID
@@ -70,6 +71,8 @@ def default(self, o):
7071
return list(o)
7172
if isinstance(o, os.PathLike):
7273
return str(o)
74+
if isinstance(o, datetime):
75+
return o.isoformat()
7376
return json.JSONEncoder.default(self, o)
7477

7578

paperqa/clients/crossref.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,7 @@ async def parse_crossref_to_doc_details(
197197
elif len(date_parts) == 1:
198198
publication_date = datetime(date_parts[0], 1, 1)
199199

200-
doc_details = DocDetails( # type: ignore[call-arg]
200+
doc_details = DocDetails(
201201
key=None if not bibtex else bibtex.split("{")[1].split(",")[0],
202202
bibtex_type=CROSSREF_CONTENT_TYPE_TO_BIBTEX_MAPPING.get(
203203
message.get("type", "other"), "misc"

paperqa/clients/journal_quality.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ async def _process(
4444
# docname can be blank since the validation will add it
4545
# remember, if both have docnames (i.e. key) they are
4646
# wiped and re-generated with resultant data
47-
return doc_details + DocDetails( # type: ignore[call-arg]
47+
return doc_details + DocDetails(
4848
source_quality=max(
4949
[
5050
self.data.get(query.journal.casefold(), DocDetails.UNDEFINED_JOURNAL_QUALITY), # type: ignore[union-attr]

paperqa/clients/openalex.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,7 @@ def parse_openalex_to_doc_details(message: dict[str, Any]) -> DocDetails:
178178

179179
bibtex_type = BIBTEX_MAPPING.get(message.get("type") or "other", "misc")
180180

181-
return DocDetails( # type: ignore[call-arg]
181+
return DocDetails(
182182
key=None,
183183
bibtex_type=bibtex_type,
184184
bibtex=None,

paperqa/clients/retractions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ async def _process(self, query: DOIQuery, doc_details: DocDetails) -> DocDetails
7171
if not self.doi_set:
7272
await self.load_data()
7373

74-
return doc_details + DocDetails(is_retracted=query.doi in self.doi_set) # type: ignore[call-arg]
74+
return doc_details + DocDetails(is_retracted=query.doi in self.doi_set)
7575

7676
def query_creator(self, doc_details: DocDetails, **kwargs) -> DOIQuery | None:
7777
try:

paperqa/clients/semantic_scholar.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,7 @@ async def parse_s2_to_doc_details(
182182

183183
journal_data = paper_data.get("journal") or {}
184184

185-
doc_details = DocDetails( # type: ignore[call-arg]
185+
doc_details = DocDetails(
186186
key=None if not bibtex else bibtex.split("{")[1].split(",")[0],
187187
bibtex_type="article", # s2 should be basically all articles
188188
bibtex=bibtex,

paperqa/clients/unpaywall.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ def _create_doc_details(self, data: UnpaywallResponse) -> DocDetails:
166166
if data.best_oa_location:
167167
pdf_url = data.best_oa_location.url_for_pdf
168168
license = data.best_oa_location.license # noqa: A001
169-
return DocDetails( # type: ignore[call-arg]
169+
return DocDetails(
170170
authors=[
171171
f"{author.given} {author.family}" for author in (data.z_authors or [])
172172
],

paperqa/core.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ async def map_fxn_summary(
108108
text=Text(
109109
text=text.text,
110110
name=text.name,
111-
doc=text.doc.__class__(**text.doc.model_dump(exclude={"embedding"})),
111+
doc=text.doc.model_dump(exclude={"embedding"}),
112112
),
113113
score=score, # pylint: disable=possibly-used-before-assignment
114114
**extras,

paperqa/docs.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -470,10 +470,12 @@ async def aadd_texts(
470470
# 3. Update self
471471
# NOTE: we defer adding texts to the texts index to retrieval time
472472
# (e.g. `self.texts_index.add_texts_and_embeddings(texts)`)
473-
self.docs[doc.dockey] = doc
474-
self.texts += texts
475-
self.docnames.add(doc.docname)
476-
return True
473+
if doc.docname and doc.dockey:
474+
self.docs[doc.dockey] = doc
475+
self.texts += texts
476+
self.docnames.add(doc.docname)
477+
return True
478+
return False
477479

478480
def delete(
479481
self,
@@ -489,8 +491,9 @@ def delete(
489491
doc = next((doc for doc in self.docs.values() if doc.docname == name), None)
490492
if doc is None:
491493
return
492-
self.docnames.remove(doc.docname)
493-
dockey = doc.dockey
494+
if doc.docname and doc.dockey:
495+
self.docnames.remove(doc.docname)
496+
dockey = doc.dockey
494497
del self.docs[dockey]
495498
self.deleted_dockeys.add(dockey)
496499
self.texts = list(filter(lambda x: x.doc.dockey != dockey, self.texts))

0 commit comments

Comments
 (0)