Skip to content

Commit 92569da

Browse files
anna-grimanna-grim
andauthored
refactor: update (#96)
Co-authored-by: anna-grim <[email protected]>
1 parent c7ee125 commit 92569da

File tree

5 files changed

+33
-30
lines changed

5 files changed

+33
-30
lines changed

demo/evaluation_results.xls

0 Bytes
Binary file not shown.

src/segmentation_skeleton_metrics/skeleton_metric.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626
swc_util,
2727
util
2828
)
29-
from segmentation_skeleton_metrics.utils.graph_util import to_array
3029

3130
MERGE_DIST_THRESHOLD = 100
3231
MIN_CNT = 40
@@ -201,13 +200,15 @@ def set_node_labels(self, key):
201200
# Assign threads
202201
threads = []
203202
for i in self.graphs[key].nodes:
204-
voxel = tuple(self.graphs[key].nodes[i]["xyz"])
203+
voxel = tuple(self.graphs[key].nodes[i]["voxel"])
205204
threads.append(executor.submit(self.get_label, i, voxel))
206205

207206
# Store label
207+
pbar = tqdm(total=len(threads), desc="threads finished")
208208
for thread in as_completed(threads):
209209
i, label = thread.result()
210210
self.graphs[key].nodes[i].update({"label": label})
211+
pbar.update(1)
211212

212213
def get_label(self, i, voxel):
213214
"""
@@ -227,7 +228,7 @@ def get_label(self, i, voxel):
227228
228229
"""
229230
# Read label
230-
if type(self.label_mask) is ts.TensorStore:
231+
if isinstance(self.label_mask, ts.TensorStore):
231232
label = int(self.label_mask[voxel].read().result())
232233
else:
233234
label = self.label_mask[voxel]
@@ -578,16 +579,16 @@ def is_fragment_merge(self, key, label, kdtree):
578579
None
579580
580581
"""
581-
for voxel in to_array(self.fragment_graphs[label])[::2]:
582+
for voxel in gutil.to_array(self.fragment_graphs[label])[::2]:
582583
if kdtree.query(voxel)[0] > MERGE_DIST_THRESHOLD:
583-
# Check whether to take inverse of label
584+
# Check whether to get inverse of label
584585
if self.inverse_label_map:
585586
equivalent_label = self.label_map[label]
586587
else:
587588
equivalent_label = label
588589

589590
# Record merge mistake
590-
xyz = util.to_world(voxel)
591+
xyz = util.to_physical(voxel)
591592
self.merge_cnt[key] += 1
592593
self.merged_labels.add((key, equivalent_label, tuple(xyz)))
593594
if self.save_projections:

src/segmentation_skeleton_metrics/utils/graph_util.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -153,8 +153,8 @@ def compute_run_length(graph):
153153
"""
154154
path_length = 0
155155
for i, j in nx.dfs_edges(graph):
156-
xyz_1 = util.to_physical(graph.nodes[i]["xyz"], ANISOTROPY)
157-
xyz_2 = util.to_physical(graph.nodes[j]["xyz"], ANISOTROPY)
156+
xyz_1 = util.to_physical(graph.nodes[i]["voxel"], ANISOTROPY)
157+
xyz_2 = util.to_physical(graph.nodes[j]["voxel"], ANISOTROPY)
158158
path_length += get_dist(xyz_1, xyz_2)
159159
return path_length
160160

