Skip to content

Speed up prepdocs for file strategy with parallel async pools #2553

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion app/backend/prepdocs.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,13 +311,19 @@ async def main(strategy: Strategy, setup_index: bool = True):
required=False,
help="Search service system assigned Identity (Managed identity) (used for integrated vectorization)",
)
parser.add_argument(
"--concurrency",
type=int,
default=FileStrategy.DEFAULT_CONCURRENCY,
help="Max. number of concurrent tasks to run for processing files (file strategy only) (default: 4)",
)

parser.add_argument("--verbose", "-v", action="store_true", help="Verbose output")
args = parser.parse_args()

if args.verbose:
logging.basicConfig(format="%(message)s", datefmt="[%X]", handlers=[RichHandler(rich_tracebacks=True)])
# We only set the level to INFO for our logger,
# We only set the level to DEBUG for our logger,
# to avoid seeing the noisy INFO level logs from the Azure SDKs
logger.setLevel(logging.DEBUG)

Expand Down Expand Up @@ -467,6 +473,7 @@ async def main(strategy: Strategy, setup_index: bool = True):
category=args.category,
use_content_understanding=use_content_understanding,
content_understanding_endpoint=os.getenv("AZURE_CONTENTUNDERSTANDING_ENDPOINT"),
concurrency=args.concurrency,
)

loop.run_until_complete(main(ingestion_strategy, setup_index=not args.remove and not args.removeall))
Expand Down
4 changes: 2 additions & 2 deletions app/backend/prepdocslib/blobmanager.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,15 +56,15 @@ async def upload_blob(self, file: File) -> Optional[list[str]]:
if file.url is None:
with open(file.content.name, "rb") as reopened_file:
blob_name = BlobManager.blob_name_from_file_name(file.content.name)
logger.info("Uploading blob for whole file -> %s", blob_name)
logger.info("'%s': Uploading blob for file to '%s'", file.content.name, blob_name)
blob_client = await container_client.upload_blob(blob_name, reopened_file, overwrite=True)
file.url = blob_client.url

if self.store_page_images:
if os.path.splitext(file.content.name)[1].lower() == ".pdf":
return await self.upload_pdf_blob_images(service_client, container_client, file)
else:
logger.info("File %s is not a PDF, skipping image upload", file.content.name)
logger.info("'%s': File is not a PDF, skipping image upload", file.content.name)

return None

