Skip to content

Commit 4ef5622

Browse files
committed
Addressed some of my comments
1 parent 9d2dbf1 commit 4ef5622

File tree

5 files changed

+143
-139
lines changed

5 files changed

+143
-139
lines changed

app/backend/prepdocslib/blobmanager.py

Lines changed: 42 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,16 @@
33
import logging
44
import os
55
import re
6-
from typing import List, Optional, Union, NamedTuple, Tuple
6+
from enum import Enum
7+
from typing import List, Optional, Union
78

89
import fitz # type: ignore
910
from azure.core.credentials_async import AsyncTokenCredential
1011
from azure.storage.blob import (
12+
BlobClient,
1113
BlobSasPermissions,
1214
UserDelegationKey,
13-
generate_blob_sas,
14-
BlobClient
15+
generate_blob_sas,
1516
)
1617
from azure.storage.blob.aio import BlobServiceClient, ContainerClient
1718
from PIL import Image, ImageDraw, ImageFont
@@ -21,6 +22,7 @@
2122

2223
logger = logging.getLogger("scripts")
2324

25+
2426
class BlobManager:
2527
"""
2628
Class to manage uploading and deleting blobs containing citation information from a blob storage account
@@ -45,58 +47,60 @@ def __init__(
4547
self.subscriptionId = subscriptionId
4648
self.user_delegation_key: Optional[UserDelegationKey] = None
4749

48-
#async def upload_blob(self, file: File, container_client:ContainerClient) -> Optional[List[str]]:
49-
50-
async def _create_new_blob(self, file: File, container_client:ContainerClient) -> BlobClient:
50+
async def _create_new_blob(self, file: File, container_client: ContainerClient) -> BlobClient:
5151
with open(file.content.name, "rb") as reopened_file:
52-
blob_name = BlobManager.blob_name_from_file_name(file.content.name)
53-
logger.info("Uploading blob for whole file -> %s", blob_name)
54-
return await container_client.upload_blob(blob_name, reopened_file, overwrite=True, metadata=file.metadata)
52+
blob_name = BlobManager.blob_name_from_file_name(file.content.name)
53+
logger.info("Uploading blob for whole file -> %s", blob_name)
54+
return await container_client.upload_blob(blob_name, reopened_file, overwrite=True, metadata=file.metadata)
5555

56-
async def _file_blob_update_needed(self, blob_client: BlobClient, file : File) -> bool:
57-
md5_check : int = 0 # 0= not done, 1, positive,. 2 negative
56+
async def _file_blob_update_needed(self, blob_client: BlobClient, file: File) -> bool:
5857
# Get existing blob properties
5958
blob_properties = await blob_client.get_blob_properties()
6059
blob_metadata = blob_properties.metadata
61-
60+
6261
# Check if the md5 values are the same
63-
file_md5 = file.metadata.get('md5')
64-
blob_md5 = blob_metadata.get('md5')
65-
66-
# Remove md5 from file metadata if it matches the blob metadata
67-
if file_md5 and file_md5 != blob_md5:
68-
return True
69-
else:
70-
return False
71-
62+
file_md5 = file.metadata.get("md5")
63+
blob_md5 = blob_metadata.get("md5")
64+
65+
# If the file has an md5 value, check if it is different from the blob
66+
return file_md5 and file_md5 != blob_md5
67+
7268
async def upload_blob(self, file: File) -> Optional[List[str]]:
7369
async with BlobServiceClient(
7470
account_url=self.endpoint, credential=self.credential, max_single_put_size=4 * 1024 * 1024
7571
) as service_client, service_client.get_container_client(self.container) as container_client:
7672
if not await container_client.exists():
7773
await container_client.create_container()
78-
79-
# Re-open and upload the original file
80-
md5_check : int = 0 # 0= not done, 1, positive,. 2 negative
81-
82-
# upload the file local storage zu azure storage
74+
75+
# Re-open and upload the original file if the blob does not exist or the md5 values do not match
76+
class MD5Check(Enum):
77+
NOT_DONE = 0
78+
MATCH = 1
79+
NO_MATCH = 2
80+
81+
md5_check = MD5Check.NOT_DONE
82+
83+
# Upload the file to Azure Storage
8384
# file.url is only None if files are not uploaded yet, for datalake it is set
8485
if file.url is None:
8586
blob_client = container_client.get_blob_client(file.url)
8687

8788
if not await blob_client.exists():
89+
logger.info("Blob %s does not exist, uploading", file.url)
8890
blob_client = await self._create_new_blob(file, container_client)
8991
else:
9092
if self._blob_update_needed(blob_client, file):
91-
md5_check = 2
93+
logger.info("Blob %s exists but md5 values do not match, updating", file.url)
94+
md5_check = MD5Check.NO_MATCH
9295
# Upload the file with the updated metadata
9396
with open(file.content.name, "rb") as data:
9497
await blob_client.upload_blob(data, overwrite=True, metadata=file.metadata)
9598
else:
96-
md5_check = 1
99+
logger.info("Blob %s exists and md5 values match, skipping upload", file.url)
100+
md5_check = MD5Check.MATCH
97101
file.url = blob_client.url
98-
99-
if md5_check!=1 and self.store_page_images:
102+
103+
if md5_check != MD5Check.MATCH and self.store_page_images:
100104
if os.path.splitext(file.content.name)[1].lower() == ".pdf":
101105
return await self.upload_pdf_blob_images(service_client, container_client, file)
102106
else:
@@ -127,20 +131,19 @@ async def upload_pdf_blob_images(
127131

128132
for i in range(page_count):
129133
blob_name = BlobManager.blob_image_name_from_file_page(file.content.name, i)
130-
134+
131135
blob_client = container_client.get_blob_client(blob_name)
132-
do_upload : bool = True
133136
if await blob_client.exists():
134137
# Get existing blob properties
135138
blob_properties = await blob_client.get_blob_properties()
136139
blob_metadata = blob_properties.metadata
137-
140+
138141
# Check if the md5 values are the same
139-
file_md5 = file.metadata.get('md5')
140-
blob_md5 = blob_metadata.get('md5')
142+
file_md5 = file.metadata.get("md5")
143+
blob_md5 = blob_metadata.get("md5")
141144
if file_md5 == blob_md5:
142-
continue # documemt already uploaded
143-
145+
continue # documemt already uploaded
146+
144147
logger.debug("Converting page %s to image and uploading -> %s", i, blob_name)
145148

146149
doc = fitz.open(file.content.name)
@@ -167,7 +170,7 @@ async def upload_pdf_blob_images(
167170
output = io.BytesIO()
168171
new_img.save(output, format="PNG")
169172
output.seek(0)
170-
173+
171174
await blob_client.upload_blob(data=output, overwrite=True, metadata=file.metadata)
172175
if not self.user_delegation_key:
173176
self.user_delegation_key = await service_client.get_user_delegation_key(start_time, expiry_time)
@@ -181,7 +184,7 @@ async def upload_pdf_blob_images(
181184
permission=BlobSasPermissions(read=True),
182185
expiry=expiry_time,
183186
start=start_time,
184-
)
187+
)
185188
sas_uris.append(f"{blob_client.url}?{sas_token}")
186189

187190
return sas_uris

app/backend/prepdocslib/filestrategy.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,6 @@
11
import logging
2-
import asyncio
3-
from concurrent.futures import ThreadPoolExecutor
42
from typing import List, Optional
5-
from concurrent.futures import ThreadPoolExecutor
6-
from typing import List, Optional
7-
from tqdm.asyncio import tqdm
3+
84
from .blobmanager import BlobManager
95
from .embeddings import ImageEmbeddings, OpenAIEmbeddings
106
from .fileprocessor import FileProcessor
@@ -36,6 +32,7 @@ async def parse_file(
3632
]
3733
return sections
3834

35+
3936
class FileStrategy(Strategy):
4037
"""
4138
Strategy for ingesting documents into a search service from files stored either locally or in a data lake storage account
@@ -96,7 +93,9 @@ async def run(self):
9693
blob_image_embeddings: Optional[List[List[float]]] = None
9794
if self.image_embeddings and blob_sas_uris:
9895
blob_image_embeddings = await self.image_embeddings.create_embeddings(blob_sas_uris)
99-
await search_manager.update_content(sections=sections, file=file, image_embeddings=blob_image_embeddings)
96+
await search_manager.update_content(
97+
sections=sections, file=file, image_embeddings=blob_image_embeddings
98+
)
10099
finally:
101100
if file:
102101
file.close()
@@ -128,7 +127,9 @@ async def process_file(self, file, search_manager):
128127
blob_image_embeddings: Optional[List[List[float]]] = None
129128
if self.image_embeddings and blob_sas_uris:
130129
blob_image_embeddings = await self.image_embeddings.create_embeddings(blob_sas_uris)
131-
await search_manager.update_content(sections=sections, file=file, image_embeddings=blob_image_embeddings)
130+
await search_manager.update_content(
131+
sections=sections, file=file, image_embeddings=blob_image_embeddings
132+
)
132133
finally:
133134
if file:
134135
file.close()

