77
88"""
99
10+ from concurrent .futures import ThreadPoolExecutor , as_completed
11+ from io import BytesIO
12+ from zipfile import ZipFile
13+
1014import networkx as nx
1115import numpy as np
16+ from google .cloud import storage
1217
1318from segmentation_skeleton_metrics import utils
1419
1520
16- def read (path , cloud_read = False ):
17- return read_from_cloud (path ) if cloud_read else read_from_local (path )
21+ def parse (swc_paths , min_size , anisotropy = [1.0 , 1.0 , 1.0 ]):
22+ """
23+ Reads swc files and extracts the xyz coordinates.
24+
25+ Paramters
26+ ---------
27+ swc_paths : list or dict
28+ If swc files are on local machine, list of paths to swc files where
29+ each file corresponds to a neuron in the prediction. If swc files are
30+ on cloud, then dict with keys "bucket_name" and "path".
31+ min_size : int
32+ Threshold on the number of nodes contained in an swc file. Only swc
33+ files with more than "min_size" nodes are stored in "valid_labels".
34+ anisotropy : list[float]
35+ Image to World scaling factors applied to xyz coordinates to account
36+ for anisotropy of the microscope.
1837
38+ Returns
39+ -------
40+ dict
41+ ...
42+ """
43+ if type (swc_paths ) == list :
44+ return parse_local_paths (swc_paths , min_size , anisotropy )
45+ elif type (swc_paths ) == dict :
46+ return parse_cloud_paths (swc_paths , min_size , anisotropy )
47+ else :
48+ return None
49+
50+
51+ def parse_local_paths (pred_swc_paths , min_size , anisotropy ):
52+ valid_labels = dict ()
53+ for path in pred_swc_paths :
54+ contents = read_from_local (path )
55+ if len (contents ) > min_size :
56+ swc_id = int (utils .get_swc_id (path ))
57+ valid_labels [swc_id ] = get_coords (contents , anisotropy )
58+ return valid_labels
59+
60+
61+ def parse_cloud_paths (cloud_dict , min_size , anisotropy ):
62+ # Initializations
63+ bucket = storage .Client ().bucket (cloud_dict ["bucket_name" ])
64+ zip_paths = utils .list_gcs_filenames (bucket , cloud_dict ["path" ], ".zip" )
65+ chunk_size = int (len (zip_paths ) * 0.02 )
66+
67+ # Parse
68+ cnt = 1
69+ valid_labels = dict ()
70+ print ("Downloading predicted swc files from cloud..." )
71+ print ("# zip files:" , len (zip_paths ))
72+ for i , path in enumerate (zip_paths ):
73+ valid_labels .update (download (bucket , path , min_size , anisotropy ))
74+ if i > cnt * chunk_size :
75+ utils .progress_bar (i + 1 , len (zip_paths ))
76+ cnt += 1
77+
78+ # Report Results
79+ print ("\n #Valid Labels:" , len (valid_labels ))
80+ print ("" )
81+ return valid_labels
82+
83+
84+ def download (bucket , zip_path , min_size , anisotropy ):
85+ zip_content = bucket .blob (zip_path ).download_as_bytes ()
86+ with ZipFile (BytesIO (zip_content )) as zip_file :
87+ with ThreadPoolExecutor () as executor :
88+ # Assign threads
89+ threads = []
90+ for path in utils .list_files_in_gcs_zip (zip_content ):
91+ threads .append (
92+ executor .submit (
93+ parse_gcs_zip , zip_file , path , min_size , anisotropy
94+ )
95+ )
1996
20- def read_from_cloud (path ):
21- pass
97+ # Process results
98+ valid_labels = dict ()
99+ for thread in as_completed (threads ):
100+ valid_labels .update (thread .result ())
101+ return valid_labels
102+
103+
104+ def parse_gcs_zip (zip_file , path , min_size , anisotropy ):
105+ contents = read_from_cloud (zip_file , path )
106+ if len (contents ) > min_size :
107+ swc_id = int (utils .get_swc_id (path ))
108+ return {swc_id : get_coords (contents , anisotropy )}
109+ else :
110+ return dict ()
22111
23112
24113def read_from_local (path ):
@@ -40,38 +129,44 @@ def read_from_local(path):
40129 return file .readlines ()
41130
42131
43- def get_xyz_coords (path , anisotropy = [1.0 , 1.0 , 1.0 ]):
132+ def read_from_cloud (zip_file , path ):
133+ """
134+ Reads the content of an swc file from a zip file in a GCS bucket.
135+
136+ """
137+ with zip_file .open (path ) as text_file :
138+ return text_file .read ().decode ("utf-8" ).splitlines ()
139+
140+
141+ def get_coords (contents , anisotropy ):
44142 """
45143 Gets the xyz coords from the swc file at "path".
46144
47145 Parameters
48146 ----------
49147 path : str
50148 Path to swc file to be parsed.
51- anisotropy : list[float], optional
52- Scaling factors applied to xyz coordinates to account for anisotropy
53- of the microscope. The default is [1.0, 1.0, 1.0] .
149+ anisotropy : list[float]
150+ Image to World scaling factors applied to xyz coordinates to account
151+ for anisotropy of the microscope.
54152
55153 Returns
56154 -------
57155 numpy.ndarray
58156 xyz coords from an swc file.
59157
60158 """
61- xyz_list = []
62- with open (path , "r" ) as f :
63- offset = [0 , 0 , 0 ]
64- for line in f .readlines ():
65- if line .startswith ("# OFFSET" ):
66- parts = line .split ()
67- offset = read_xyz (parts [2 :5 ])
68- if not line .startswith ("#" ):
69- parts = line .split ()
70- xyz = read_xyz (
71- parts [2 :5 ], anisotropy = anisotropy , offset = offset
72- )
73- xyz_list .append (xyz )
74- return np .array (xyz_list )
159+ coords_list = []
160+ offset = [0 , 0 , 0 ]
161+ for line in contents :
162+ if line .startswith ("# OFFSET" ):
163+ parts = line .split ()
164+ offset = read_xyz (parts [2 :5 ])
165+ if not line .startswith ("#" ):
166+ parts = line .split ()
167+ coord = read_xyz (parts [2 :5 ], anisotropy = anisotropy , offset = offset )
168+ coords_list .append (coord )
169+ return np .array (coords_list )
75170
76171
77172def read_xyz (xyz , anisotropy = [1.0 , 1.0 , 1.0 ], offset = [0 , 0 , 0 ]):
@@ -81,8 +176,8 @@ def read_xyz(xyz, anisotropy=[1.0, 1.0, 1.0], offset=[0, 0, 0]):
81176
82177 Parameters
83178 ----------
84- xyz : str
85- (x,y,z) coordinates .
179+ coord : str
180+ xyz coordinate .
86181 anisotropy : list[float], optional
87182 Image to real-world coordinates scaling factors applied to "xyz". The
88183 default is [1.0, 1.0, 1.0].
@@ -178,7 +273,7 @@ def to_graph(path, anisotropy=[1.0, 1.0, 1.0]):
178273 """
179274 graph = nx .Graph (swc_id = utils .get_swc_id (path ))
180275 offset = [0 , 0 , 0 ]
181- for line in read (path ):
276+ for line in read_from_local (path ):
182277 if line .startswith ("# OFFSET" ):
183278 parts = line .split ()
184279 offset = read_xyz (parts [2 :5 ])
0 commit comments