|
15 | 15 | ProcessPoolExecutor, |
16 | 16 | ThreadPoolExecutor, |
17 | 17 | ) |
| 18 | +from copy import deepcopy |
18 | 19 | from scipy.spatial import distance, KDTree |
19 | 20 | from tqdm import tqdm |
20 | 21 | from zipfile import ZipFile |
21 | 22 |
|
22 | 23 | import networkx as nx |
23 | 24 | import numpy as np |
24 | 25 | import os |
| 26 | +import pandas as pd |
25 | 27 |
|
26 | 28 | from segmentation_skeleton_metrics import split_detection |
27 | 29 | from segmentation_skeleton_metrics.utils import ( |
@@ -162,6 +164,7 @@ def load_groundtruth(self, swc_pointer): |
162 | 164 | use_anisotropy=False, |
163 | 165 | ) |
164 | 166 | self.graphs = graph_builder.run(swc_pointer) |
| 167 | + self.gt_graphs = deepcopy(self.graphs) |
165 | 168 |
|
166 | 169 | # Label nodes |
167 | 170 | for key in tqdm(self.graphs, desc="Labeling Graphs"): |
@@ -370,21 +373,9 @@ def init_writers(self): |
370 | 373 |
|
371 | 374 | # Merged fragments writer |
372 | 375 | if self.save_merges or self.localize_merges: |
373 | | - # Initialize directory |
374 | | - merges_dir = os.path.join(self.output_dir, "merged_fragments") |
375 | | - util.mkdir(merges_dir, delete=True) |
376 | | - |
377 | | - # ZIP writer |
378 | | - self.merge_writer = dict() |
379 | | - for key in self.graphs.keys(): |
380 | | - zip_path = f"{merges_dir}/{key}.zip" |
381 | | - self.merge_writer[key] = ZipFile(zip_path, "w") |
382 | | - self.graphs[key].to_zipped_swc(self.merge_writer[key]) |
383 | | - |
384 | | - # Merge sites |
385 | | - if self.localize_merges: |
386 | | - sites_path = os.path.join(merges_dir, "estimated-merge-sites.txt") |
387 | | - self.site_txt_writer = open(sites_path, "w", encoding="utf-8") |
| 376 | + zip_path = os.path.join(self.output_dir, "merged_fragments.zip") |
| 377 | + self.merge_writer = ZipFile(zip_path, "a") |
| 378 | + self.merge_sites = list() |
388 | 379 |
|
389 | 380 | # -- Main Routine -- |
390 | 381 | def run(self): |
@@ -418,6 +409,13 @@ def run(self): |
418 | 409 | path = f"{self.output_dir}/{prefix}results.xls" |
419 | 410 | util.save_results(path, full_results) |
420 | 411 |
|
| 412 | + # Save merge sites (if applicable) |
| 413 | + if self.localize_merges: |
| 414 | + df = pd.DataFrame(self.merge_sites) |
| 415 | + df.to_csv( |
| 416 | + os.path.join(self.output_dir, "merge_sites.csv"), index=False |
| 417 | + ) |
| 418 | + |
421 | 419 | # Report results overview |
422 | 420 | path = os.path.join(self.output_dir, f"{prefix}results-overview.txt") |
423 | 421 | util.update_txt(path, "Average Results...") |
@@ -608,12 +606,17 @@ def is_fragment_merge(self, key, label, kdtree): |
608 | 606 |
|
609 | 607 | # Save merged fragment (if applicable) |
610 | 608 | if self.save_merges: |
611 | | - fragment_graph.to_zipped_swc(self.merge_writer[key]) |
| 609 | + fragment_graph.to_zipped_swc(self.merge_writer) |
| 610 | + if f"{key}.swc" not in self.merge_writer.namelist(): |
| 611 | + self.gt_graphs[key].to_zipped_swc(self.merge_writer) |
| 612 | + |
| 613 | + # Find approximate merge site |
612 | 614 | if self.localize_merges: |
613 | | - self.find_merge_site(key, fragment_graph, kdtree) |
| 615 | + self.find_merge_site(key, fragment_graph, kdtree) |
| 616 | + |
614 | 617 | break |
615 | 618 |
|
616 | | - # Save fragment (if applicable) |
| 619 | + # Save fragment (if applicable) |
617 | 620 | if self.save_fragments and min_dist < 3: |
618 | 621 | fragment_graph.to_zipped_swc(self.fragment_writer[key]) |
619 | 622 |
|
@@ -726,14 +729,24 @@ def find_merge_site(self, key, fragment_graph, kdtree): |
726 | 729 | voxel_j = fragment_graph.voxels[j] |
727 | 730 | gt_voxel = util.kdtree_query(kdtree, voxel_j) |
728 | 731 | if self.physical_dist(gt_voxel, voxel_j) < 2: |
| 732 | + # Save merge swc |
729 | 733 | hit = True |
730 | 734 | merge_cnt = np.sum(list(self.merge_cnt.values())) |
731 | 735 | filename = f"merge-{merge_cnt}.swc" |
732 | 736 | xyz = img_util.to_physical(voxel_j, self.anisotropy) |
733 | 737 | swc_util.to_zipped_point( |
734 | | - self.merge_writer[key], filename, xyz |
| 738 | + self.merge_writer, filename, xyz |
| 739 | + ) |
| 740 | + |
| 741 | + # Save merge in list |
| 742 | + segment_id = util.get_segment_id(fragment_graph.filename) |
| 743 | + self.merge_sites.append( |
| 744 | + { |
| 745 | + "Segment_ID": segment_id, |
| 746 | + "Voxel": voxel_j, |
| 747 | + "XYZ": xyz, |
| 748 | + } |
735 | 749 | ) |
736 | | - self.site_txt_writer.write(f"{tuple(xyz)}\n") |
737 | 750 | break |
738 | 751 |
|
739 | 752 | # Check whether to continue |
|
0 commit comments