Skip to content

Commit c07f2a8

Browse files
JoaquinPolonuerpre-commit-ci-lite[bot]jamesbraza
authored
Move doc details line (#856)
Co-authored-by: pre-commit-ci-lite[bot] <117423508+pre-commit-ci-lite[bot]@users.noreply.github.com> Co-authored-by: James Braza <[email protected]>
1 parent 4819b50 commit c07f2a8

File tree

2 files changed

+8
-7
lines changed

2 files changed

+8
-7
lines changed

paperqa/agents/search.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -441,34 +441,35 @@ async def query(
441441
def fetch_kwargs_from_manifest(
442442
file_location: str, manifest: dict[str, Any], manifest_fallback_location: str
443443
) -> dict[str, Any]:
444-
manifest_entry: DocDetails | None = manifest.get(file_location) or manifest.get(
444+
manifest_entry: dict[str, Any] | None = manifest.get(file_location) or manifest.get(
445445
manifest_fallback_location
446446
)
447447
if manifest_entry:
448-
return manifest_entry.model_dump()
448+
return DocDetails(**manifest_entry).model_dump()
449449
return {}
450450

451451

452452
async def maybe_get_manifest(
453453
filename: anyio.Path | None = None,
454-
) -> dict[str, DocDetails]:
454+
) -> dict[str, dict[str, Any]]:
455455
if not filename:
456456
return {}
457457
if filename.suffix == ".csv":
458458
try:
459459
async with await anyio.open_file(filename, mode="r") as file:
460460
content = await file.read()
461-
records = [DocDetails(**r) for r in csv.DictReader(content.splitlines())]
462461
file_loc_to_records = {
463-
str(r.file_location): r for r in records if r.file_location
462+
str(r.get("file_location")): r
463+
for r in csv.DictReader(content.splitlines())
464+
if r.get("file_location")
464465
}
465466
if not file_loc_to_records:
466467
raise ValueError( # noqa: TRY301
467468
"No mapping of file location to details extracted from manifest"
468469
f" file {filename}."
469470
)
470471
logger.debug(
471-
f"Found manifest file at {filename}, read {len(records)} records"
472+
f"Found manifest file at {filename}, read {len(file_loc_to_records)} records"
472473
f" from it, which maps to {len(file_loc_to_records)} locations."
473474
)
474475
except FileNotFoundError:

paperqa/agents/task.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ def __init__(
9696
self._rewards = rewards
9797

9898
async def validate_sources(
99-
self, manifest_or_index: dict[str, DocDetails] | SearchIndex | None = None
99+
self, manifest_or_index: dict[str, dict[str, Any]] | SearchIndex | None = None
100100
) -> None:
101101
"""Validate the sources can be found in the input manifest or index."""
102102
if not self.sources:

0 commit comments

Comments
 (0)