Skip to content

Commit f8d364a

Browse files
JoaquinPolonuerpre-commit-ci-lite[bot]jamesbraza
authored
Support parallel indexing (#849)
Co-authored-by: pre-commit-ci-lite[bot] <117423508+pre-commit-ci-lite[bot]@users.noreply.github.com> Co-authored-by: James Braza <[email protected]>
1 parent 8fb3691 commit f8d364a

File tree

7 files changed

+131
-43
lines changed

7 files changed

+131
-43
lines changed

paperqa/agents/search.py

Lines changed: 46 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@
99
import pickle
1010
import warnings
1111
import zlib
12-
from collections.abc import Callable, Collection, Sequence
12+
from collections import Counter
13+
from collections.abc import AsyncIterator, Callable, Collection, Sequence
1314
from datetime import datetime
1415
from enum import StrEnum, auto
1516
from typing import TYPE_CHECKING, Any, ClassVar
@@ -152,6 +153,7 @@ def __init__(
152153
self._schema: Schema | None = None
153154
self._index: Index | None = None
154155
self._searcher: Searcher | None = None
156+
self._writer: IndexWriter | None = None
155157
self._index_files: dict[str, str] = {}
156158
self.changed = False
157159
self.storage = storage
@@ -237,6 +239,15 @@ async def searcher(self) -> Searcher:
237239
self._searcher = index.searcher()
238240
return self._searcher
239241

242+
@contextlib.asynccontextmanager
243+
async def writer(self, reset: bool = False) -> AsyncIterator[IndexWriter]:
244+
if not self._writer:
245+
index = await self.index
246+
self._writer = index.writer()
247+
yield self._writer
248+
if reset:
249+
self._writer = None
250+
240251
@property
241252
async def count(self) -> int:
242253
return (await self.searcher).num_docs
@@ -295,10 +306,9 @@ async def add_document(
295306
async def _add_document() -> None:
296307
if not await self.filecheck(index_doc["file_location"], index_doc["body"]):
297308
try:
298-
writer: IndexWriter = (await self.index).writer()
299-
writer.add_document(Document.from_dict(index_doc)) # type: ignore[call-arg]
300-
writer.commit()
301-
writer.wait_merging_threads()
309+
async with self.writer() as writer:
310+
# Let caller handle commit to allow for batching
311+
writer.add_document(Document.from_dict(index_doc)) # type: ignore[call-arg]
302312

303313
filehash = self.filehash(index_doc["body"])
304314
(await self.index_files)[index_doc["file_location"]] = filehash
@@ -326,19 +336,17 @@ async def _add_document() -> None:
326336
)
327337
raise
328338

329-
@staticmethod
330339
@retry(
331340
stop=stop_after_attempt(1000),
332341
wait=wait_random_exponential(multiplier=0.25, max=60),
333342
retry=retry_if_exception_type(AsyncRetryError),
334343
reraise=True,
335344
)
336-
def delete_document(index: Index, file_location: str) -> None:
345+
async def delete_document(self, file_location: str) -> None:
337346
try:
338-
writer: IndexWriter = index.writer()
339-
writer.delete_documents("file_location", file_location)
340-
writer.commit()
341-
writer.wait_merging_threads()
347+
async with self.writer() as writer:
348+
writer.delete_documents("file_location", file_location)
349+
await self.save_index()
342350
except ValueError as e:
343351
if "Failed to acquire Lockfile: LockBusy." in str(e):
344352
raise AsyncRetryError("Failed to acquire lock") from e
@@ -347,7 +355,7 @@ def delete_document(index: Index, file_location: str) -> None:
347355
async def remove_from_index(self, file_location: str) -> None:
348356
index_files = await self.index_files
349357
if index_files.get(file_location):
350-
self.delete_document(await self.index, file_location)
358+
await self.delete_document(file_location)
351359
filehash = index_files.pop(file_location)
352360
docs_index_dir = await self.docs_index_directory
353361
# TODO: since the directory is part of the filehash these
@@ -359,6 +367,9 @@ async def remove_from_index(self, file_location: str) -> None:
359367
self.changed = True
360368

361369
async def save_index(self) -> None:
370+
async with self.writer(reset=True) as writer:
371+
writer.commit()
372+
writer.wait_merging_threads()
362373
file_index_path = await self.file_index_filename
363374
async with await anyio.open_file(file_index_path, "wb") as f:
364375
await f.write(zlib.compress(pickle.dumps(await self.index_files)))
@@ -461,8 +472,10 @@ async def process_file(
461472
manifest: dict[str, Any],
462473
semaphore: anyio.Semaphore,
463474
settings: Settings,
475+
processed_counter: Counter[str],
464476
progress_bar_update: Callable[[], Any] | None = None,
465477
) -> None:
478+
466479
abs_file_path = (
467480
pathlib.Path(settings.agent.index.paper_directory).absolute() / rel_file_path
468481
)
@@ -496,16 +509,23 @@ async def process_file(
496509
fields=["title", "author", "journal", "year"],
497510
settings=settings,
498511
)
499-
except (ValueError, ImpossibleParsingError):
512+
except Exception as e:
513+
# We handle any exception here because we want to save_index so we
514+
# 1. can resume the build without rebuilding this file if a separate
515+
# process_file invocation leads to a segfault or crash.
516+
# 2. don't have deadlock issues after.
500517
logger.exception(
501518
f"Error parsing {file_location}, skipping index for this file."
502519
)
503520
await search_index.mark_failed_document(file_location)
504-
# Save so we can resume the build without rebuilding this file if a
505-
# separate process_file invocation leads to a segfault or crash
506521
await search_index.save_index()
507522
if progress_bar_update:
508523
progress_bar_update()
524+
525+
if not isinstance(e, ValueError | ImpossibleParsingError):
526+
# ImpossibleParsingError: parsing failure, don't retry
527+
# ValueError: TODOC
528+
raise
509529
return
510530

511531
this_doc = next(iter(tmp_docs.docs.values()))
@@ -525,9 +545,15 @@ async def process_file(
525545
},
526546
document=tmp_docs,
527547
)
528-
# Save so we can resume the build without rebuilding this file if a
529-
# separate process_file invocation leads to a segfault or crash
530-
await search_index.save_index()
548+
549+
processed_counter["batched_save_counter"] += 1
550+
if (
551+
processed_counter["batched_save_counter"]
552+
== settings.agent.index.batch_size
553+
):
554+
await search_index.save_index()
555+
processed_counter["batched_save_counter"] = 0
556+
531557
logger.info(f"Complete ({title}).")
532558

533559
# Update progress bar for either a new or previously indexed file
@@ -674,6 +700,7 @@ async def get_directory_index( # noqa: PLR0912
674700
)
675701
with progress_bar:
676702
async with anyio.create_task_group() as tg:
703+
processed_counter: Counter[str] = Counter()
677704
for rel_file_path in valid_papers_rel_file_paths:
678705
if index_settings.sync_with_paper_directory:
679706
tg.start_soon(
@@ -683,6 +710,7 @@ async def get_directory_index( # noqa: PLR0912
683710
manifest,
684711
semaphore,
685712
_settings,
713+
processed_counter,
686714
progress_bar_update_fn,
687715
)
688716
else:

paperqa/docs.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -285,7 +285,7 @@ async def aadd( # noqa: PLR0912
285285
llm_model = all_settings.get_llm()
286286
if citation is None:
287287
# Peek first chunk
288-
texts = read_doc(
288+
texts = await read_doc(
289289
path,
290290
Doc(docname="", citation="", dockey=dockey), # Fake doc
291291
chunk_chars=parse_config.chunk_size,
@@ -370,7 +370,7 @@ async def aadd( # noqa: PLR0912
370370
doc, **(query_kwargs | kwargs)
371371
)
372372

373-
texts = read_doc(
373+
texts = await read_doc(
374374
path,
375375
doc,
376376
chunk_chars=parse_config.chunk_size,

paperqa/readers.py

Lines changed: 29 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import asyncio
34
import os
45
from math import ceil
56
from pathlib import Path
@@ -252,7 +253,7 @@ def chunk_code_text(
252253

253254

254255
@overload
255-
def read_doc(
256+
async def read_doc(
256257
path: str | os.PathLike,
257258
doc: Doc,
258259
parsed_text_only: Literal[False],
@@ -264,7 +265,7 @@ def read_doc(
264265

265266

266267
@overload
267-
def read_doc(
268+
async def read_doc(
268269
path: str | os.PathLike,
269270
doc: Doc,
270271
parsed_text_only: Literal[False] = ...,
@@ -276,7 +277,7 @@ def read_doc(
276277

277278

278279
@overload
279-
def read_doc(
280+
async def read_doc(
280281
path: str | os.PathLike,
281282
doc: Doc,
282283
parsed_text_only: Literal[True],
@@ -288,7 +289,7 @@ def read_doc(
288289

289290

290291
@overload
291-
def read_doc(
292+
async def read_doc(
292293
path: str | os.PathLike,
293294
doc: Doc,
294295
parsed_text_only: Literal[False],
@@ -299,7 +300,7 @@ def read_doc(
299300
) -> tuple[list[Text], ParsedMetadata]: ...
300301

301302

302-
def read_doc(
303+
async def read_doc(
303304
path: str | os.PathLike,
304305
doc: Doc,
305306
parsed_text_only: bool = False,
@@ -311,7 +312,6 @@ def read_doc(
311312
"""Parse a document and split into chunks.
312313
313314
Optionally can include just the parsing as well as metadata about the parsing/chunking
314-
315315
Args:
316316
path: local document path
317317
doc: object with document metadata
@@ -322,18 +322,29 @@ def read_doc(
322322
page_size_limit: optional limit on the number of characters per page
323323
"""
324324
str_path = str(path)
325-
parsed_text = None
326325

327326
# start with parsing -- users may want to store this separately
328327
if str_path.endswith(".pdf"):
329-
parsed_text = parse_pdf_to_pages(path, page_size_limit=page_size_limit)
328+
# TODO: Make parse_pdf_to_pages async
329+
parsed_text = await asyncio.to_thread(
330+
parse_pdf_to_pages, path, page_size_limit=page_size_limit
331+
)
330332
elif str_path.endswith(".txt"):
331-
parsed_text = parse_text(path, page_size_limit=page_size_limit)
333+
# TODO: Make parse_text async
334+
parsed_text = await asyncio.to_thread(
335+
parse_text, path, page_size_limit=page_size_limit
336+
)
332337
elif str_path.endswith(".html"):
333-
parsed_text = parse_text(path, html=True, page_size_limit=page_size_limit)
338+
parsed_text = await asyncio.to_thread(
339+
parse_text, path, html=True, page_size_limit=page_size_limit
340+
)
334341
else:
335-
parsed_text = parse_text(
336-
path, split_lines=True, use_tiktoken=False, page_size_limit=page_size_limit
342+
parsed_text = await asyncio.to_thread(
343+
parse_text,
344+
path,
345+
split_lines=True,
346+
use_tiktoken=False,
347+
page_size_limit=page_size_limit,
337348
)
338349

339350
if parsed_text_only:
@@ -352,7 +363,9 @@ def read_doc(
352363
parsed_text, doc, chunk_chars=chunk_chars, overlap=overlap
353364
)
354365
chunk_metadata = ChunkMetadata(
355-
chunk_chars=chunk_chars, overlap=overlap, chunk_type="overlap_pdf_by_page"
366+
chunk_chars=chunk_chars,
367+
overlap=overlap,
368+
chunk_type="overlap_pdf_by_page",
356369
)
357370
elif str_path.endswith((".txt", ".html")):
358371
chunked_text = chunk_text(
@@ -366,7 +379,9 @@ def read_doc(
366379
parsed_text, doc, chunk_chars=chunk_chars, overlap=overlap
367380
)
368381
chunk_metadata = ChunkMetadata(
369-
chunk_chars=chunk_chars, overlap=overlap, chunk_type="overlap_code_by_line"
382+
chunk_chars=chunk_chars,
383+
overlap=overlap,
384+
chunk_type="overlap_code_by_line",
370385
)
371386

372387
if include_metadata:

paperqa/settings.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -408,6 +408,11 @@ class IndexSettings(BaseModel):
408408
default=5, # low default for folks without S2/Crossref keys
409409
description="Number of concurrent filesystem reads for indexing",
410410
)
411+
batch_size: int = Field(
412+
default=1,
413+
ge=1,
414+
description="Number of files to process before committing to the index.",
415+
)
411416
sync_with_paper_directory: bool = Field(
412417
default=True,
413418
description=(

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,7 @@ disable = [
232232
"too-many-lines", # Don't care to enforce this
233233
"too-many-locals", # Rely on ruff PLR0914 for this
234234
"too-many-positional-arguments", # Don't care to enforce this
235+
"too-many-public-methods", # Rely on ruff PLR0904 for this
235236
"too-many-return-statements", # Rely on ruff PLR0911 for this
236237
"too-many-statements", # Rely on ruff PLR0915 for this
237238
"undefined-loop-variable", # Don't care to enforce this

tests/test_agents.py

Lines changed: 40 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -161,10 +161,11 @@ async def crashing_aadd(*args, **kwargs) -> str | None:
161161
) as mock_aadd,
162162
):
163163
index = await get_directory_index(settings=agent_test_settings)
164+
165+
assert len(await index.index_files) == num_source_files
164166
assert (
165-
mock_aadd.await_count <= crash_threshold
166-
), "Should have been able to resume build"
167-
assert len(await index.index_files) > crash_threshold
167+
mock_aadd.await_count < num_source_files
168+
), "Should not rebuild the whole index"
168169

169170

170171
@pytest.mark.asyncio
@@ -1117,3 +1118,39 @@ async def test_continuation(self) -> None:
11171118
# Check continuation of the search
11181119
result = await tool.clinical_trials_search("Covid-19 vaccines", state)
11191120
assert len(state.docs.docs) > trial_count, "Search was unable to continue"
1121+
1122+
1123+
@pytest.mark.asyncio
1124+
async def test_index_build_concurrency(agent_test_settings: Settings) -> None:
1125+
1126+
high_concurrency_settings = agent_test_settings.model_copy(deep=True)
1127+
high_concurrency_settings.agent.index.name = "high_concurrency"
1128+
high_concurrency_settings.agent.index.concurrency = 3
1129+
high_concurrency_settings.agent.index.batch_size = 3
1130+
with patch.object(
1131+
SearchIndex, "save_index", side_effect=SearchIndex.save_index, autospec=True
1132+
) as mock_save_index:
1133+
start_time = time.perf_counter()
1134+
await get_directory_index(settings=high_concurrency_settings)
1135+
high_concurrency_duration = time.perf_counter() - start_time
1136+
high_batch_save_count = mock_save_index.call_count
1137+
1138+
low_concurrency_settings = agent_test_settings.model_copy(deep=True)
1139+
low_concurrency_settings.agent.index.name = "low_concurrency"
1140+
low_concurrency_settings.agent.index.concurrency = 1
1141+
low_concurrency_settings.agent.index.batch_size = 1
1142+
with patch.object(
1143+
SearchIndex, "save_index", side_effect=SearchIndex.save_index, autospec=True
1144+
) as mock_save_index:
1145+
start_time = time.perf_counter()
1146+
await get_directory_index(settings=low_concurrency_settings)
1147+
low_concurrency_duration = time.perf_counter() - start_time
1148+
low_batch_save_count = mock_save_index.call_count
1149+
1150+
assert high_concurrency_duration * 1.1 < low_concurrency_duration, (
1151+
f"Expected high concurrency to be faster, but took {high_concurrency_duration:.2f}s "
1152+
f"compared to {low_concurrency_duration:.2f}s"
1153+
)
1154+
assert (
1155+
high_batch_save_count < low_batch_save_count
1156+
), f"Expected fewer save_index with high batch size, but got {high_batch_save_count} vs {low_batch_save_count}"

0 commit comments

Comments
 (0)