Expand Down
6 changes: 2 additions & 4 deletions app/backend/prepdocslib/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,12 +114,11 @@ async def create_embedding_batch(self, texts: list[str], dimensions_args: ExtraA
model=self.open_ai_model_name, input=batch.texts, **dimensions_args
)
embeddings.extend([data.embedding for data in emb_response.data])
logger.info(
logger.debug(
"Computed embeddings in batch. Batch size: %d, Token count: %d",
len(batch.texts),
batch.token_length,
)

return embeddings

async def create_embedding_single(self, text: str, dimensions_args: ExtraArgs) -> list[float]:
Expand All @@ -134,8 +133,7 @@ async def create_embedding_single(self, text: str, dimensions_args: ExtraArgs) -
emb_response = await client.embeddings.create(
model=self.open_ai_model_name, input=text, **dimensions_args
)
logger.info("Computed embedding for text section. Character count: %d", len(text))

logger.debug("Computed embedding for text section. Character count: %d", len(text))
return emb_response.data[0].embedding

async def create_embeddings(self, texts: list[str]) -> list[list[float]]:
Expand Down
26 changes: 20 additions & 6 deletions app/backend/prepdocslib/filestrategy.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import logging
from typing import Optional

Expand All @@ -23,11 +24,11 @@ async def parse_file(
key = file.file_extension().lower()
processor = file_processors.get(key)
if processor is None:
logger.info("Skipping '%s', no parser found.", file.filename())
logger.info("'%s': Skipping, no parser found.", file.content.name)
return []
logger.info("Ingesting '%s'", file.filename())
logger.info("'%s': Starting ingestion process", file.content.name)
pages = [page async for page in processor.parser.parse(content=file.content)]
logger.info("Splitting '%s' into sections", file.filename())
logger.info("'%s': Splitting into sections", file.content.name)
if image_embeddings:
logger.warning("Each page will be split into smaller chunks of text, but images will be of the entire page.")
sections = [
Expand All @@ -41,6 +42,8 @@ class FileStrategy(Strategy):
Strategy for ingesting documents into a search service from files stored either locally or in a data lake storage account
"""

DEFAULT_CONCURRENCY = 4

def __init__(
self,
list_file_strategy: ListFileStrategy,
Expand All @@ -56,6 +59,7 @@ def __init__(
category: Optional[str] = None,
use_content_understanding: bool = False,
content_understanding_endpoint: Optional[str] = None,
concurrency: int = DEFAULT_CONCURRENCY,
):
self.list_file_strategy = list_file_strategy
self.blob_manager = blob_manager
Expand All @@ -70,6 +74,7 @@ def __init__(
self.category = category
self.use_content_understanding = use_content_understanding
self.content_understanding_endpoint = content_understanding_endpoint
self.concurrency = concurrency

def setup_search_manager(self):
self.search_manager = SearchManager(
Expand Down Expand Up @@ -98,20 +103,29 @@ async def setup(self):

async def run(self):
self.setup_search_manager()
if self.document_action == DocumentAction.Add:
files = self.list_file_strategy.list()
async for file in files:

async def process_file_worker(semaphore: asyncio.Semaphore, file: File):
async with semaphore:
try:
sections = await parse_file(file, self.file_processors, self.category, self.image_embeddings)
if sections:
blob_sas_uris = await self.blob_manager.upload_blob(file)
blob_image_embeddings: Optional[list[list[float]]] = None
if self.image_embeddings and blob_sas_uris:
blob_image_embeddings = await self.image_embeddings.create_embeddings(blob_sas_uris)
logger.info("'%s': Computing embeddings and updating search index", file.content.name)
await self.search_manager.update_content(sections, blob_image_embeddings, url=file.url)
finally:
if file:
logger.info("'%s': Finished processing file", file.content.name)
file.close()

if self.document_action == DocumentAction.Add:
files = self.list_file_strategy.list()
logger.info("Running with concurrency: %d", self.concurrency)
semaphore = asyncio.Semaphore(self.concurrency)
tasks = [process_file_worker(semaphore, file) async for file in files]
await asyncio.gather(*tasks)
Copy link
Preview

Copilot AI Jul 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider using asyncio.gather(*tasks, return_exceptions=True) or handling exceptions within process_file_worker so that a single task failure doesn't cancel the entire batch.

Suggested change
await asyncio.gather(*tasks)
results = await asyncio.gather(*tasks, return_exceptions=True)
for result in results:
if isinstance(result, Exception):
logger.error("Task failed with exception: %s", str(result), exc_info=True)

Copilot uses AI. Check for mistakes.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tonybaloney Thoughts on this suggestion from Copilot? Is it correct?

elif self.document_action == DocumentAction.Remove:
paths = self.list_file_strategy.list_paths()
async for path in paths:
Expand Down
2 changes: 1 addition & 1 deletion app/backend/prepdocslib/htmlparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ async def parse(self, content: IO) -> AsyncGenerator[Page, None]:
Returns:
Page: The parsed html Page.
"""
logger.info("Extracting text from '%s' using local HTML parser (BeautifulSoup)", content.name)
logger.info("'%s': Extracting text using local HTML parser (BeautifulSoup)", content.name)

data = content.read()
soup = BeautifulSoup(data, "html.parser")
Expand Down
2 changes: 1 addition & 1 deletion app/backend/prepdocslib/integratedvectorizerstrategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ async def create_embedding_skill(self, index_name: str) -> SearchIndexerSkillset
return skillset

async def setup(self):
logger.info("Setting up search index using integrated vectorization...")
logger.info("Setting up search index using integrated vectorization")
search_manager = SearchManager(
search_info=self.search_info,
search_analyzer_name=self.search_analyzer_name,
Expand Down
2 changes: 1 addition & 1 deletion app/backend/prepdocslib/listfilestrategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def check_md5(self, path: str) -> bool:
stored_hash = md5_f.read()

if stored_hash and stored_hash.strip() == existing_hash.strip():
logger.info("Skipping %s, no changes detected.", path)
logger.info("'%s': Skipping, no changes detected.", path)
return True

# Write the hash
Expand Down
4 changes: 2 additions & 2 deletions app/backend/prepdocslib/mediadescriber.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ async def poll():
return await poll()

async def create_analyzer(self):
logger.info("Creating analyzer '%s'...", self.analyzer_schema["analyzerId"])
logger.info("Creating analyzer '%s'", self.analyzer_schema["analyzerId"])

token_provider = get_bearer_token_provider(self.credential, "https://cognitiveservices.azure.com/.default")
token = await token_provider()
Expand All @@ -84,7 +84,7 @@ async def create_analyzer(self):
await self.poll_api(session, poll_url, headers)

async def describe_image(self, image_bytes: bytes) -> str:
logger.info("Sending image to Azure Content Understanding service...")
logger.info("Sending image to Azure Content Understanding service")
async with aiohttp.ClientSession() as session:
token = await self.credential.get_token("https://cognitiveservices.azure.com/.default")
headers = {"Authorization": "Bearer " + token.token}
Expand Down
4 changes: 2 additions & 2 deletions app/backend/prepdocslib/pdfparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class LocalPdfParser(Parser):
"""

async def parse(self, content: IO) -> AsyncGenerator[Page, None]:
logger.info("Extracting text from '%s' using local PDF parser (pypdf)", content.name)
logger.info("'%s': Extracting text using local PDF parser (pypdf)", content.name)

reader = PdfReader(content)
pages = reader.pages
Expand Down Expand Up @@ -65,7 +65,7 @@ def __init__(
self.content_understanding_endpoint = content_understanding_endpoint

async def parse(self, content: IO) -> AsyncGenerator[Page, None]:
logger.info("Extracting text from '%s' using Azure Document Intelligence", content.name)
logger.info("'%s': Extracting text using Azure Document Intelligence", content.name)

async with DocumentIntelligenceClient(
endpoint=self.endpoint, credential=self.credential
Expand Down
6 changes: 3 additions & 3 deletions app/backend/prepdocslib/searchmanager.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def __init__(
self.search_images = search_images

async def create_index(self):
logger.info("Checking whether search index %s exists...", self.search_info.index_name)
logger.info("Checking whether search index '%s' exists", self.search_info.index_name)

async with self.search_info.create_search_index_client() as search_index_client:

Expand Down Expand Up @@ -280,10 +280,10 @@ async def create_index(self):

await search_index_client.create_index(index)
else:
logger.info("Search index %s already exists", self.search_info.index_name)
logger.info("Search index '%s' already exists", self.search_info.index_name)
existing_index = await search_index_client.get_index(self.search_info.index_name)
if not any(field.name == "storageUrl" for field in existing_index.fields):
logger.info("Adding storageUrl field to index %s", self.search_info.index_name)
logger.info("Adding storageUrl field to index '%s'", self.search_info.index_name)
existing_index.fields.append(
SimpleField(
name="storageUrl",
Expand Down