Skip to content

Commit e5b43bf

Browse files
anna-grimanna-grim
andauthored
refactor: improved txt reader (#173)
* refactor: improved txt reader * remove print --------- Co-authored-by: anna-grim <anna.grim@alleninstitute.org>
1 parent 5b2647c commit e5b43bf

File tree

2 files changed

+11
-5
lines changed

2 files changed

+11
-5
lines changed

src/segmentation_skeleton_metrics/data_handling/swc_loading.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -452,7 +452,7 @@ def read_from_s3(self, s3_path):
452452
# Parse SWC files
453453
swc_dicts = deque()
454454
for path in swc_paths:
455-
content = util.read_txt_from_s3(bucket_name, path).splitlines()
455+
content = util.read_txt(bucket_name, path).splitlines()
456456
filename = os.path.basename(path)
457457
result = self.parse(content, filename)
458458
if result:

src/segmentation_skeleton_metrics/utils/util.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -165,8 +165,13 @@ def read_txt(path):
165165
List[str]
166166
Lines from the txt file.
167167
"""
168-
with open(path, "r") as f:
169-
return f.read().splitlines()
168+
if is_s3_path(path):
169+
return read_txt_from_s3(path)
170+
elif is_gcs_path(path):
171+
return read_txt_from_gcs(path)
172+
else:
173+
with open(path, "r") as f:
174+
return f.read().splitlines()
170175

171176

172177
def update_txt(path, text, verbose=True):
@@ -427,7 +432,7 @@ def list_s3_paths(bucket_name, prefix, extension=""):
427432
return filenames
428433

429434

430-
def read_txt_from_s3(bucket_name, path):
435+
def read_txt_from_s3(path):
431436
"""
432437
Reads a txt file stored in an S3 bucket.
433438
@@ -443,6 +448,7 @@ def read_txt_from_s3(bucket_name, path):
443448
str
444449
Contents of txt file.
445450
"""
451+
bucket_name, path = parse_cloud_path(path)
446452
s3 = boto3.client("s3", config=Config(signature_version=UNSIGNED))
447453
obj = s3.get_object(Bucket=bucket_name, Key=path)
448454
return obj['Body'].read().decode('utf-8')
@@ -558,7 +564,7 @@ def load_valid_labels(path):
558564
Segment IDs that can be assigned to nodes.
559565
"""
560566
valid_labels = set()
561-
for label_str in read_txt(path):
567+
for label_str in read_txt(path).splitlines():
562568
valid_labels.add(int(label_str.split(".")[0]))
563569
return valid_labels
564570

0 commit comments

Comments
 (0)