Skip to content

Commit 15f34a5

Browse files
authored
support checksum
1 parent 66438f2 commit 15f34a5

File tree

1 file changed

+50
-5
lines changed

1 file changed

+50
-5
lines changed

app/backend/prepdocslib/listfilestrategy.py

Lines changed: 50 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from azure.storage.filedatalake import DataLakeServiceClient
2+
from azure.storage.blob import BlobServiceClient
13
import base64
24
import hashlib
35
import logging
@@ -8,6 +10,8 @@
810
from glob import glob
911
from typing import IO, AsyncGenerator, Dict, List, Optional, Union
1012

13+
from azure.identity import DefaultAzureCredential
14+
1115
from azure.core.credentials_async import AsyncTokenCredential
1216
from azure.storage.filedatalake.aio import (
1317
DataLakeServiceClient,
@@ -22,10 +26,11 @@ class File:
2226
This file might contain access control information about which users or groups can access it
2327
"""
2428

25-
def __init__(self, content: IO, acls: Optional[dict[str, list]] = None, url: Optional[str] = None):
29+
def __init__(self, content: IO, acls: Optional[dict[str, list]] = None, url: Optional[str] = None, metadata : Dict[str, str]= None):
2630
self.content = content
2731
self.acls = acls or {}
2832
self.url = url
33+
self.metadata = metadata
2934

3035
def filename(self):
3136
return os.path.basename(self.content.name)
@@ -58,7 +63,10 @@ async def list(self) -> AsyncGenerator[File, None]:
5863
async def list_paths(self) -> AsyncGenerator[str, None]:
5964
if False: # pragma: no cover - this is necessary for mypy to type check
6065
yield
61-
66+
67+
def count_docs(self) -> int:
68+
if False: # pragma: no cover - this is necessary for mypy to type check
69+
yield
6270

6371
class LocalListFileStrategy(ListFileStrategy):
6472
"""
@@ -109,7 +117,23 @@ def check_md5(self, path: str) -> bool:
109117
md5_f.write(existing_hash)
110118

111119
return False
112-
120+
121+
122+
def count_docs(self) -> int:
123+
"""
124+
Return the number of files that match the path pattern.
125+
"""
126+
return sum(1 for _ in self._list_paths_sync(self.path_pattern))
127+
128+
def _list_paths_sync(self, path_pattern: str):
129+
"""
130+
Synchronous version of _list_paths to be used for counting files.
131+
"""
132+
for path in glob(path_pattern):
133+
if os.path.isdir(path):
134+
yield from self._list_paths_sync(f"{path}/*")
135+
else:
136+
yield path
113137

114138
class ADLSGen2ListFileStrategy(ListFileStrategy):
115139
"""
@@ -167,11 +191,32 @@ async def list(self) -> AsyncGenerator[File, None]:
167191
if acl_parts[0] == "user" and "r" in acl_parts[2]:
168192
acls["oids"].append(acl_parts[1])
169193
if acl_parts[0] == "group" and "r" in acl_parts[2]:
170-
acls["groups"].append(acl_parts[1])
171-
yield File(content=open(temp_file_path, "rb"), acls=acls, url=file_client.url)
194+
acls["groups"].append(acl_parts[1])
195+
properties = await file_client.get_file_properties()
196+
yield File(content=open(temp_file_path, "rb"), acls=acls, url=file_client.url, metadata=properties.metadata)
172197
except Exception as data_lake_exception:
173198
logger.error(f"\tGot an error while reading {path} -> {data_lake_exception} --> skipping file")
174199
try:
175200
os.remove(temp_file_path)
176201
except Exception as file_delete_exception:
177202
logger.error(f"\tGot an error while deleting {temp_file_path} -> {file_delete_exception}")
203+
204+
def count_docs(self) -> int:
205+
"""
206+
Return the number of blobs in the specified folder within the Azure Blob Storage container.
207+
"""
208+
209+
# Create a BlobServiceClient using account URL and credentials
210+
service_client = BlobServiceClient(
211+
account_url=f"https://{self.data_lake_storage_account}.blob.core.windows.net",
212+
credential=DefaultAzureCredential())
213+
214+
# Get the container client
215+
container_client = service_client.get_container_client(self.data_lake_filesystem)
216+
217+
# Count blobs within the specified folder
218+
if self.data_lake_path != "/":
219+
return sum(1 for _ in container_client.list_blobs(name_starts_with= self.data_lake_path))
220+
else:
221+
return sum(1 for _ in container_client.list_blobs())
222+

0 commit comments

Comments
 (0)