11import argparse
22import asyncio
3+ import datetime
34import json
45import logging
56import os
7+ import hashlib
68from typing import Any , Optional
79
810import aiohttp
@@ -56,7 +58,7 @@ def __init__(
5658 self .data_access_control_format = data_access_control_format
5759 self .graph_headers : Optional [dict [str , str ]] = None
5860
59- async def run (self ):
61+ async def run (self , scandirs : bool = False ):
6062 async with self .create_service_client () as service_client :
6163 logger .info (f"Ensuring { self .filesystem_name } exists..." )
6264 async with service_client .get_file_system_client (self .filesystem_name ) as filesystem_client :
@@ -80,15 +82,17 @@ async def run(self):
8082 )
8183 directories [directory ] = directory_client
8284
85+ logger .info ("Uploading scanned files..." )
86+ if scandirs and directory != "/" :
87+ await self .scan_and_upload_directories (directories , filesystem_client )
88+
8389 logger .info ("Uploading files..." )
8490 for file , file_info in self .data_access_control_format ["files" ].items ():
8591 directory = file_info ["directory" ]
8692 if directory not in directories :
8793 logger .error (f"File { file } has unknown directory { directory } , exiting..." )
8894 return
89- await self .upload_file (
90- directory_client = directories [directory ], file_path = os .path .join (self .data_directory , file )
91- )
95+ await self .upload_file (directory_client = directories [directory ], file_path = os .path .join (self .data_directory , file ))
9296
9397 logger .info ("Setting access control..." )
9498 for directory , access_control in self .data_access_control_format ["directories" ].items ():
@@ -100,8 +104,7 @@ async def run(self):
100104 f"Directory { directory } has unknown group { group_name } in access control list, exiting"
101105 )
102106 return
103- await directory_client .update_access_control_recursive (
104- acl = f"group:{ groups [group_name ]} :r-x"
107+ await directory_client .update_access_control_recursive (acl = f"group:{ groups [group_name ]} :r-x"
105108 )
106109 if "oids" in access_control :
107110 for oid in access_control ["oids" ]:
@@ -110,15 +113,110 @@ async def run(self):
110113 for directory_client in directories .values ():
111114 await directory_client .close ()
112115
116+ async def walk_files (self , src_filepath = "." ):
117+ filepath_list = []
118+
119+ #This for loop uses the os.walk() function to walk through the files and directories
120+ #and records the filepaths of the files to a list
121+ for root , dirs , files in os .walk (src_filepath ):
122+
123+ #iterate through the files currently obtained by os.walk() and
124+ #create the filepath string for that file and add it to the filepath_list list
125+ for file in files :
126+ #Checks to see if the root is '.' and changes it to the correct current
127+ #working directory by calling os.getcwd(). Otherwise root_path will just be the root variable value.
128+ if root == '.' :
129+ root_path = os .getcwd () + "/"
130+ else :
131+ root_path = root
132+
133+ filepath = os .path .join (root_path , file )
134+
135+ #Appends filepath to filepath_list if filepath does not currently exist in filepath_list
136+ if filepath not in filepath_list :
137+ filepath_list .append (filepath )
138+
139+ #Return filepath_list
140+ return filepath_list
141+
142+ async def scan_and_upload_directories (self , directories : dict [str , DataLakeDirectoryClient ], filesystem_client ):
143+ logger .info ("Scanning and uploading files from directories recursively..." )
144+ for directory , directory_client in directories .items ():
145+ directory_path = os .path .join (self .data_directory , directory )
146+
147+ # Überprüfen, ob 'scandir' existiert und auf False gesetzt ist
148+ if not self .data_access_control_format ["directories" ][directory ].get ("scandir" , True ):
149+ logger .info (f"Skipping directory { directory } as 'scandir' is set to False" )
150+ continue
151+
152+ groups = self .data_access_control_format ["directories" ][directory ].get ("groups" , [])
153+
154+ # Check if the directory exists before walking it
155+ if not os .path .exists (directory_path ):
156+ logger .warning (f"Directory does not exist: { directory_path } " )
157+ continue
158+
159+ # Get all file paths using the walk_files function
160+ file_paths = await self .walk_files (directory_path )
161+
162+ # Upload each file collected
163+ for file_path in file_paths :
164+ await self .upload_file (directory_client , file_path )
165+ logger .info (f"Uploaded { file_path } to { directory } " )
166+
113167 def create_service_client (self ):
114168 return DataLakeServiceClient (
115169 account_url = f"https://{ self .storage_account_name } .dfs.core.windows.net" , credential = self .credentials
116170 )
117171
118- async def upload_file (self , directory_client : DataLakeDirectoryClient , file_path : str ):
172+ async def calc_md5 (self , path : str ) -> str :
173+ with open (path , "rb" ) as file :
174+ return hashlib .md5 (file .read ()).hexdigest ()
175+
176+ async def check_md5 (self , path : str , md5_hash : str ) -> bool :
177+ # if filename ends in .md5 skip
178+ if path .endswith (".md5" ):
179+ return True
180+
181+ # if there is a file called .md5 in this directory, see if its updated
182+ stored_hash = None
183+ hash_path = f"{ path } .md5"
184+ os .path .exists (hash_path ):
185+ with open (hash_path , encoding = "utf-8" ) as md5_f :
186+ stored_hash = md5_f .read ()
187+
188+ if stored_hash and stored_hash .strip () == md5_hash .strip ():
189+ logger .info ("Skipping %s, no changes detected." , path )
190+ return True
191+
192+ # Write the hash
193+ with open (hash_path , "w" , encoding = "utf-8" ) as md5_f :
194+ md5_f .write (md5_hash )
195+
196+ return False
197+
198+ async def upload_file (self , directory_client : DataLakeDirectoryClient , file_path : str , category : str ):
199+ # Calculate MD5 hash once
200+ md5_hash = await self .calc_md5 (file_path )
201+
202+ # Check if the file has been uploaded or if it has changed
203+ if await self .check_md5 (file_path , md5_hash ):
204+ logger .info ("File %s has already been uploaded, skipping upload." , file_path )
205+ return # Skip uploading if the MD5 check indicates no changes
206+
207+ # Proceed with the upload since the file has changed
119208 with open (file = file_path , mode = "rb" ) as f :
120209 file_client = directory_client .get_file_client (file = os .path .basename (file_path ))
121- await file_client .upload_data (f , overwrite = True )
210+ last_modified = datetime .fromtimestamp (os .path .getmtime (file_path )).isoformat ()
211+ title = os .path .splitext (os .path .basename (file_path ))[0 ]
212+ metadata = {
213+ "md5" : md5_hash ,
214+ "category" : category ,
215+ "updated" : last_modified ,
216+ "title" : title
217+ }
218+ await file_client .upload_data (f , overwrite = True , metadata = metadata )
219+ logger .info ("File %s uploaded with metadata %s." , file_path , metadata )
122220
123221 async def create_or_get_group (self , group_name : str ):
124222 group_id = None
@@ -144,6 +242,7 @@ async def create_or_get_group(self, group_name: str):
144242 # If Unified does not work for you, then you may need the following settings instead:
145243 # "mailEnabled": False,
146244 # "mailNickname": group_name,
245+
147246 }
148247 async with session .post ("https://graph.microsoft.com/v1.0/groups" , json = group ) as response :
149248 content = await response .json ()
@@ -165,19 +264,19 @@ async def main(args: Any):
165264 data_access_control_format = json .load (f )
166265 command = AdlsGen2Setup (
167266 data_directory = args .data_directory ,
168- storage_account_name = os .environ ["AZURE_ADLS_GEN2_STORAGE_ACCOUNT" ],
169- filesystem_name = "gptkbcontainer" ,
267+ storage_account_name = os .environ ["AZURE_ADLS_GEN2_STORAGE_ACCOUNT" ],
268+ filesystem_name = os . environ [ "AZURE_ADLS_GEN2_FILESYSTEM" ] ,
170269 security_enabled_groups = args .create_security_enabled_groups ,
171270 credentials = credentials ,
172271 data_access_control_format = data_access_control_format ,
173272 )
174- await command .run ()
273+ await command .run (args . scandirs )
175274
176275
177276if __name__ == "__main__" :
178277 parser = argparse .ArgumentParser (
179- description = "Upload sample data to a Data Lake Storage Gen2 account and associate sample access control lists with it using sample groups" ,
180- epilog = "Example: ./scripts/adlsgen2setup.py ./data --data-access-control ./scripts/sampleacls .json --create-security-enabled-groups <true|false>" ,
278+ description = "Upload data to a Data Lake Storage Gen2 account and associate access control lists with it using sample groups" ,
279+ epilog = "Example: ./scripts/adlsgen2setup.py ./data --data-access-control .azure/${AZURE_ENV_NAME}/docs_acls .json --create-security-enabled-groups <true|false> --scandirs " ,
181280 )
182281 parser .add_argument ("data_directory" , help = "Data directory that contains sample PDFs" )
183282 parser .add_argument (
@@ -190,6 +289,7 @@ async def main(args: Any):
190289 "--data-access-control" , required = True , help = "JSON file describing access control for the sample data"
191290 )
192291 parser .add_argument ("--verbose" , "-v" , required = False , action = "store_true" , help = "Verbose output" )
292+ parser .add_argument ("--scandirs" , required = False , action = "store_true" , help = "Scan and upload all files from directories recursively" )
193293 args = parser .parse_args ()
194294 if args .verbose :
195295 logging .basicConfig ()
0 commit comments