Skip to content

Commit 740fa3f

Browse files
author
anna-grim
committed
feat: label handler
1 parent 3732d37 commit 740fa3f

File tree

5 files changed

+206
-261
lines changed

5 files changed

+206
-261
lines changed

src/segmentation_skeleton_metrics/skeleton_metric.py

Lines changed: 22 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
88
"""
99

10-
from collections import deque
1110
from concurrent.futures import (
1211
as_completed,
1312
ProcessPoolExecutor,
@@ -112,11 +111,13 @@ def __init__(
112111
self.output_dir = output_dir
113112
self.preexisting_merges = preexisting_merges
114113

114+
# Label handler
115+
self.label_handler = gutil.LabelHandler(
116+
connections_path=connections_path, valid_labels=valid_labels
117+
)
118+
115119
# Load data
116-
assert isinstance(valid_labels, set) if valid_labels else True
117120
self.label_mask = pred_labels
118-
self.valid_labels = valid_labels
119-
self.init_label_map(connections_path)
120121
self.load_groundtruth(gt_pointer)
121122
self.load_fragments(fragments_pointer)
122123

@@ -126,31 +127,6 @@ def __init__(
126127
self.init_zip_writer()
127128

128129
# --- Load Data ---
129-
def init_label_map(self, path):
130-
"""
131-
Initializes a dictionary that maps a label to its equivalent label in
132-
the case where "connections_path" is provided.
133-
134-
Parameters
135-
----------
136-
path : str
137-
Path to a txt file containing pairs of segment ids of segments
138-
that were merged into a single segment.
139-
140-
Returns
141-
-------
142-
None
143-
144-
"""
145-
if path:
146-
assert self.valid_labels is not None, "Must provide valid labels!"
147-
self.label_map, self.inverse_label_map = util.init_label_map(
148-
path, self.valid_labels
149-
)
150-
else:
151-
self.label_map = None
152-
self.inverse_label_map = None
153-
154130
def load_groundtruth(self, swc_pointer):
155131
"""
156132
Initializes "self.graphs" by iterating over "paths" which corresponds
@@ -265,50 +241,10 @@ def get_patch_labels(self, key, nodes):
265241
node_to_label = dict()
266242
for i in nodes:
267243
voxel = self.to_local_voxels(key, i, bbox["min"])
268-
label = self.adjust_label(label_patch[voxel])
244+
label = self.label_handler.get(label_patch[voxel])
269245
node_to_label[i] = label
270246
return node_to_label
271247

272-
def adjust_label(self, label):
273-
"""
274-
Gets label of voxel in "self.label_mask".
275-
276-
Parameters
277-
----------
278-
i : int
279-
Node ID.
280-
voxel : numpy.ndarray
281-
Image coordinate of voxel to be read.
282-
283-
Returns
284-
-------
285-
int
286-
Label of voxel.
287-
288-
"""
289-
if self.label_map:
290-
label = self.get_equivalent_label(label)
291-
elif self.valid_labels:
292-
label = 0 if label not in self.valid_labels else label
293-
return label
294-
295-
def get_equivalent_label(self, label):
296-
"""
297-
Gets the equivalence class label corresponding to "label".
298-
299-
Parameters
300-
----------
301-
label : int
302-
Label to be checked.
303-
304-
Returns
305-
-------
306-
label
307-
Equivalence class label.
308-
309-
"""
310-
return self.label_map[label] if label in self.label_map else 0
311-
312248
def get_all_node_labels(self):
313249
"""
314250
Gets the a set of all unique labels from all graphs in "self.graphs".
@@ -324,7 +260,7 @@ def get_all_node_labels(self):
324260
325261
"""
326262
all_labels = set()
327-
inverse_bool = True if self.inverse_label_map else False
263+
inverse_bool = self.label_handler.use_mapping()
328264
for key in self.graphs:
329265
labels = self.get_node_labels(key, inverse_bool=inverse_bool)
330266
all_labels = all_labels.union(labels)
@@ -352,7 +288,7 @@ def get_node_labels(self, key, inverse_bool=False):
352288
if inverse_bool:
353289
output = set()
354290
for l in self.key_to_label_to_nodes[key].keys():
355-
output = output.union(self.inverse_label_map[l])
291+
output = output.union(self.label_handler.inverse_mapping[l])
356292
return output
357293
else:
358294
return set(self.key_to_label_to_nodes[key].keys())
@@ -404,7 +340,7 @@ def run(self):
404340
self.detect_splits()
405341
self.quantify_splits()
406342

