Skip to content

Commit c2e72fd

Browse files
anna-grimanna-grim
andauthored
Feat save merge sites (#104)
* refactor: improved new features * improved new feature * refactor: merge site save --------- Co-authored-by: anna-grim <[email protected]>
1 parent cb65641 commit c2e72fd

File tree

2 files changed

+33
-22
lines changed

2 files changed

+33
-22
lines changed

src/segmentation_skeleton_metrics/skeleton_metric.py

Lines changed: 33 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,15 @@
1515
ProcessPoolExecutor,
1616
ThreadPoolExecutor,
1717
)
18+
from copy import deepcopy
1819
from scipy.spatial import distance, KDTree
1920
from tqdm import tqdm
2021
from zipfile import ZipFile
2122

2223
import networkx as nx
2324
import numpy as np
2425
import os
26+
import pandas as pd
2527

2628
from segmentation_skeleton_metrics import split_detection
2729
from segmentation_skeleton_metrics.utils import (
@@ -162,6 +164,7 @@ def load_groundtruth(self, swc_pointer):
162164
use_anisotropy=False,
163165
)
164166
self.graphs = graph_builder.run(swc_pointer)
167+
self.gt_graphs = deepcopy(self.graphs)
165168

166169
# Label nodes
167170
for key in tqdm(self.graphs, desc="Labeling Graphs"):
@@ -370,21 +373,9 @@ def init_writers(self):
370373

371374
# Merged fragments writer
372375
if self.save_merges or self.localize_merges:
373-
# Initialize directory
374-
merges_dir = os.path.join(self.output_dir, "merged_fragments")
375-
util.mkdir(merges_dir, delete=True)
376-
377-
# ZIP writer
378-
self.merge_writer = dict()
379-
for key in self.graphs.keys():
380-
zip_path = f"{merges_dir}/{key}.zip"
381-
self.merge_writer[key] = ZipFile(zip_path, "w")
382-
self.graphs[key].to_zipped_swc(self.merge_writer[key])
383-
384-
# Merge sites
385-
if self.localize_merges:
386-
sites_path = os.path.join(merges_dir, "estimated-merge-sites.txt")
387-
self.site_txt_writer = open(sites_path, "w", encoding="utf-8")
376+
zip_path = os.path.join(self.output_dir, "merged_fragments.zip")
377+
self.merge_writer = ZipFile(zip_path, "a")
378+
self.merge_sites = list()
388379

389380
# -- Main Routine --
390381
def run(self):
@@ -418,6 +409,13 @@ def run(self):
418409
path = f"{self.output_dir}/{prefix}results.xls"
419410
util.save_results(path, full_results)
420411

412+
# Save merge sites (if applicable)
413+
if self.localize_merges:
414+
df = pd.DataFrame(self.merge_sites)
415+
df.to_csv(
416+
os.path.join(self.output_dir, "merge_sites.csv"), index=False
417+
)
418+
421419
# Report results overview
422420
path = os.path.join(self.output_dir, f"{prefix}results-overview.txt")
423421
util.update_txt(path, "Average Results...")
@@ -608,12 +606,17 @@ def is_fragment_merge(self, key, label, kdtree):
608606

609607
# Save merged fragment (if applicable)
610608
if self.save_merges:
611-
fragment_graph.to_zipped_swc(self.merge_writer[key])
609+
fragment_graph.to_zipped_swc(self.merge_writer)
610+
if f"{key}.swc" not in self.merge_writer.namelist():
611+
self.gt_graphs[key].to_zipped_swc(self.merge_writer)
612+
613+
# Find approximate merge site
612614
if self.localize_merges:
613-
self.find_merge_site(key, fragment_graph, kdtree)
615+
self.find_merge_site(key, fragment_graph, kdtree)
616+
614617
break
615618

616-
# Save fragment (if applicable)
619+
# Save fragment (if applicable)
617620
if self.save_fragments and min_dist < 3:
618621
fragment_graph.to_zipped_swc(self.fragment_writer[key])
619622

@@ -726,14 +729,24 @@ def find_merge_site(self, key, fragment_graph, kdtree):
726729
voxel_j = fragment_graph.voxels[j]
727730
gt_voxel = util.kdtree_query(kdtree, voxel_j)
728731
if self.physical_dist(gt_voxel, voxel_j) < 2:
732+
# Save merge swc
729733
hit = True
730734
merge_cnt = np.sum(list(self.merge_cnt.values()))
731735
filename = f"merge-{merge_cnt}.swc"
732736
xyz = img_util.to_physical(voxel_j, self.anisotropy)
733737
swc_util.to_zipped_point(
734-
self.merge_writer[key], filename, xyz
738+
self.merge_writer, filename, xyz
739+
)
740+
741+
# Save merge in list
742+
segment_id = util.get_segment_id(fragment_graph.filename)
743+
self.merge_sites.append(
744+
{
745+
"Segment_ID": segment_id,
746+
"Voxel": voxel_j,
747+
"XYZ": xyz,
748+
}
735749
)
736-
self.site_txt_writer.write(f"{tuple(xyz)}\n")
737750
break
738751

739752
# Check whether to continue

src/segmentation_skeleton_metrics/utils/util.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,8 @@
1010
1111
"""
1212

13-
from io import BytesIO
1413
from random import sample
1514
from xlwt import Workbook
16-
from zipfile import ZipFile
1715

1816
import os
1917
import shutil

0 commit comments

Comments
 (0)