Skip to content

Commit 95a70d8

Browse files
anna-grimanna-grim
andauthored
Feat save merge sites (#106)
* refactor: improved new features * improved new feature * refactor: merge site save * bug: mkdir output_dir * feat: read zips of swcs on gcs --------- Co-authored-by: anna-grim <[email protected]>
1 parent ea049f1 commit 95a70d8

File tree

2 files changed

+174
-0
lines changed

2 files changed

+174
-0
lines changed

src/segmentation_skeleton_metrics/utils/swc_util.py

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
ProcessPoolExecutor,
2929
ThreadPoolExecutor,
3030
)
31+
from google.cloud import storage
3132
from io import StringIO
3233
from tqdm import tqdm
3334
from zipfile import ZipFile
@@ -95,6 +96,10 @@ def read(self, swc_pointer):
9596
- "swc_id": name of SWC file, minus the ".swc".
9697
9798
"""
99+
# Dictionary with GCS specs
100+
if isinstance(swc_pointer, dict):
101+
return self.read_from_gcs(swc_pointer)
102+
98103
# List of paths to SWC files
99104
if isinstance(swc_pointer, list):
100105
return self.read_from_paths(swc_pointer)
@@ -278,6 +283,97 @@ def read_from_zipped_file(self, zipfile, path):
278283
filename = os.path.basename(path)
279284
return self.parse(content, filename)
280285

286+
def read_from_gcs(self, gcs_dict):
287+
"""
288+
Reads SWC files from ZIP archives stored in a GCS bucket.
289+
290+
Parameters
291+
----------
292+
gcs_dict : dict
293+
Dictionary with the keys "bucket_name" and "path" that specify
294+
where the ZIP archives are located in a GCS bucket.
295+
296+
Returns
297+
-------
298+
Dequeue[dict]
299+
List of dictionaries whose keys and values are the attribute
300+
names and values from an SWC file.
301+
302+
"""
303+
# List filenames
304+
bucket = storage.Client().bucket(gcs_dict["bucket_name"])
305+
swc_paths = util.list_gcs_filenames(bucket, gcs_dict["path"], ".swc")
306+
zip_paths = util.list_gcs_filenames(bucket, gcs_dict["path"], ".zip")
307+
308+
# Call reader
309+
if len(swc_paths) > 0:
310+
return self.read_from_gcs_swcs(bucket, swc_paths)
311+
if len(zip_paths) > 0:
312+
return self.read_from_gcs_zips(bucket, zip_paths)
313+
314+
# Error
315+
raise Exception(f"GCS Pointer is invalid -{gcs_dict}-")
316+
317+
def read_from_gcs_swcs(self, bucket, swc_paths):
318+
pass
319+
320+
def read_from_gcs_zips(self, bucket, zip_paths):
321+
# Main
322+
pbar = tqdm(total=len(zip_paths), desc="Read SWCs")
323+
with ProcessPoolExecutor() as executor:
324+
# Assign processes
325+
processes = list()
326+
for path in zip_paths:
327+
zip_content = bucket.blob(path).download_as_bytes()
328+
processes.append(
329+
executor.submit(self.read_from_gcs_zip, zip_content)
330+
)
331+
332+
# Store results
333+
swc_dicts = deque()
334+
for process in as_completed(processes):
335+
swc_dicts.extend(process.result())
336+
pbar.update(1)
337+
return swc_dicts
338+
339+
def read_from_gcs_zip(self, zip_content):
340+
"""
341+
Reads SWC files stored in a ZIP archive downloaded from a GCS
342+
bucket.
343+
344+
Parameters
345+
----------
346+
zip_content : bytes
347+
Content of a ZIP archive.
348+
349+
Returns
350+
-------
351+
Dequeue[dict]
352+
List of dictionaries whose keys and values are the attribute
353+
names and values from an SWC file.
354+
355+
356+
"""
357+
with ZipFile(BytesIO(zip_content)) as zip_file:
358+
with ThreadPoolExecutor() as executor:
359+
# Assign threads
360+
threads = list()
361+
for filename in util.list_files_in_zip(zip_content):
362+
if self.confirm_read(filename):
363+
threads.append(
364+
executor.submit(
365+
self.read_from_zipped_file, zip_file, filename
366+
)
367+
)
368+
369+
# Process results
370+
swc_dicts = deque()
371+
for thread in as_completed(threads):
372+
result = thread.result()
373+
if result:
374+
swc_dicts.append(result)
375+
return swc_dicts
376+
281377
def confirm_read(self, filename):
282378
"""
283379
Checks whether the swc_id corresponding to the given filename is

src/segmentation_skeleton_metrics/utils/util.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,84 @@ def update_txt(path, text):
167167
file.write(text + "\n")
168168

169169

170+
# -- GCS utils --
171+
def list_files_in_zip(zip_content):
172+
"""
173+
Lists all files in a zip file stored in a GCS bucket.
174+
175+
Parameters
176+
----------
177+
zip_content : str
178+
Content stored in a zip file in the form of a string of bytes.
179+
180+
Returns
181+
-------
182+
list[str]
183+
List of filenames in a zip file.
184+
185+
"""
186+
with ZipFile(BytesIO(zip_content), "r") as zip_file:
187+
return zip_file.namelist()
188+
189+
190+
def list_gcs_filenames(bucket, prefix, extension):
191+
"""
192+
Lists all files in a GCS bucket with the given extension.
193+
194+
Parameters
195+
----------
196+
bucket : google.cloud.client
197+
Name of bucket to be read from.
198+
prefix : str
199+
Path to directory in "bucket".
200+
extension : str
201+
File extension of filenames to be listed.
202+
203+
Returns
204+
-------
205+
list
206+
Filenames stored at "cloud" path with the given extension.
207+
208+
"""
209+
blobs = bucket.list_blobs(prefix=prefix)
210+
return [blob.name for blob in blobs if extension in blob.name]
211+
212+
213+
def list_gcs_subdirectories(bucket_name, prefix):
214+
"""
215+
Lists all direct subdirectories of a given prefix in a GCS bucket.
216+
217+
Parameters
218+
----------
219+
bucket : str
220+
Name of bucket to be read from.
221+
prefix : str
222+
Path to directory in "bucket".
223+
224+
Returns
225+
-------
226+
list[str]
227+
List of direct subdirectories.
228+
229+
"""
230+
# Load blobs
231+
storage_client = storage.Client()
232+
blobs = storage_client.list_blobs(
233+
bucket_name, prefix=prefix, delimiter="/"
234+
)
235+
[blob.name for blob in blobs]
236+
237+
# Parse directory contents
238+
prefix_depth = len(prefix.split("/"))
239+
subdirs = list()
240+
for prefix in blobs.prefixes:
241+
is_dir = prefix.endswith("/")
242+
is_direct_subdir = len(prefix.split("/")) - 1 == prefix_depth
243+
if is_dir and is_direct_subdir:
244+
subdirs.append(prefix)
245+
return subdirs
246+
247+
170248
# --- Miscellaneous ---
171249
def get_segment_id(filename):
172250
"""

0 commit comments

Comments
 (0)