@@ -165,8 +165,13 @@ def read_txt(path):
165165 List[str]
166166 Lines from the txt file.
167167 """
168- with open (path , "r" ) as f :
169- return f .read ().splitlines ()
168+ if is_s3_path (path ):
169+ return read_txt_from_s3 (path )
170+ elif is_gcs_path (path ):
171+ return read_txt_from_gcs (path )
172+ else :
173+ with open (path , "r" ) as f :
174+ return f .read ().splitlines ()
170175
171176
172177def update_txt (path , text , verbose = True ):
@@ -427,7 +432,7 @@ def list_s3_paths(bucket_name, prefix, extension=""):
427432 return filenames
428433
429434
430- def read_txt_from_s3 (bucket_name , path ):
435+ def read_txt_from_s3 (path ):
431436 """
432437 Reads a txt file stored in an S3 bucket.
433438
@@ -443,6 +448,7 @@ def read_txt_from_s3(bucket_name, path):
443448 str
444449 Contents of txt file.
445450 """
451+ bucket_name , path = parse_cloud_path (path )
446452 s3 = boto3 .client ("s3" , config = Config (signature_version = UNSIGNED ))
447453 obj = s3 .get_object (Bucket = bucket_name , Key = path )
448454 return obj ['Body' ].read ().decode ('utf-8' )
@@ -558,7 +564,7 @@ def load_valid_labels(path):
558564 Segment IDs that can be assigned to nodes.
559565 """
560566 valid_labels = set ()
561- for label_str in read_txt (path ):
567+ for label_str in read_txt (path ). splitlines () :
562568 valid_labels .add (int (label_str .split ("." )[0 ]))
563569 return valid_labels
564570
0 commit comments