Skip to content

Commit 33dd804

Browse files
anna-grimanna-grim
andauthored
feat: tools for reading from s3 (#139)
Co-authored-by: anna-grim <[email protected]>
1 parent db6302c commit 33dd804

File tree

2 files changed

+175
-31
lines changed

2 files changed

+175
-31
lines changed

src/segmentation_skeleton_metrics/utils/swc_util.py

Lines changed: 71 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -86,41 +86,45 @@ def read(self, swc_pointer):
8686
- "filename": filename of SWC file
8787
- "swc_id": name of SWC file, minus the ".swc".
8888
"""
89-
# Dictionary with GCS specs
90-
if isinstance(swc_pointer, dict):
91-
return self.read_from_gcs(swc_pointer)
92-
93-
# List of paths to SWC files
89+
# List of local paths to SWC files
9490
if isinstance(swc_pointer, list):
9591
return self.read_from_paths(swc_pointer)
9692

9793
# Directory containing...
9894
if os.path.isdir(swc_pointer):
99-
# ZIP archives with SWC files
95+
# Local ZIP archives with SWC files
10096
paths = util.list_paths(swc_pointer, extension=".zip")
10197
if len(paths) > 0:
10298
return self.read_from_zips(swc_pointer)
10399

104-
# SWC files
100+
# Local SWC files
105101
paths = util.read_paths(swc_pointer, extension=".swc")
106102
if len(paths) > 0:
107103
return self.read_from_paths(paths)
108104

109-
raise Exception("Directory is invalid!")
105+
raise Exception("Directory is Invalid!")
110106

111107
# Path to...
112108
if isinstance(swc_pointer, str):
113-
# ZIP archive with SWC files
109+
# Cloud GCS storage
110+
if util.is_gcs_path(swc_pointer):
111+
return self.read_from_gcs(swc_pointer)
112+
113+
# Cloud S3 storage
114+
if util.is_s3_path(swc_pointer):
115+
return self.read_from_s3(swc_pointer)
116+
117+
# Local ZIP archive with SWC files
114118
if swc_pointer.endswith(".zip"):
115119
return self.read_from_zip(swc_pointer)
116120

117-
# Path to single SWC file
121+
# Local path to single SWC file
118122
if swc_pointer.endswith(".swc"):
119123
return self.read_from_path(swc_pointer)
120124

121-
raise Exception("Path is invalid!")
125+
raise Exception("Path is Invalid!")
122126

123-
raise Exception("SWC Pointer is inValid!")
127+
raise Exception("SWC Pointer is Invalid!")
124128

125129
def read_from_path(self, path):
126130
"""
@@ -268,15 +272,17 @@ def read_from_zipped_file(self, zipfile, path):
268272
filename = os.path.basename(path)
269273
return self.parse(content, filename)
270274

271-
def read_from_gcs(self, gcs_dict):
275+
def read_from_gcs(self, gcs_path):
272276
"""
273277
Reads SWC files stored in a GCS bucket.
274278
275279
Parameters
276280
----------
277-
gcs_dict : dict
278-
Dictionary with the keys "bucket_name" and "path" that specify
279-
where the SWC files are located in a GCS bucket.
281+
gcs_path : str
282+
Path to location in a GCS bucket that the SWC files are stored.
283+
The path must be in the format "gs://{bucket_name}/{prefix}",
284+
where "prefix" is a path to a directory containing SWC files or
285+
ZIP archives containing SWC files
280286
281287
Returns
282288
-------
@@ -285,17 +291,18 @@ def read_from_gcs(self, gcs_dict):
285291
names and values from an SWC file.
286292
"""
287293
# List filenames
288-
swc_paths = util.list_gcs_filenames(gcs_dict, ".swc")
289-
zip_paths = util.list_gcs_filenames(gcs_dict, ".zip")
294+
bucket_name, prefix = parse_cloud_path(gcs_path)
295+
swc_paths = util.list_gcs_filenames(bucket_name, prefix, ".swc")
296+
zip_paths = util.list_gcs_filenames(bucket_name, prefix, ".zip")
290297

291298
# Call reader
292299
if len(swc_paths) > 0:
293-
return self.read_from_gcs_swcs(gcs_dict["bucket_name"], swc_paths)
300+
return self.read_from_gcs_swcs(bucket_name, swc_paths)
294301
if len(zip_paths) > 0:
295-
return self.read_from_gcs_zips(gcs_dict["bucket_name"], zip_paths)
302+
return self.read_from_gcs_zips(bucket_name, zip_paths)
296303

297304
# Error
298-
raise Exception(f"GCS Pointer is invalid -{gcs_dict}-")
305+
raise Exception(f"GCS Pointer is invalid -{gcs_path}-")
299306

300307
def read_from_gcs_swcs(self, bucket_name, swc_paths):
301308
"""
@@ -419,6 +426,17 @@ def read_from_gcs_zip(self, bucket_name, path):
419426
swc_dicts.append(result)
420427
return swc_dicts
421428

429+
def read_from_s3_swcs(self, s3_path):
430+
# List filenames
431+
bucket_name, prefix = parse_cloud_path(s3_path)
432+
swc_paths = util.list_s3_paths(bucket_name, prefix, extension=".swc")
433+
434+
# Parse SWC files
435+
swc_dicts = deque()
436+
for path in swc_paths:
437+
contents = util.read_txt_from_s3(bucket_name, path)
438+
return swc_dicts
439+
422440
def confirm_read(self, filename):
423441
"""
424442
Checks whether the swc_id corresponding to the given filename is
@@ -523,7 +541,38 @@ def read_voxel(self, xyz_str, offset):
523541
return img_util.to_voxels(xyz, self.anisotropy)
524542

525543

526-
# --- Write ---
544+
# --- Helpers ---
545+
def parse_cloud_path(path):
546+
"""
547+
Parses a cloud storage path into its bucket name and key/prefix. Supports
548+
paths of the form: "{scheme}://bucket_name/prefix" or without a scheme.
549+
550+
Parameters
551+
----------
552+
path : str
553+
Path to be parsed.
554+
555+
Returns
556+
-------
557+
bucket_name : str
558+
Name of the bucket.
559+
prefix : str
560+
Cloud prefix.
561+
"""
562+
# Remove s3:// if present
563+
if path.startswith("s3://"):
564+
path = path[len("s3://"):]
565+
566+
# Remove gs:// if present
567+
if path.startswith("gs://"):
568+
path = path[len("gs://"):]
569+
570+
parts = path.split("/", 1)
571+
bucket_name = parts[0]
572+
prefix = parts[1] if len(parts) > 1 else ""
573+
return bucket_name, prefix
574+
575+
527576
def to_zipped_point(zip_writer, filename, xyz):
528577
"""
529578
Writes a point to an SWC file format, which is then stored in a ZIP

src/segmentation_skeleton_metrics/utils/util.py

Lines changed: 104 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@
1616
from xlwt import Workbook
1717
from zipfile import ZipFile
1818

19+
import boto3
20+
from botocore import UNSIGNED
21+
from botocore.client import Config
1922
import os
2023
import pandas as pd
2124
import shutil
@@ -153,7 +156,24 @@ def update_txt(path, text):
153156
file.write(text + "\n")
154157

155158

156-
# -- GCS utils --
159+
# -- GCS Utils --
160+
def is_gcs_path(path):
161+
"""
162+
Checks if the path is a GCS path.
163+
164+
Parameters
165+
----------
166+
path : str
167+
Path to be checked.
168+
169+
Returns
170+
-------
171+
bool
172+
Indication of whether the path is a GCS path.
173+
"""
174+
return path.startswith("gs://")
175+
176+
157177
def list_files_in_zip(zip_content):
158178
"""
159179
Lists all files in a zip file stored in a GCS bucket.
@@ -172,14 +192,16 @@ def list_files_in_zip(zip_content):
172192
return zip_file.namelist()
173193

174194

175-
def list_gcs_filenames(gcs_dict, extension):
195+
def list_gcs_filenames(bucket_name, prefix, extension):
176196
"""
177197
Lists all files in a GCS bucket with the given extension.
178198
179199
Parameters
180200
----------
181-
gcs_dict : dict
182-
...
201+
bucket_name : str
202+
Name of bucket to be searched.
203+
prefix : str
204+
Path to location within bucket to be searched.
183205
extension : str
184206
File extension of filenames to be listed.
185207
@@ -188,8 +210,8 @@ def list_gcs_filenames(gcs_dict, extension):
188210
List[str]
189211
Filenames stored at "cloud" path with the given extension.
190212
"""
191-
bucket = storage.Client().bucket(gcs_dict["bucket_name"])
192-
blobs = bucket.list_blobs(prefix=gcs_dict["path"])
213+
bucket = storage.Client().bucket(bucket_name)
214+
blobs = bucket.list_blobs(prefix=prefix)
193215
return [blob.name for blob in blobs if extension in blob.name]
194216

195217

@@ -227,16 +249,16 @@ def list_gcs_subdirectories(bucket_name, prefix):
227249
return subdirs
228250

229251

230-
def read_txt_from_gcs(bucket_name, filename):
252+
def read_txt_from_gcs(bucket_name, path):
231253
"""
232254
Reads a txt file stored in a GCS bucket.
233255
234256
Parameters
235257
----------
236258
bucket_name : str
237259
Name of bucket to be read from.
238-
filename : str
239-
Name of txt file to be read.
260+
path : str
261+
Path to txt file to be read.
240262
241263
Returns
242264
-------
@@ -277,6 +299,79 @@ def upload_directory_to_gcs(bucket_name, source_dir, destination_dir):
277299
blob.upload_from_filename(local_path)
278300

279301

302+
# --- S3 Utils ---
303+
def is_s3_path(path):
304+
"""
305+
Checks if the given path is an S3 path.
306+
307+
Parameters
308+
----------
309+
path : str
310+
Path to be checked.
311+
312+
Returns
313+
-------
314+
bool
315+
Indication of whether the path is an S3 path.
316+
"""
317+
return path.startswith("s3://")
318+
319+
320+
def list_s3_paths(bucket_name, prefix, extension=""):
321+
"""
322+
Lists all object keys in a public S3 bucket under a given prefix,
323+
optionally filters by file extension.
324+
325+
Parameters
326+
----------
327+
bucket_name : str
328+
Name of the S3 bucket.
329+
prefix : str
330+
The S3 "directory" prefix to search under.
331+
extension : str, optional
332+
File extension to filter by. Default is an empty string, which returns
333+
all files.
334+
335+
Returns
336+
-------
337+
List[str]
338+
List of S3 object keys that match the prefix and extension filter.
339+
"""
340+
# Create an anonymous client for public buckets
341+
s3 = boto3.client("s3", config=Config(signature_version=UNSIGNED))
342+
response = s3.list_objects_v2(Bucket=bucket_name, Prefix=prefix)
343+
344+
# List all objects under the prefix
345+
filenames = list()
346+
if "Contents" in response:
347+
for obj in response["Contents"]:
348+
filename = obj["Key"]
349+
if filename.endswith(extension):
350+
filenames.append(filename)
351+
return filenames
352+
353+
354+
def read_txt_from_s3(bucket_name, path):
355+
"""
356+
Reads a txt file stored in an S3 bucket.
357+
358+
Parameters
359+
----------
360+
bucket_name : str
361+
Name of bucket to be read from.
362+
path : str
363+
Path to txt file to be read.
364+
365+
Returns
366+
-------
367+
str
368+
Contents of txt file.
369+
"""
370+
s3 = boto3.client("s3", config=Config(signature_version=UNSIGNED))
371+
obj = s3.get_object(Bucket=bucket_name, Key=path)
372+
return obj['Body'].read().decode('utf-8').splitlines()
373+
374+
280375
# --- Miscellaneous ---
281376
def get_segment_id(filename):
282377
"""

0 commit comments

Comments
 (0)