1+ from azure .storage .filedatalake import DataLakeServiceClient
2+ from azure .storage .blob import BlobServiceClient
13import base64
24import hashlib
35import logging
810from glob import glob
911from typing import IO , AsyncGenerator , Dict , List , Optional , Union
1012
13+ from azure .identity import DefaultAzureCredential
14+
1115from azure .core .credentials_async import AsyncTokenCredential
1216from azure .storage .filedatalake .aio import (
1317 DataLakeServiceClient ,
@@ -22,10 +26,11 @@ class File:
2226 This file might contain access control information about which users or groups can access it
2327 """
2428
25- def __init__ (self , content : IO , acls : Optional [dict [str , list ]] = None , url : Optional [str ] = None ):
29+ def __init__ (self , content : IO , acls : Optional [dict [str , list ]] = None , url : Optional [str ] = None , metadata : Dict [ str , str ] = None ):
2630 self .content = content
2731 self .acls = acls or {}
2832 self .url = url
33+ self .metadata = metadata
2934
3035 def filename (self ):
3136 return os .path .basename (self .content .name )
@@ -58,7 +63,10 @@ async def list(self) -> AsyncGenerator[File, None]:
5863 async def list_paths (self ) -> AsyncGenerator [str , None ]:
5964 if False : # pragma: no cover - this is necessary for mypy to type check
6065 yield
61-
66+
67+ def count_docs (self ) -> int :
68+ if False : # pragma: no cover - this is necessary for mypy to type check
69+ yield
6270
6371class LocalListFileStrategy (ListFileStrategy ):
6472 """
@@ -109,7 +117,23 @@ def check_md5(self, path: str) -> bool:
109117 md5_f .write (existing_hash )
110118
111119 return False
112-
120+
121+
122+ def count_docs (self ) -> int :
123+ """
124+ Return the number of files that match the path pattern.
125+ """
126+ return sum (1 for _ in self ._list_paths_sync (self .path_pattern ))
127+
128+ def _list_paths_sync (self , path_pattern : str ):
129+ """
130+ Synchronous version of _list_paths to be used for counting files.
131+ """
132+ for path in glob (path_pattern ):
133+ if os .path .isdir (path ):
134+ yield from self ._list_paths_sync (f"{ path } /*" )
135+ else :
136+ yield path
113137
114138class ADLSGen2ListFileStrategy (ListFileStrategy ):
115139 """
@@ -167,11 +191,32 @@ async def list(self) -> AsyncGenerator[File, None]:
167191 if acl_parts [0 ] == "user" and "r" in acl_parts [2 ]:
168192 acls ["oids" ].append (acl_parts [1 ])
169193 if acl_parts [0 ] == "group" and "r" in acl_parts [2 ]:
170- acls ["groups" ].append (acl_parts [1 ])
171- yield File (content = open (temp_file_path , "rb" ), acls = acls , url = file_client .url )
194+ acls ["groups" ].append (acl_parts [1 ])
195+ properties = await file_client .get_file_properties ()
196+ yield File (content = open (temp_file_path , "rb" ), acls = acls , url = file_client .url , metadata = properties .metadata )
172197 except Exception as data_lake_exception :
173198 logger .error (f"\t Got an error while reading { path } -> { data_lake_exception } --> skipping file" )
174199 try :
175200 os .remove (temp_file_path )
176201 except Exception as file_delete_exception :
177202 logger .error (f"\t Got an error while deleting { temp_file_path } -> { file_delete_exception } " )
203+
204+ def count_docs (self ) -> int :
205+ """
206+ Return the number of blobs in the specified folder within the Azure Blob Storage container.
207+ """
208+
209+ # Create a BlobServiceClient using account URL and credentials
210+ service_client = BlobServiceClient (
211+ account_url = f"https://{ self .data_lake_storage_account } .blob.core.windows.net" ,
212+ credential = DefaultAzureCredential ())
213+
214+ # Get the container client
215+ container_client = service_client .get_container_client (self .data_lake_filesystem )
216+
217+ # Count blobs within the specified folder
218+ if self .data_lake_path != "/" :
219+ return sum (1 for _ in container_client .list_blobs (name_starts_with = self .data_lake_path ))
220+ else :
221+ return sum (1 for _ in container_client .list_blobs ())
222+
0 commit comments