Skip to content

Commit e5b6447

Browse files
Ensure fields are lists before combining in DocDetails class (#801)
1 parent 8b41c1a commit e5b6447

File tree

3 files changed

+81
-0
lines changed

3 files changed

+81
-0
lines changed

.mailmap

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,3 +8,5 @@ Odhran O'Donoghue <[email protected]> <[email protected]
88
99
1010
Mayk Caldas <[email protected]> maykcaldas <[email protected]>
11+
12+
Harry Vu <[email protected]> harryvu-futurehouse

paperqa/types.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -686,6 +686,16 @@ def __add__(self, other: DocDetails | int) -> DocDetails: # noqa: PLR0912
686686
merged_data[field] = {**self.other, **other.other}
687687
# handle the bibtex / sources as special fields
688688
for field_to_combine in ("bibtex_source", "client_source"):
689+
# Ensure the fields are lists before combining
690+
if self.other.get(field_to_combine) and not isinstance(
691+
self.other[field_to_combine], list
692+
):
693+
self.other[field_to_combine] = [self.other[field_to_combine]]
694+
if other.other.get(field_to_combine) and not isinstance(
695+
other.other[field_to_combine], list
696+
):
697+
other.other[field_to_combine] = [other.other[field_to_combine]]
698+
689699
if self.other.get(field_to_combine) and other.other.get(
690700
field_to_combine
691701
):

tests/test_paperqa.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import textwrap
77
from collections.abc import AsyncIterable, Sequence
88
from copy import deepcopy
9+
from datetime import datetime, timedelta
910
from io import BytesIO
1011
from pathlib import Path
1112
from typing import cast
@@ -1239,6 +1240,74 @@ def test_dois_resolve_to_correct_journals(doi_journals):
12391240
assert details.journal == doi_journals["journal"]
12401241

12411242

1243+
def test_docdetails_merge_with_non_list_fields() -> None:
1244+
"""Check republication where the source metadata has different shapes."""
1245+
initial_date = datetime(2023, 1, 1)
1246+
doc1 = DocDetails(
1247+
citation="Citation 1",
1248+
publication_date=initial_date,
1249+
docname="Document 1",
1250+
dockey="key1",
1251+
# NOTE: doc1 has non-list bibtex_source and list client_source
1252+
other={"bibtex_source": "source1", "client_source": ["client1"]},
1253+
)
1254+
1255+
later_publication_date = initial_date + timedelta(weeks=13)
1256+
doc2 = DocDetails(
1257+
citation=doc1.citation,
1258+
publication_date=later_publication_date,
1259+
docname=doc1.docname,
1260+
dockey=doc1.dockey,
1261+
# NOTE: doc2 has list bibtex_source and non-list client_source
1262+
other={"bibtex_source": ["source2"], "client_source": "client2"},
1263+
)
1264+
1265+
# Merge the two DocDetails instances
1266+
merged_doc = doc1 + doc2
1267+
1268+
assert {"source1", "source2"}.issubset(
1269+
merged_doc.other["bibtex_source"]
1270+
), "Expected merge to keep both bibtex sources"
1271+
assert {"client1", "client2"}.issubset(
1272+
merged_doc.other["client_source"]
1273+
), "Expected merge to keep both client sources"
1274+
assert isinstance(merged_doc, DocDetails), "Merged doc should also be DocDetails"
1275+
1276+
1277+
def test_docdetails_merge_with_list_fields() -> None:
1278+
"""Check republication where the source metadata is the same shape."""
1279+
initial_date = datetime(2023, 1, 1)
1280+
doc1 = DocDetails(
1281+
citation="Citation 1",
1282+
publication_date=initial_date,
1283+
docname="Document 1",
1284+
dockey="key1",
1285+
# NOTE: doc1 has list bibtex_source and list client_source
1286+
other={"bibtex_source": ["source1"], "client_source": ["client1"]},
1287+
)
1288+
1289+
later_publication_date = initial_date + timedelta(weeks=13)
1290+
doc2 = DocDetails(
1291+
citation=doc1.citation,
1292+
publication_date=later_publication_date,
1293+
docname=doc1.docname,
1294+
dockey=doc1.dockey,
1295+
# NOTE: doc2 has list bibtex_source and list client_source
1296+
other={"bibtex_source": ["source2"], "client_source": ["client2"]},
1297+
)
1298+
1299+
# Merge the two DocDetails instances
1300+
merged_doc = doc1 + doc2
1301+
1302+
assert {"source1", "source2"}.issubset(
1303+
merged_doc.other["bibtex_source"]
1304+
), "Expected merge to keep both bibtex sources"
1305+
assert {"client1", "client2"}.issubset(
1306+
merged_doc.other["client_source"]
1307+
), "Expected merge to keep both client sources"
1308+
assert isinstance(merged_doc, DocDetails), "Merged doc should also be DocDetails"
1309+
1310+
12421311
@pytest.mark.vcr
12431312
@pytest.mark.parametrize("use_partition", [True, False])
12441313
@pytest.mark.asyncio

0 commit comments

Comments
 (0)