Skip to content

Commit c2c51a6

Browse files
anna-grimanna-grim
andauthored
Cloud merge detection (#54)
* feat: merge detection on cloud * minor upds --------- Co-authored-by: anna-grim <[email protected]>
1 parent 933cd2d commit c2c51a6

File tree

1 file changed

+17
-18
lines changed

1 file changed

+17
-18
lines changed

src/segmentation_skeleton_metrics/skeleton_metric.py

Lines changed: 17 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from segmentation_skeleton_metrics import split_detection, swc_utils, utils
2121
from segmentation_skeleton_metrics.swc_utils import save, to_graph
2222

23-
INTERSECTION_THRESHOLD = 8
23+
INTERSECTION_THRESHOLD = 10
2424
MERGE_DIST_THRESHOLD = 40
2525

2626

@@ -50,8 +50,8 @@ def __init__(
5050
equivalent_ids=None,
5151
ignore_boundary_mistakes=False,
5252
output_dir=None,
53-
valid_size_threshold=40,
54-
write_to_swc=False,
53+
valid_size_threshold=25,
54+
save_swc=False,
5555
):
5656
"""
5757
Constructs skeleton metric object that evaluates the quality of a
@@ -87,7 +87,7 @@ def __init__(
8787
Threshold on the number of nodes contained in an swc file. Only swc
8888
files with more than "valid_size_threshold" nodes are stored in
8989
"self.valid_labels". The default is 40.
90-
write_to_swc : bool, optional
90+
save_swc : bool, optional
9191
Indication of whether to write mistake sites to an swc file. The
9292
default is False.
9393
@@ -100,7 +100,7 @@ def __init__(
100100
self.anisotropy = anisotropy
101101
self.ignore_boundary_mistakes = ignore_boundary_mistakes
102102
self.output_dir = output_dir
103-
self.write_to_swc = write_to_swc
103+
self.save = save_swc
104104

105105
self.init_black_holes(black_holes_xyz_id)
106106
self.black_hole_radius = black_hole_radius
@@ -372,11 +372,13 @@ def compute_metrics(self):
372372
"""
373373
# Split evaluation
374374
print("Detecting Splits...")
375+
self.saved_site_cnt = 0
375376
self.detect_splits()
376377
self.quantify_splits()
377378

378379
# Merge evaluation
379380
print("Detecting Merges...")
381+
self.saved_site_cnt = 0
380382
self.detect_merges()
381383
self.quantify_merges()
382384

@@ -509,18 +511,18 @@ def detect_merges(self):
509511
merge = (frozenset((target_id_1, target_id_2)), label)
510512
if merge not in detected_merges:
511513
detected_merges.add(merge)
512-
site, d = self.locate(target_id_1, target_id_2, label)
513-
if d < MERGE_DIST_THRESHOLD:
514-
self.merge_cnts[target_id_1] += 1
515-
self.merge_cnts[target_id_2] += 1
516-
if self.write_to_swc:
514+
if self.save:
515+
site, d = self.locate(target_id_1, target_id_2, label)
516+
if d < MERGE_DIST_THRESHOLD:
517517
self.save_swc(site[0], site[1], "merge")
518518

519519
# Update graph
520520
for (target_ids, label) in detected_merges:
521521
target_id_1, target_id_2 = tuple(target_ids)
522522
self.process_merge(target_id_1, label)
523523
self.process_merge(target_id_2, label)
524+
self.merge_cnts[target_id_1] += 1
525+
self.merge_cnts[target_id_2] += 1
524526

525527
# Report Runtime
526528
t, unit = utils.time_writer(time() - t0)
@@ -541,6 +543,8 @@ def locate(self, target_id_1, target_id_2, label):
541543
if utils.dist(xyz_1, xyz_2) < min_dist:
542544
min_dist = utils.dist(xyz_1, xyz_2)
543545
xyz_pair = [xyz_1, xyz_2]
546+
if min_dist < MERGE_DIST_THRESHOLD:
547+
return xyz_pair, min_dist
544548
return xyz_pair, min_dist
545549

546550
def near_bdd(self, xyz):
@@ -787,16 +791,11 @@ def list_metrics(self):
787791
return metrics
788792

789793
def save_swc(self, xyz_1, xyz_2, mistake_type):
794+
self.saved_site_cnt += 1
790795
xyz_1 = utils.to_world(xyz_1, self.anisotropy)
791796
xyz_2 = utils.to_world(xyz_2, self.anisotropy)
792-
if mistake_type == "split":
793-
color = "0.0 1.0 0.0"
794-
cnt = 1 + np.sum(list(self.split_cnts.values())) // 2
795-
else:
796-
color = "0.0 0.0 1.0"
797-
cnt = 1 + np.sum(list(self.merge_cnts.values())) // 2
798-
799-
path = f"{self.output_dir}/{mistake_type}-{cnt}.swc"
797+
color = "0.0 1.0 0.0" if mistake_type == "split" else "0.0 0.0 1.0"
798+
path = f"{self.output_dir}/{mistake_type}-{self.saved_site_cnt}.swc"
800799
save(path, xyz_1, xyz_2, color=color)
801800

802801

0 commit comments

Comments
 (0)