Skip to content

Commit 453fd12

Browse files
anna-grimanna-grim
andauthored
refactor: improved new features (#102)
Co-authored-by: anna-grim <[email protected]>
1 parent ac679cc commit 453fd12

File tree

3 files changed

+79
-32
lines changed

3 files changed

+79
-32
lines changed

src/segmentation_skeleton_metrics/skeleton_metric.py

Lines changed: 43 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from segmentation_skeleton_metrics.utils import (
2828
graph_util as gutil,
2929
img_util,
30+
swc_util,
3031
util,
3132
)
3233

@@ -63,7 +64,7 @@ def __init__(
6364
anisotropy=(1.0, 1.0, 1.0),
6465
connections_path=None,
6566
fragments_pointer=None,
66-
localize_merge=False,
67+
localize_merges=False,
6768
preexisting_merges=None,
6869
save_merges=False,
6970
save_fragments=False,
@@ -94,7 +95,7 @@ def __init__(
9495
"swc_util.Reader" for documentation. Notes: (1) "anisotropy" is
9596
applied to these SWC files and (2) these SWC files are required
9697
for counting merges. The default is None.
97-
localize_merge : bool, optional
98+
localize_merges : bool, optional
9899
Indication of whether to search for the approximate location of a
99100
merge. The default is False.
100101
preexisting_merges : List[int], optional
@@ -119,7 +120,7 @@ def __init__(
119120
# Instance attributes
120121
self.anisotropy = anisotropy
121122
self.connections_path = connections_path
122-
self.localize_merge = localize_merge
123+
self.localize_merges = localize_merges
123124
self.output_dir = output_dir
124125
self.preexisting_merges = preexisting_merges
125126
self.save_merges = save_merges
@@ -136,12 +137,7 @@ def __init__(
136137
self.load_fragments(fragments_pointer)
137138

138139
# Initialize writers
139-
self.init_zip_writers()
140-
141-
# Initialize fragment projections directory
142-
if self.save_projections:
143-
self.projections_dir = os.path.join(output_dir, "projections")
144-
util.mkdir(self.projections_dir)
140+
self.init_writers()
145141

146142
# --- Load Data ---
147143
def load_groundtruth(self, swc_pointer):
@@ -345,7 +341,7 @@ def get_node_labels(self, key, inverse_bool=False):
345341
else:
346342
return self.graphs[key].get_labels()
347343

348-
def init_zip_writers(self):
344+
def init_writers(self):
349345
"""
350346
Initializes "self.merge_writer" attribute by setting up a directory for
351347
output files and creating ZIP files for each graph in "self.graphs".
@@ -359,31 +355,45 @@ def init_zip_writers(self):
359355
None
360356
361357
"""
362-
# Merged fragments zip writer
358+
# Fragments writer
359+
if self.save_fragments:
360+
# Initialize direction
361+
fragments_dir = os.path.join(self.output_dir, "fragments")
362+
util.mkdir(fragments_dir, delete=True)
363+
364+
# ZIP writer
365+
self.fragment_writer = dict()
366+
for key in self.graphs.keys():
367+
zip_path = f"{fragments_dir}/{key}.zip"
368+
self.fragment_writer[key] = ZipFile(zip_path, "w")
369+
self.graphs[key].to_zipped_swc(self.fragment_writer[key])
370+
371+
# Merged fragments writer
363372
if self.save_merges:
364373
# Initialize directory
365374
merges_dir = os.path.join(self.output_dir, "merged_fragments")
366-
util.mkdir(merged_fragments_dir)
375+
util.mkdir(merged_fragments_dir, delete=True)
367376

368-
# Initialize zip writer
377+
# ZIP writer
369378
self.merge_writer = dict()
370379
for key in self.graphs.keys():
371380
zip_path = f"{merged_fragments_dir}/{key}.zip"
372381
self.merge_writer[key] = ZipFile(zip_path, "w")
373382
self.graphs[key].to_zipped_swc(self.merge_writer[key])
374383

375-
# Fragments zip writer
376-
if self.save_fragments:
377-
# Initialize direction
378-
fragments_dir = os.path.join(self.output_dir, "fragments")
379-
util.mkdir(fragments_dir)
384+
# Merge sites
385+
if self.localize_merges:
386+
# Initialize directory
387+
merges_dir = os.path.join(self.output_dir, "merge-sites")
388+
util.mkdir(merges_dir, delete=True)
380389

381-
# Initialize zip writer
382-
self.fragment_writer = dict()
383-
for key in self.graphs.keys():
384-
zip_path = f"{fragments_dir}/{key}.zip"
385-
self.fragment_writer[key] = ZipFile(zip_path, "w")
386-
self.graphs[key].to_zipped_swc(self.fragment_writer[key])
390+
# ZIP writer
391+
zip_path = f"{merges_dir}/estimated-merge-sites.zip"
392+
self.site_zip_writer = ZipFile(zip_path, "w")
393+
394+
# Txt writer
395+
sites_path = os.path.join(merges_dir, "estimated-merge-sites.txt")
396+
self.site_txt_writer = open(sites_path, "w", encoding="utf-8")
387397

388398
# -- Main Routine --
389399
def run(self):
@@ -560,9 +570,6 @@ def count_merges(self, key, kdtree):
560570
for label in self.label_handler.get_class(label):
561571
if label in self.fragment_ids:
562572
self.is_fragment_merge(key, label, kdtree)
563-
if self.save_projections:
564-
fragment_graph = self.find_graph_from_label(label)[0]
565-
fragment_graph.to_zipped_swc(zip_writer)
566573

567574
def is_fragment_merge(self, key, label, kdtree):
568575
"""
@@ -611,12 +618,12 @@ def is_fragment_merge(self, key, label, kdtree):
611618
# Save merged fragment (if applicable)
612619
if self.save_merges:
613620
fragment_graph.to_zipped_swc(self.merge_writer[key])
614-
if self.localize_merge:
621+
if self.localize_merges:
615622
self.find_merge_site(key, fragment_graph, kdtree)
616623
break
617624

618625
# Save fragment (if applicable)
619-
if self.save_fragments and min_dist < 3:
626+
if self.save_fragments and min_dist < 3:
620627
fragment_graph.to_zipped_swc(self.fragment_writer[key])
621628

622629
def adjust_metrics(self, key):
@@ -729,7 +736,13 @@ def find_merge_site(self, key, fragment_graph, kdtree):
729736
gt_voxel = util.kdtree_query(kdtree, voxel_j)
730737
if self.physical_dist(gt_voxel, voxel_j) < 2:
731738
hit = True
732-
print("Approximate Site:", img_util.to_physical(voxel_j, self.anisotropy))
739+
merge_cnt = np.sum(list(self.merge_cnt.values()))
740+
filename = f"{merge_cnt}.swc"
741+
xyz = img_util.to_physical(voxel_j, self.anisotropy)
742+
swc_util.to_zipped_point(
743+
self.site_zip_writer, filename, xyz
744+
)
745+
self.site_txt_writer.write(f"{tuple(xyz)}\n")
733746
break
734747

735748
# Check whether to continue

src/segmentation_skeleton_metrics/split_detection.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,15 +46,15 @@ def run(process_id, graph):
4646
label_i = int(graph.labels[i])
4747
label_j = int(graph.labels[j])
4848
if is_split(label_i, label_j):
49-
graph.remove_edge(i, j) temp
49+
graph.remove_edge(i, j)
5050
split_cnt += 1
5151
elif label_j == 0:
5252
check_misalignment(graph, visited_edges, i, j)
5353
visited_edges.add(frozenset({i, j}))
5454

5555
# Finish
5656
split_percent = split_cnt / graph.graph["n_edges"]
57-
graph.remove_nodes_with_label(0) temp
57+
graph.remove_nodes_with_label(0)
5858
return process_id, graph, split_percent
5959

6060

src/segmentation_skeleton_metrics/utils/swc_util.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
ProcessPoolExecutor,
2929
ThreadPoolExecutor,
3030
)
31+
from io import StringIO
3132
from tqdm import tqdm
3233
from zipfile import ZipFile
3334

@@ -380,3 +381,36 @@ def read_voxel(self, xyz_str, offset):
380381
"""
381382
xyz = [float(xyz_str[i]) + offset[i] for i in range(3)]
382383
return img_util.to_voxels(xyz, self.anisotropy)
384+
385+
386+
# --- Write ---
387+
def to_zipped_point(zip_writer, filename, xyz):
388+
"""
389+
Writes a point to an SWC file format, which is then stored in a ZIP
390+
archive.
391+
392+
Parameters
393+
----------
394+
zip_writer : zipfile.ZipFile
395+
A ZipFile object that will store the generated SWC file.
396+
filename : str
397+
Filename of SWC file.
398+
xyz : ArrayLike
399+
Point to be written to SWC file.
400+
401+
Returns
402+
-------
403+
None
404+
405+
"""
406+
with StringIO() as text_buffer:
407+
# Preamble
408+
text_buffer.write("# COLOR [1.0 0.0 0.0]")
409+
text_buffer.write("# id, type, z, y, x, r, pid")
410+
411+
# Write entry
412+
x, y, z = tuple(xyz)
413+
text_buffer.write("\n" + f"1 2 {x} {y} {z} 15 -1")
414+
415+
# Finish
416+
zip_writer.writestr(filename, text_buffer.getvalue())

0 commit comments

Comments
 (0)