Skip to content

Commit a608e57

Browse files
authored
add scan option
1 parent 023dc1b commit a608e57

File tree

1 file changed

+113
-13
lines changed

1 file changed

+113
-13
lines changed

scripts/adlsgen2setup.py

Lines changed: 113 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
import argparse
22
import asyncio
3+
import datetime
34
import json
45
import logging
56
import os
7+
import hashlib
68
from typing import Any, Optional
79

810
import aiohttp
@@ -56,7 +58,7 @@ def __init__(
5658
self.data_access_control_format = data_access_control_format
5759
self.graph_headers: Optional[dict[str, str]] = None
5860

59-
async def run(self):
61+
async def run(self, scandirs: bool = False):
6062
async with self.create_service_client() as service_client:
6163
logger.info(f"Ensuring {self.filesystem_name} exists...")
6264
async with service_client.get_file_system_client(self.filesystem_name) as filesystem_client:
@@ -80,15 +82,17 @@ async def run(self):
8082
)
8183
directories[directory] = directory_client
8284

85+
logger.info("Uploading scanned files...")
86+
if scandirs and directory != "/":
87+
await self.scan_and_upload_directories(directories, filesystem_client)
88+
8389
logger.info("Uploading files...")
8490
for file, file_info in self.data_access_control_format["files"].items():
8591
directory = file_info["directory"]
8692
if directory not in directories:
8793
logger.error(f"File {file} has unknown directory {directory}, exiting...")
8894
return
89-
await self.upload_file(
90-
directory_client=directories[directory], file_path=os.path.join(self.data_directory, file)
91-
)
95+
await self.upload_file(directory_client=directories[directory], file_path=os.path.join(self.data_directory, file))
9296