app/backend/prepdocslib/listfilestrategy.py

Lines changed: 22 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
from azure.storage.filedatalake import DataLakeServiceClient
2-
from azure.storage.blob import BlobServiceClient
31
import base64
42
import hashlib
53
import logging
@@ -10,12 +8,10 @@
108
from glob import glob
119
from typing import IO, AsyncGenerator, Dict, List, Optional, Union
1210

13-
from azure.identity import DefaultAzureCredential
14-
1511
from azure.core.credentials_async import AsyncTokenCredential
16-
from azure.storage.filedatalake.aio import (
17-
DataLakeServiceClient,
18-
)
12+
from azure.identity import DefaultAzureCredential
13+
from azure.storage.blob import BlobServiceClient
14+
from azure.storage.filedatalake.aio import DataLakeServiceClient
1915

2016
logger = logging.getLogger("scripts")
2117

@@ -26,11 +22,17 @@ class File:
2622
This file might contain access control information about which users or groups can access it
2723
"""
2824

29-
def __init__(self, content: IO, acls: Optional[dict[str, list]] = None, url: Optional[str] = None, metadata : Dict[str, str]= None):
25+
def __init__(
26+
self,
27+
content: IO,
28+
acls: Optional[dict[str, list]] = None,
29+
url: Optional[str] = None,
30+
metadata: Dict[str, str] = None,
31+
):
3032
self.content = content
3133
self.acls = acls or {}
3234
self.url = url
33-
self.metadata = metadata
35+
self.metadata = metadata
3436

3537
def filename(self):
3638
return os.path.basename(self.content.name)
@@ -63,11 +65,12 @@ async def list(self) -> AsyncGenerator[File, None]:
6365
async def list_paths(self) -> AsyncGenerator[str, None]:
6466
if False: # pragma: no cover - this is necessary for mypy to type check
6567
yield
66-
68+
6769
def count_docs(self) -> int:
6870
if False: # pragma: no cover - this is necessary for mypy to type check
6971
yield
7072

73+
7174
class LocalListFileStrategy(ListFileStrategy):
7275
"""
7376
Concrete strategy for listing files that are located in a local filesystem
@@ -117,7 +120,6 @@ def check_md5(self, path: str) -> bool:
117120
md5_f.write(existing_hash)
118121