407-
# Check whether to delete prexisting merges
343+
# Check for prexisting merges
408344
if self.preexisting_merges:
409345
for key in self.graphs:
410346
self.adjust_metrics(key)
@@ -467,7 +403,7 @@ def detect_splits(self):
467403
468404
"""
469405
t0 = time()
470-
pbar = tqdm(total=len(self.graphs), desc="Split Detection:")
406+
pbar = tqdm(total=len(self.graphs), desc="Split Detection")
471407
with ProcessPoolExecutor() as executor:
472408
# Assign processes
473409
processes = list()
@@ -543,7 +479,7 @@ def detect_merges(self):
543479

544480
# Count total merges
545481
if self.fragment_graphs:
546-
pbar = tqdm(total=len(self.graphs), desc="Count Merges:")
482+
pbar = tqdm(total=len(self.graphs), desc="Merge Detection")
547483
for key, graph in self.graphs.items():
548484
if graph.number_of_nodes() > 0:
549485
kdtree = KDTree(graph.graph["voxel"])
@@ -582,14 +518,7 @@ def count_merges(self, key, kdtree):
582518
"""
583519
for label in self.get_node_labels(key):
584520
if len(self.key_to_label_to_nodes[key][label]) > MIN_CNT:
585-
# Check whether to compute label inverse
586-
if self.inverse_label_map:
587-
labels = deepcopy(self.inverse_label_map[label])
588-
else:
589-
labels = [label]
590-
591-
# Check if fragment is a merge mistake
592-
for label in labels:
521+
for label in self.label_handler.get_class(label):
593522
if label in self.fragment_graphs:
594523
self.is_fragment_merge(key, label, kdtree)
595524

@@ -616,16 +545,13 @@ def is_fragment_merge(self, key, label, kdtree):
616545
"""
617546
for voxel in self.fragment_graphs[label].graph["voxel"]:
618547
if kdtree.query(voxel)[0] > MERGE_DIST_THRESHOLD:
619-
# Check whether to get inverse of label
620-
if self.inverse_label_map:
621-
equivalent_label = self.label_map[label]
622-
else:
623-
equivalent_label = label
624-
625-
# Record merge mistake
548+
# Log merge mistake
549+
equiv_label = self.label_handler.get(label)
626550
xyz = img_util.to_physical(voxel, self.anisotropy)
627551
self.merge_cnt[key] += 1
628-
self.merged_labels.add((key, equivalent_label, tuple(xyz)))
552+
self.merged_labels.add((key, equiv_label, tuple(xyz)))
553+
554+
# Save merged fragment (if applicable)
629555
if self.save_projections and label in self.fragment_graphs:
630556
swc_util.to_zipped_swc(
631557
self.zip_writer[key], self.fragment_graphs[label]
@@ -729,7 +655,7 @@ def save_merged_labels(self):
729655
with open(os.path.join(self.output_dir, filename), "w") as f:
730656
f.write(f" Label - xyz\n")
731657
for _, label, xyz in self.merged_labels:
732-
if self.connections_path:
658+
if self.label_handler.use_mapping():
733659
label = self.get_merged_label(label)
734660
f.write(f" {label} - {xyz}\n")
735661

@@ -749,11 +675,11 @@ def get_merged_label(self, label):
749675
-------
750676
str or list
751677
The first matching label found in "self.fragment_graphs.keys()" or
752-
the original associated labels from "inverse_label_map" if no
678+
the original associated labels from "inverse_label_map" if no
753679
matches are found.
754680
755681
"""
756-
for l in self.inverse_label_map[label]:
682+
for l in self.label_handler.get_class(label):
757683
if l in self.fragment_graphs.keys():
758684
return l
759685
return self.inverse_label_map[label]
@@ -852,8 +778,8 @@ def generate_avg_results(self):
852778

853779
def avg_result(self, stats):
854780
"""
855-
Averages the values computed across "self.graphs" for
856-
a given metric stored in "stats".
781+
Averages the values computed across "self.graphs" for a given metric
782+
stored in "stats".
857783
858784
Parameters
859785
----------

0 commit comments

Comments
 (0)