9397
logger.info("Setting access control...")
9498
for directory, access_control in self.data_access_control_format["directories"].items():
@@ -100,8 +104,7 @@ async def run(self):
100104
f"Directory {directory} has unknown group {group_name} in access control list, exiting"
101105
)
102106
return
103-
await directory_client.update_access_control_recursive(
104-
acl=f"group:{groups[group_name]}:r-x"
107+
await directory_client.update_access_control_recursive(acl=f"group:{groups[group_name]}:r-x"
105108
)
106109
if "oids" in access_control:
107110
for oid in access_control["oids"]:
@@ -110,15 +113,110 @@ async def run(self):
110113
for directory_client in directories.values():
111114
await directory_client.close()
112115

116+
async def walk_files(self, src_filepath = "."):
117+
filepath_list = []
118+
119+
#This for loop uses the os.walk() function to walk through the files and directories
120+
#and records the filepaths of the files to a list
121+
for root, dirs, files in os.walk(src_filepath):
122+
123+
#iterate through the files currently obtained by os.walk() and
124+
#create the filepath string for that file and add it to the filepath_list list
125+
for file in files:
126+
#Checks to see if the root is '.' and changes it to the correct current
127+
#working directory by calling os.getcwd(). Otherwise root_path will just be the root variable value.
128+
if root == '.':
129+
root_path = os.getcwd() + "/"
130+
else:
131+
root_path = root
132+
133+
filepath = os.path.join(root_path, file)
134+
135+
#Appends filepath to filepath_list if filepath does not currently exist in filepath_list
136+
if filepath not in filepath_list:
137+
filepath_list.append(filepath)
138+
139+
#Return filepath_list
140+
return filepath_list
141+
142+
async def scan_and_upload_directories(self, directories: dict[str, DataLakeDirectoryClient], filesystem_client):
143+
logger.info("Scanning and uploading files from directories recursively...")
144+
for directory, directory_client in directories.items():
145+
directory_path = os.path.join(self.data_directory, directory)
146+
147+
# Überprüfen, ob 'scandir' existiert und auf False gesetzt ist
148+
if not self.data_access_control_format["directories"][directory].get("scandir", True):
149+
logger.info(f"Skipping directory {directory} as 'scandir' is set to False")
150+
continue
151+
152+
groups = self.data_access_control_format["directories"][directory].get("groups", [])
153+
154+
# Check if the directory exists before walking it
155+
if not os.path.exists(directory_path):
156+
logger.warning(f"Directory does not exist: {directory_path}")
157+
continue
158+
159+
# Get all file paths using the walk_files function
160+
file_paths = await self.walk_files(directory_path)
161+
162+
# Upload each file collected
163+
for file_path in file_paths:
164+
await self.upload_file(directory_client, file_path)
165+
logger.info(f"Uploaded {file_path} to {directory}")
166+
113167
def create_service_client(self):
114168
return DataLakeServiceClient(
115169
account_url=f"https://{self.storage_account_name}.dfs.core.windows.net", credential=self.credentials
116170
)
117171

118-
async def upload_file(self, directory_client: DataLakeDirectoryClient, file_path: str):
172+
async def calc_md5(self, path: str) -> str:
173+
with open(path, "rb") as file:
174+
return hashlib.md5(file.read()).hexdigest()
175+
176+
async def check_md5(self, path: str, md5_hash: str) -> bool:
177+
# if filename ends in .md5 skip
178+
if path.endswith(".md5"):
179+
return True
180+
181+
# if there is a file called .md5 in this directory, see if its updated
182+
stored_hash = None
183+
hash_path = f"{path}.md5"
184+
os.path.exists(hash_path):
185+
with open(hash_path, encoding="utf-8") as md5_f:
186+
stored_hash = md5_f.read()
187+
188+
if stored_hash and stored_hash.strip() == md5_hash.strip():
189+
logger.info("Skipping %s, no changes detected.", path)
190+
return True
191+
192+
# Write the hash
193+
with open(hash_path, "w", encoding="utf-8") as md5_f:
194+
md5_f.write(md5_hash)
195+
196+
return False
197+
198+
async def upload_file(self, directory_client: DataLakeDirectoryClient, file_path: str, category: str):
199+
# Calculate MD5 hash once
200+
md5_hash = await self.calc_md5(file_path)
201+
202+
# Check if the file has been uploaded or if it has changed
203+
if await self.check_md5(file_path, md5_hash):
204+
logger.info("File %s has already been uploaded, skipping upload.", file_path)
205+
return # Skip uploading if the MD5 check indicates no changes
206+
207+
# Proceed with the upload since the file has changed
119208
with open(file=file_path, mode="rb") as f:
120209
file_client = directory_client.get_file_client(file=os.path.basename(file_path))
121-
await file_client.upload_data(f, overwrite=True)
210+
last_modified = datetime.fromtimestamp(os.path.getmtime(file_path)).isoformat()
211+
title = os.path.splitext(os.path.basename(file_path))[0]
212+
metadata = {
213+
"md5": md5_hash,
214+
"category": category,
215+
"updated": last_modified,
216+
"title": title
217+
}
218+
await file_client.upload_data(f, overwrite=True, metadata=metadata)
219+
logger.info("File %s uploaded with metadata %s.", file_path, metadata)
122220

123221
async def create_or_get_group(self, group_name: str):
124222
group_id = None
@@ -144,6 +242,7 @@ async def create_or_get_group(self, group_name: str):
144242
# If Unified does not work for you, then you may need the following settings instead:
145243
# "mailEnabled": False,
146244
# "mailNickname": group_name,
245+
147246
}
148247
async with session.post("https://graph.microsoft.com/v1.0/groups", json=group) as response:
149248
content = await response.json()
@@ -165,19 +264,19 @@ async def main(args: Any):
165264
data_access_control_format = json.load(f)
166265
command = AdlsGen2Setup(
167266
data_directory=args.data_directory,
168-
storage_account_name=os.environ["AZURE_ADLS_GEN2_STORAGE_ACCOUNT"],
169-
filesystem_name="gptkbcontainer",
267+
storage_account_name=os.environ["AZURE_ADLS_GEN2_STORAGE_ACCOUNT"],
268+
filesystem_name=os.environ["AZURE_ADLS_GEN2_FILESYSTEM"],
170269
security_enabled_groups=args.create_security_enabled_groups,
171270
credentials=credentials,
172271
data_access_control_format=data_access_control_format,
173272
)
174-
await command.run()
273+
await command.run(args.scandirs)
175274

176275

177276
if __name__ == "__main__":
178277
parser = argparse.ArgumentParser(
179-
description="Upload sample data to a Data Lake Storage Gen2 account and associate sample access control lists with it using sample groups",
180-
epilog="Example: ./scripts/adlsgen2setup.py ./data --data-access-control ./scripts/sampleacls.json --create-security-enabled-groups <true|false>",
278+
description="Upload data to a Data Lake Storage Gen2 account and associate access control lists with it using sample groups",
279+
epilog="Example: ./scripts/adlsgen2setup.py ./data --data-access-control .azure/${AZURE_ENV_NAME}/docs_acls.json --create-security-enabled-groups <true|false> --scandirs",
181280
)
182281
parser.add_argument("data_directory", help="Data directory that contains sample PDFs")
183282
parser.add_argument(
@@ -190,6 +289,7 @@ async def main(args: Any):
190289
"--data-access-control", required=True, help="JSON file describing access control for the sample data"
191290
)
192291
parser.add_argument("--verbose", "-v", required=False, action="store_true", help="Verbose output")
292+
parser.add_argument("--scandirs", required=False, action="store_true", help="Scan and upload all files from directories recursively")
193293
args = parser.parse_args()
194294
if args.verbose:
195295
logging.basicConfig()

0 commit comments

Comments
 (0)