@@ -167,16 +167,16 @@ def to_array(graph):
167167
Parameters
168168
----------
169169
graph : networkx.Graph
170-
Graph that contains nodes with "xyz" attributes.
170+
Graph that contains nodes with "voxel" attributes.
171171
172172
Returns
173173
-------
174174
numpy.ndarray
175175
Array where each row represents the 3D coordinates of a node.
176176
177177
"""
178-
xyz_coords = nx.get_node_attributes(graph, "xyz")
179-
return np.array([xyz_coords[i] for i in graph.nodes])
178+
voxels = nx.get_node_attributes(graph, "voxel")
179+
return np.array([voxels[i] for i in graph.nodes])
180180

181181

182182
def sample_leaf(graph):

src/segmentation_skeleton_metrics/utils/swc_util.py

Lines changed: 17 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -60,9 +60,7 @@ def __init__(self, anisotropy=(1.0, 1.0, 1.0), min_size=0):
6060
(1.0, 1.0, 1.0).
6161
min_size : int, optional
6262
Threshold on the number of nodes in swc file. Only swc files with
63-
more than "min_size" nodes are stored in "xyz_coords". The default
64-
is 0.
65-
63+
more than "min_size" nodes are returned.
6664
Returns
6765
-------
6866
None
@@ -157,8 +155,10 @@ def load_from_local_paths(self, swc_paths):
157155

158156
# Store results
159157
graph_dict = dict()
158+
pbar = tqdm(total=len(filenames), desc="Label Graph")
160159
for thread in as_completed(threads):
161160
graph_dict.update(thread.result())
161+
pbar.update(1)
162162
return graph_dict
163163

164164
def load_from_local_path(self, path):
@@ -336,7 +336,7 @@ def process_content(self, content, filename):
336336
return {name: graph}
337337
return dict()
338338

339-
def read_xyz(self, xyz_str, offset):
339+
def read_voxel(self, xyz_str, offset):
340340
"""
341341
Reads the coordinates from a string, then transforms them to image
342342
coordinates (if applicable).
@@ -378,13 +378,13 @@ def get_graph(self, content):
378378
for line in content:
379379
if line.startswith("# OFFSET"):
380380
parts = line.split()
381-
offset = self.read_xyz(parts[2:5])
381+
offset = self.read_voxel(parts[2:5])
382382
if not line.startswith("#"):
383383
parts = line.split()
384384
child = int(parts[0])
385385
parent = int(parts[-1])
386-
xyz = self.read_xyz(parts[2:5], offset=offset)
387-
graph.add_node(child, xyz=xyz)
386+
voxel = self.read_voxel(parts[2:5], offset=offset)
387+
graph.add_node(child, voxel=voxel)
388388
if parent != -1:
389389
graph.add_edge(parent, child)
390390

@@ -432,30 +432,29 @@ def save(path, xyz_1, xyz_2, color=None):
432432
f.write(make_entry(2, 1, xyz_2))
433433

434434

435-
def make_entry(node_id, parent_id, xyz):
435+
def make_entry(node, parent, xyz):
436436
"""
437-
Makes an entry to be written in an swc file.
437+
Makes an entry to be written in an SWC file.
438438
439439
Parameters
440440
----------
441441
graph : networkx.Graph
442442
Graph that "node_id" and "parent_id" belong to.
443-
node_id : int
444-
Node that entry corresponds to.
443+
node : int
444+
Node ID that entry corresponds to.
445445
parent_id : int
446-
Parent of node "node_id".
447-
xyz : ...
448-
xyz coordinate of "node_id".
446+
Parent ID of the given node.
447+
voxel : ...
448+
Voxel coordinate of the given node.
449449
450450
Returns
451451
-------
452452
entry : str
453453
Entry to be written in an swc file.
454454
455455
"""
456-
x, y, z = tuple(util.to_world(xyz))
457-
entry = f"{node_id} 2 {x} {y} {z} 3 {parent_id}"
458-
return entry
456+
x, y, z = tuple(util.to_physical(voxel, (0.748, 0.748, 1.0)))
457+
return f"{node} 2 {x} {y} {z} 3 {parent}"
459458

460459

461460
def to_zipped_swc(zip_writer, graph, color=None):
@@ -488,7 +487,7 @@ def to_zipped_swc(zip_writer, graph, color=None):
488487
r = 5 if color else 3
489488
for i, j in nx.dfs_edges(graph):
490489
# Special Case: Root
491-
x, y, z = tuple(util.to_world(graph.nodes[i]["xyz"]))
490+
x, y, z = tuple(util.to_world(graph.nodes[i]["voxel"]))
492491
if n_entries == 0:
493492
parent = -1
494493
node_to_idx[i] = 1

src/segmentation_skeleton_metrics/utils/util.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -461,4 +461,7 @@ def load_valid_labels(path):
461461
reconstruction process.
462462
463463
"""
464-
return set(map(int, read_txt(path)))
464+
valid_labels = set()
465+
for label_str in read_txt(path):
466+
valid_labels.add(int(label_str.split(".")[0]))
467+
return valid_labels

0 commit comments

Comments
 (0)