Skip to content

Commit 248dbda

Browse files
anna-grimanna-grim
andauthored
feat: merge detection on cloud (#53)
Co-authored-by: anna-grim <[email protected]>
1 parent b437561 commit 248dbda

File tree

3 files changed

+234
-115
lines changed

3 files changed

+234
-115
lines changed

src/segmentation_skeleton_metrics/skeleton_metric.py

Lines changed: 24 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,7 @@
1818

1919
from segmentation_skeleton_metrics import graph_utils as gutils
2020
from segmentation_skeleton_metrics import split_detection, swc_utils, utils
21-
from segmentation_skeleton_metrics.swc_utils import (
22-
get_xyz_coords,
23-
save,
24-
to_graph,
25-
)
21+
from segmentation_skeleton_metrics.swc_utils import save, to_graph
2622

2723
INTERSECTION_THRESHOLD = 8
2824
MERGE_DIST_THRESHOLD = 40
@@ -54,8 +50,7 @@ def __init__(
5450
equivalent_ids=None,
5551
ignore_boundary_mistakes=False,
5652
output_dir=None,
57-
pred_on_cloud=False,
58-
valid_size_threshold=40.0,
53+
valid_size_threshold=40,
5954
write_to_swc=False,
6055
):
6156
"""
@@ -66,15 +61,16 @@ def __init__(
6661
----------
6762
pred_labels : numpy.ndarray or tensorstore.TensorStore
6863
Predicted segmentation mask.
69-
pred_swc_paths : list[str]
70-
List of paths to swc files where each file corresponds to a
71-
neuron in the prediction.
7264
target_swc_paths : list[str]
7365
List of paths to swc files where each file corresponds to a
7466
neuron in the ground truth.
7567
anisotropy : list[float], optional
7668
Image to real-world coordinates scaling factors applied to swc
7769
files. The default is [1.0, 1.0, 1.0]
70+
pred_swc_paths : list[str] or dict
71+
If swc files are on local machine, list of paths to swc files where
72+
each file corresponds to a neuron in the prediction. If swc files
73+
are on cloud, then dict with keys "bucket_name" and "path".
7874
black_holes_xyz_id : list, optional
7975
...
8076
black_hole_radius : float, optional
@@ -87,11 +83,10 @@ def __init__(
8783
output_dir : str, optional
8884
Path to directory that each mistake site is written to. The default
8985
is None.
90-
pred_on_cloud : bool, optional
91-
Indication of whether predicted swc files in "pred_swc_paths" are
92-
on the cloud in a GCS bucket. The default is False.
93-
valid_size_threshold : float, optional
94-
...
86+
valid_size_threshold : int, optional
87+
Threshold on the number of nodes contained in an swc file. Only swc
88+
files with more than "valid_size_threshold" nodes are stored in
89+
"self.valid_labels". The default is 40.
9590
write_to_swc : bool, optional
9691
Indication of whether to write mistake sites to an swc file. The
9792
default is False.
@@ -112,8 +107,9 @@ def __init__(
112107

113108
# Build Graphs
114109
self.label_mask = pred_labels
115-
self.pred_swc_paths = pred_swc_paths
116-
self.init_valid_labels(valid_size_threshold)
110+
self.valid_labels = swc_utils.parse(
111+
pred_swc_paths, valid_size_threshold, anisotropy=anisotropy
112+
)
117113

118114
self.target_graphs = self.init_graphs(target_swc_paths, anisotropy)
119115
self.labeled_target_graphs = self.init_labeled_target_graphs()
@@ -124,13 +120,6 @@ def __init__(
124120
self.rm_spurious_intersections()
125121

126122
# -- Initialize and Label Graphs --
127-
def init_valid_labels(self, valid_size_threshold):
128-
self.valid_labels = set()
129-
for path in self.pred_swc_paths:
130-
contents = swc_utils.read(path)
131-
if len(contents) > valid_size_threshold:
132-
self.valid_labels.add(int(utils.get_swc_id(path)))
133-
134123
def init_graphs(self, paths, anisotropy):
135124
"""
136125
Initializes "self.target_graphs" by iterating over "paths" which
@@ -236,12 +225,13 @@ def get_label(self, img_coord, return_node=False):
236225
Label of voxel at "img_coord".
237226
238227
"""
239-
label = self.__read_label(img_coord)
228+
# Read label
240229
if self.in_black_hole(img_coord):
241230
label = -1
242-
return self.finalize_label(label, return_node)
231+
else:
232+
label = self.__read_label(img_coord)
243233

244-
def finalize_label(self, label, return_node):
234+
# Validate label
245235
if return_node:
246236
return return_node, self.is_valid(label)
247237
else:
@@ -286,7 +276,7 @@ def is_valid(self, label):
286276
287277
"""
288278
if self.valid_labels:
289-
if label not in self.valid_labels:
279+
if label not in self.valid_labels.keys():
290280
return 0
291281
return label
292282

@@ -305,7 +295,7 @@ def rm_spurious_intersections(self):
305295
# Compute label intersect target_graphs
306296
hit_target_ids = dict()
307297
multi_hits = set()
308-
for xyz in self.get_pred_xyz(label):
298+
for xyz in self.get_pred_coords(label):
309299
hat_xyz, d = self.get_projection(xyz)
310300
if d < 5:
311301
hits = list(self.xyz_to_id_node[hat_xyz].keys())
@@ -329,12 +319,11 @@ def rm_spurious_intersections(self):
329319
elif label in self.id_to_label_nodes[target_id]:
330320
self.zero_nodes(target_id, label)
331321

332-
def get_pred_xyz(self, label):
333-
for path in self.pred_swc_paths:
334-
swc_id = utils.get_swc_id(path)
335-
if str(label) == swc_id:
336-
return get_xyz_coords(path, anisotropy=self.anisotropy)
337-
return []
322+
def get_pred_coords(self, label):
323+
if label in self.valid_labels.keys():
324+
return self.valid_labels[label]
325+
else:
326+
return []
338327

339328
# -- Final Constructor Routines --
340329
def init_kdtree(self):

src/segmentation_skeleton_metrics/swc_utils.py

Lines changed: 120 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -7,18 +7,107 @@
77
88
"""
99

10+
from concurrent.futures import ThreadPoolExecutor, as_completed
11+
from io import BytesIO
12+
from zipfile import ZipFile
13+
1014
import networkx as nx
1115
import numpy as np
16+
from google.cloud import storage
1217

1318
from segmentation_skeleton_metrics import utils
1419

1520

16-
def read(path, cloud_read=False):
17-
return read_from_cloud(path) if cloud_read else read_from_local(path)
21+
def parse(swc_paths, min_size, anisotropy=[1.0, 1.0, 1.0]):
22+
"""
23+
Reads swc files and extracts the xyz coordinates.
24+
25+
Paramters
26+
---------
27+
swc_paths : list or dict
28+
If swc files are on local machine, list of paths to swc files where
29+
each file corresponds to a neuron in the prediction. If swc files are
30+
on cloud, then dict with keys "bucket_name" and "path".
31+
min_size : int
32+
Threshold on the number of nodes contained in an swc file. Only swc
33+
files with more than "min_size" nodes are stored in "valid_labels".
34+
anisotropy : list[float]
35+
Image to World scaling factors applied to xyz coordinates to account
36+
for anisotropy of the microscope.
1837
38+
Returns
39+
-------
40+
dict
41+
...
42+
"""
43+
if type(swc_paths) == list:
44+
return parse_local_paths(swc_paths, min_size, anisotropy)
45+
elif type(swc_paths) == dict:
46+
return parse_cloud_paths(swc_paths, min_size, anisotropy)
47+
else:
48+
return None
49+
50+
51+
def parse_local_paths(pred_swc_paths, min_size, anisotropy):
52+
valid_labels = dict()
53+
for path in pred_swc_paths:
54+
contents = read_from_local(path)
55+
if len(contents) > min_size:
56+
swc_id = int(utils.get_swc_id(path))
57+
valid_labels[swc_id] = get_coords(contents, anisotropy)
58+
return valid_labels
59+
60+
61+
def parse_cloud_paths(cloud_dict, min_size, anisotropy):
62+
# Initializations
63+
bucket = storage.Client().bucket(cloud_dict["bucket_name"])
64+
zip_paths = utils.list_gcs_filenames(bucket, cloud_dict["path"], ".zip")
65+
chunk_size = int(len(zip_paths) * 0.02)
66+
67+
# Parse
68+
cnt = 1
69+
valid_labels = dict()
70+
print("Downloading predicted swc files from cloud...")
71+
print("# zip files:", len(zip_paths))
72+
for i, path in enumerate(zip_paths):
73+
valid_labels.update(download(bucket, path, min_size, anisotropy))
74+
if i > cnt * chunk_size:
75+
utils.progress_bar(i + 1, len(zip_paths))
76+
cnt += 1
77+
78+
# Report Results
79+
print("\n#Valid Labels:", len(valid_labels))
80+
print("")
81+
return valid_labels
82+
83+
84+
def download(bucket, zip_path, min_size, anisotropy):
85+
zip_content = bucket.blob(zip_path).download_as_bytes()
86+
with ZipFile(BytesIO(zip_content)) as zip_file:
87+
with ThreadPoolExecutor() as executor:
88+
# Assign threads
89+
threads = []
90+
for path in utils.list_files_in_gcs_zip(zip_content):
91+
threads.append(
92+
executor.submit(
93+
parse_gcs_zip, zip_file, path, min_size, anisotropy
94+
)
95+
)
1996

20-
def read_from_cloud(path):
21-
pass
97+
# Process results
98+
valid_labels = dict()
99+
for thread in as_completed(threads):
100+
valid_labels.update(thread.result())
101+
return valid_labels
102+
103+
104+
def parse_gcs_zip(zip_file, path, min_size, anisotropy):
105+
contents = read_from_cloud(zip_file, path)
106+
if len(contents) > min_size:
107+
swc_id = int(utils.get_swc_id(path))
108+
return {swc_id: get_coords(contents, anisotropy)}
109+
else:
110+
return dict()
22111

23112

24113
def read_from_local(path):
@@ -40,38 +129,44 @@ def read_from_local(path):
40129
return file.readlines()
41130

42131

43-
def get_xyz_coords(path, anisotropy=[1.0, 1.0, 1.0]):
132+
def read_from_cloud(zip_file, path):
133+
"""
134+
Reads the content of an swc file from a zip file in a GCS bucket.
135+
136+
"""
137+
with zip_file.open(path) as text_file:
138+
return text_file.read().decode("utf-8").splitlines()
139+
140+
141+
def get_coords(contents, anisotropy):
44142
"""
45143
Gets the xyz coords from the swc file at "path".
46144
47145
Parameters
48146
----------
49147
path : str
50148
Path to swc file to be parsed.
51-
anisotropy : list[float], optional
52-
Scaling factors applied to xyz coordinates to account for anisotropy
53-
of the microscope. The default is [1.0, 1.0, 1.0].
149+
anisotropy : list[float]
150+
Image to World scaling factors applied to xyz coordinates to account
151+
for anisotropy of the microscope.
54152
55153
Returns
56154
-------
57155
numpy.ndarray
58156
xyz coords from an swc file.
59157
60158
"""
61-
xyz_list = []
62-
with open(path, "r") as f:
63-
offset = [0, 0, 0]
64-
for line in f.readlines():
65-
if line.startswith("# OFFSET"):
66-
parts = line.split()
67-
offset = read_xyz(parts[2:5])
68-
if not line.startswith("#"):
69-
parts = line.split()
70-
xyz = read_xyz(
71-
parts[2:5], anisotropy=anisotropy, offset=offset
72-
)
73-
xyz_list.append(xyz)
74-
return np.array(xyz_list)
159+
coords_list = []
160+
offset = [0, 0, 0]
161+
for line in contents:
162+
if line.startswith("# OFFSET"):
163+
parts = line.split()
164+
offset = read_xyz(parts[2:5])
165+
if not line.startswith("#"):
166+
parts = line.split()
167+
coord = read_xyz(parts[2:5], anisotropy=anisotropy, offset=offset)
168+
coords_list.append(coord)
169+
return np.array(coords_list)
75170

76171

77172
def read_xyz(xyz, anisotropy=[1.0, 1.0, 1.0], offset=[0, 0, 0]):
@@ -81,8 +176,8 @@ def read_xyz(xyz, anisotropy=[1.0, 1.0, 1.0], offset=[0, 0, 0]):
81176
82177
Parameters
83178
----------
84-
xyz : str
85-
(x,y,z) coordinates.
179+
coord : str
180+
xyz coordinate.
86181
anisotropy : list[float], optional
87182
Image to real-world coordinates scaling factors applied to "xyz". The
88183
default is [1.0, 1.0, 1.0].
@@ -178,7 +273,7 @@ def to_graph(path, anisotropy=[1.0, 1.0, 1.0]):
178273
"""
179274
graph = nx.Graph(swc_id=utils.get_swc_id(path))
180275
offset = [0, 0, 0]
181-
for line in read(path):
276+
for line in read_from_local(path):
182277
if line.startswith("# OFFSET"):
183278
parts = line.split()
184279
offset = read_xyz(parts[2:5])

0 commit comments

Comments
 (0)