|
| 1 | +import os |
| 2 | +import tempfile |
| 3 | +import threading |
| 4 | +from concurrent.futures import ThreadPoolExecutor, as_completed |
| 5 | +from contextlib import suppress |
| 6 | + |
| 7 | +import azure.storage.filedatalake as azurelake # type: ignore |
| 8 | +import config |
| 9 | +import pydicom |
| 10 | + |
| 11 | +# Number of parallel workers (one per subfolder) |
| 12 | +MAX_WORKERS = 8 |
| 13 | +_print_lock = threading.Lock() |
| 14 | + |
| 15 | + |
| 16 | +BASE_ORIGINAL = "AI-READI/year3-fix/retinal_octa/enface/zeiss_cirrus" |
| 17 | +BASE_OUTPUT = "AI-READI/year3-fix/retinal_octa/enface/zeiss_cirrus_fixed" |
| 18 | + |
| 19 | + |
| 20 | +def update_cirrus_enface(inputfile, outputfile): |
| 21 | + dcm = pydicom.dcmread(inputfile) |
| 22 | + |
| 23 | + dcm.ImageOrientationPatient = "" |
| 24 | + dcm.PixelSpacing = [0.005859375, 0.005859375] |
| 25 | + |
| 26 | + dcm.save_as(outputfile, write_like_original=False) |
| 27 | + |
| 28 | + |
| 29 | +def get_file_system_client(): |
| 30 | + """Connect to Azure Data Lake (fairhubproduction) using project config.""" |
| 31 | + return azurelake.FileSystemClient.from_connection_string( |
| 32 | + config.AZURE_STORAGE_PRODUCTION_DANGEROUS_CONNECTION_STRING, |
| 33 | + file_system_name="stage-one", |
| 34 | + ) |
| 35 | + |
| 36 | + |
| 37 | +def list_subfolders(fs_client, folder_path): |
| 38 | + """List immediate subfolder names under folder_path (e.g. ['1269', '1270']).""" |
| 39 | + prefix = folder_path.rstrip("/") + "/" |
| 40 | + paths = fs_client.get_paths(path=folder_path, recursive=True) |
| 41 | + subfolders = set() |
| 42 | + for item in paths: |
| 43 | + name = (item.name or "").strip() |
| 44 | + if not name or name == folder_path.strip("/"): |
| 45 | + continue |
| 46 | + rest = name[len(prefix) :] if name.startswith(prefix) else name |
| 47 | + parts = rest.split("/") |
| 48 | + if parts and parts[0]: |
| 49 | + subfolders.add(parts[0]) |
| 50 | + return sorted(subfolders) |
| 51 | + |
| 52 | + |
| 53 | +def list_files_in_folder(fs_client, folder_path): |
| 54 | + """List all files (blobs) under folder_path; skip directory markers.""" |
| 55 | + paths = fs_client.get_paths(path=folder_path, recursive=True) |
| 56 | + files = [] |
| 57 | + for item in paths: |
| 58 | + name = (item.name or "").strip() |
| 59 | + if not name or name == folder_path.strip("/"): |
| 60 | + continue |
| 61 | + if name.endswith("/"): |
| 62 | + continue |
| 63 | + with suppress(Exception): |
| 64 | + fc = fs_client.get_file_client(file_path=name) |
| 65 | + props = fc.get_file_properties() |
| 66 | + if getattr(props, "metadata", {}) and props.metadata.get("hdi_isfolder"): |
| 67 | + continue |
| 68 | + files.append(name) |
| 69 | + return files |
| 70 | + |
| 71 | + |
| 72 | +def process_one_file(fs_client, remote_path): |
| 73 | + """Download one file, run update_cirrus_enface, upload to BASE_OUTPUT. Returns (success, message).""" |
| 74 | + prefix = BASE_ORIGINAL.rstrip("/") + "/" |
| 75 | + if not remote_path.startswith(prefix): |
| 76 | + return False, f"Path not under BASE_ORIGINAL: {remote_path}" |
| 77 | + relative = remote_path[len(prefix) :] |
| 78 | + out_blob_path = f"{BASE_OUTPUT.rstrip('/')}/{relative}" |
| 79 | + |
| 80 | + download_fd = tempfile.NamedTemporaryFile( |
| 81 | + delete=False, suffix=".dcm", prefix="cirrus_dl_" |
| 82 | + ) |
| 83 | + write_fd = tempfile.NamedTemporaryFile( |
| 84 | + delete=False, suffix=".dcm", prefix="cirrus_out_" |
| 85 | + ) |
| 86 | + try: |
| 87 | + download_path = download_fd.name |
| 88 | + write_path = write_fd.name |
| 89 | + finally: |
| 90 | + download_fd.close() |
| 91 | + write_fd.close() |
| 92 | + |
| 93 | + try: |
| 94 | + file_client = fs_client.get_file_client(file_path=remote_path) |
| 95 | + with open(download_path, "wb") as f: |
| 96 | + f.write(file_client.download_file().readall()) |
| 97 | + except Exception as e: |
| 98 | + for p in (download_path, write_path): |
| 99 | + with suppress(FileNotFoundError): |
| 100 | + os.unlink(p) |
| 101 | + return False, f"Download failed {remote_path}: {e}" |
| 102 | + |
| 103 | + try: |
| 104 | + update_cirrus_enface(download_path, write_path) |
| 105 | + except Exception as e: |
| 106 | + for p in (download_path, write_path): |
| 107 | + with suppress(FileNotFoundError): |
| 108 | + os.unlink(p) |
| 109 | + return False, f"Transform failed {remote_path}: {e}" |
| 110 | + |
| 111 | + try: |
| 112 | + out_client = fs_client.get_file_client(file_path=out_blob_path) |
| 113 | + with open(write_path, "rb") as f: |
| 114 | + out_client.upload_data(f.read(), overwrite=True) |
| 115 | + return True, out_blob_path |
| 116 | + except Exception as e: |
| 117 | + return False, f"Upload failed {out_blob_path}: {e}" |
| 118 | + finally: |
| 119 | + for p in (download_path, write_path): |
| 120 | + with suppress(FileNotFoundError): |
| 121 | + os.unlink(p) |
| 122 | + |
| 123 | + |
| 124 | +def _safe_print(msg: str) -> None: |
| 125 | + with _print_lock: |
| 126 | + print(msg) |
| 127 | + |
| 128 | + |
| 129 | +def process_one_folder(fs_client, subfolder_name): |
| 130 | + """Process all files in BASE_ORIGINAL/subfolder_name; download, fix, upload. Returns (subfolder, ok_count, skip_count).""" |
| 131 | + folder_path = f"{BASE_ORIGINAL.rstrip('/')}/{subfolder_name}" |
| 132 | + files = list_files_in_folder(fs_client, folder_path) |
| 133 | + ok, skip = 0, 0 |
| 134 | + for remote_path in files: |
| 135 | + success, msg = process_one_file(fs_client, remote_path) |
| 136 | + if success: |
| 137 | + ok += 1 |
| 138 | + _safe_print(f" [OK] {msg}") |
| 139 | + else: |
| 140 | + skip += 1 |
| 141 | + _safe_print(f" [SKIP] {msg}") |
| 142 | + return subfolder_name, ok, skip |
| 143 | + |
| 144 | + |
| 145 | +def main(): |
| 146 | + """Fix the cirrus data files: list subfolders of BASE_ORIGINAL, process each folder in parallel, upload to BASE_OUTPUT.""" |
| 147 | + fs_client = get_file_system_client() |
| 148 | + subfolders = list_subfolders(fs_client, BASE_ORIGINAL) |
| 149 | + if not subfolders: |
| 150 | + print(f"No subfolders found under {BASE_ORIGINAL}") |
| 151 | + return |
| 152 | + print(f"Found {len(subfolders)} subfolder(s) under {BASE_ORIGINAL}: {subfolders}") |
| 153 | + total_ok, total_skip = 0, 0 |
| 154 | + with ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor: |
| 155 | + futures = { |
| 156 | + executor.submit(process_one_folder, fs_client, name): name |
| 157 | + for name in subfolders |
| 158 | + } |
| 159 | + for future in as_completed(futures): |
| 160 | + subfolder_name = futures[future] |
| 161 | + try: |
| 162 | + name, ok, skip = future.result() |
| 163 | + total_ok += ok |
| 164 | + total_skip += skip |
| 165 | + _safe_print(f" Done folder {name}: {ok} OK, {skip} skip") |
| 166 | + except Exception as e: |
| 167 | + _safe_print(f" [ERROR] folder {subfolder_name}: {e}") |
| 168 | + print(f"Total: {total_ok} OK, {total_skip} skip") |
| 169 | + |
| 170 | + |
| 171 | +if __name__ == "__main__": |
| 172 | + main() |
0 commit comments