119122
return False
120-
121123

122124
def count_docs(self) -> int:
123125
"""
@@ -135,6 +137,7 @@ def _list_paths_sync(self, path_pattern: str):
135137
else:
136138
yield path
137139

140+
138141
class ADLSGen2ListFileStrategy(ListFileStrategy):
139142
"""
140143
Concrete strategy for listing files that are located in a data lake storage account
@@ -191,9 +194,11 @@ async def list(self) -> AsyncGenerator[File, None]:
191194
if acl_parts[0] == "user" and "r" in acl_parts[2]:
192195
acls["oids"].append(acl_parts[1])
193196
if acl_parts[0] == "group" and "r" in acl_parts[2]:
194-
acls["groups"].append(acl_parts[1])
197+
acls["groups"].append(acl_parts[1])
195198
properties = await file_client.get_file_properties()
196-
yield File(content=open(temp_file_path, "rb"), acls=acls, url=file_client.url, metadata=properties.metadata)
199+
yield File(
200+
content=open(temp_file_path, "rb"), acls=acls, url=file_client.url, metadata=properties.metadata
201+
)
197202
except Exception as data_lake_exception:
198203
logger.error(f"\tGot an error while reading {path} -> {data_lake_exception} --> skipping file")
199204
try:
@@ -205,18 +210,18 @@ def count_docs(self) -> int:
205210
"""
206211
Return the number of blobs in the specified folder within the Azure Blob Storage container.
207212
"""
208-
213+
209214
# Create a BlobServiceClient using account URL and credentials
210215
service_client = BlobServiceClient(
211216
account_url=f"https://{self.data_lake_storage_account}.blob.core.windows.net",
212-
credential=DefaultAzureCredential())
217+
credential=DefaultAzureCredential(),
218+
)
213219

214220
# Get the container client
215221
container_client = service_client.get_container_client(self.data_lake_filesystem)
216222

217223
# Count blobs within the specified folder
218224
if self.data_lake_path != "/":
219-
return sum(1 for _ in container_client.list_blobs(name_starts_with= self.data_lake_path))
225+
return sum(1 for _ in container_client.list_blobs(name_starts_with=self.data_lake_path))
220226
else:
221227
return sum(1 for _ in container_client.list_blobs())
222-

0 commit comments

Comments
 (0)