Skip to content

Commit be63a22

Browse files
anna-grimanna-grim
andauthored
Refactor merge detection (#154)
* refactor: dataloading * refactor: dataloading --------- Co-authored-by: anna-grim <[email protected]>
1 parent 0c2212c commit be63a22

File tree

3 files changed

+56
-127
lines changed

3 files changed

+56
-127
lines changed

src/segmentation_skeleton_metrics/data_handling/graph_loading.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -18,33 +18,29 @@
1818
import numpy as np
1919

2020
from segmentation_skeleton_metrics.data_handling.skeleton_graph import SkeletonGraph
21-
from segmentation_skeleton_metrics.utils import graph_util as gutil, swc_util, util
21+
from segmentation_skeleton_metrics.utils import swc_util, util
2222

2323

2424
class DataLoader:
2525

2626
def __init__(
2727
self,
28+
label_handler,
2829
anisotropy=(1.0, 1.0, 1.0),
29-
connections_path=None,
3030
use_anisotropy=False,
31-
valid_labels=None,
3231
verbose=True
3332
):
3433
# Instance attributes
3534
self.anisotropy = anisotropy
35+
self.label_handler = label_handler
3636
self.use_anisotropy = use_anisotropy
37-
self.valid_labels = valid_labels
3837
self.verbose = verbose
3938

40-
# Label handler
41-
self.label_handler = LabelHandler(connections_path, valid_labels)
42-
4339
# --- Core Routines ---
4440
def load_groundtruth(self, swc_pointer, label_mask):
4541
"""
4642
Loads ground truth graphs.
47-
43+
4844
Parameters
4945
----------
5046
swc_pointer : str
@@ -60,7 +56,7 @@ def load_groundtruth(self, swc_pointer, label_mask):
6056
if self.verbose:
6157
print("\n(1) Load Ground Truth")
6258

63-
graph_loader = gutil.GraphLoader(
59+
graph_loader = GraphLoader(
6460
anisotropy=self.anisotropy,
6561
is_groundtruth=True,
6662
label_handler=self.label_handler,
@@ -94,7 +90,7 @@ def load_fragments(self, swc_pointer, gt_graphs):
9490

9591
# Load fragments
9692
selected_ids = self.get_all_node_labels(gt_graphs)
97-
graph_loader = gutil.GraphLoader(
93+
graph_loader = GraphLoader(
9894
anisotropy=self.anisotropy,
9995
is_groundtruth=False,
10096
selected_ids=selected_ids,

src/segmentation_skeleton_metrics/evaluate.py

Lines changed: 32 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,9 @@
99
"""
1010

1111
from segmentation_skeleton_metrics.skeleton_metric import SkeletonMetric
12-
from segmentation_skeleton_metrics.data_handling.graph_loading import DataLoader
12+
from segmentation_skeleton_metrics.data_handling.graph_loading import (
13+
DataLoader, LabelHandler
14+
)
1315

1416

1517
def evaluate(
@@ -25,28 +27,48 @@ def evaluate(
2527
valid_labels=None,
2628
verbose=True
2729
):
30+
"""
31+
...
32+
33+
Parameters
34+
----------
35+
gt_pointer : str
36+
Pointer to ground truth SWC files, see "swc_util.Reader" for
37+
documentation. These SWC files are assumed to be stored in voxel
38+
coordinates.
39+
label_mask : ImageReader
40+
Predicted segmentation.
41+
anisotropy : Tuple[float], optional
42+
...
43+
connections_path : str, optional
44+
Path to a txt file containing pairs of segment IDs that represents
45+
fragments that were merged. Default is None.
46+
fragments_pointer : str, optional
47+
Pointer to SWC files corresponding to "label_mask", see
48+
"swc_util.Reader" for documentation. Notes: (1) "anisotropy" is
49+
applied to these SWC files and (2) these SWC files are required
50+
for counting merges. Default is None.
51+
"""
2852
# Load data
53+
label_handler = LabelHandler(connections_path, valid_labels)
2954
dataloader = DataLoader(
55+
label_handler,
3056
anisotropy=anisotropy,
31-
connections_path=connections_path,
3257
use_anisotropy=use_anisotropy,
33-
valid_labels=valid_labels,
3458
verbose=verbose
35-
)
59+
)
3660
gt_graphs = dataloader.load_groundtruth(gt_pointer, label_mask)
37-
fragments_graph = dataloader.load_fragments(fragments_pointer, gt_graphs)
61+
fragment_graphs = dataloader.load_fragments(fragments_pointer, gt_graphs)
3862

3963
# Evaluator
4064
skeleton_metric = SkeletonMetric(
41-
gt_pointer,
42-
label_mask,
65+
gt_graphs,
66+
fragment_graphs,
67+
label_handler,
4368
anisotropy=anisotropy,
44-
connections_path=connections_path,
45-
fragments_pointer=fragments_pointer,
4669
output_dir=output_dir,
4770
save_merges=save_merges,
4871
save_fragments=save_fragments,
49-
use_anisotropy=use_anisotropy,
5072
)
5173
skeleton_metric.run()
5274

src/segmentation_skeleton_metrics/skeleton_metric.py

Lines changed: 18 additions & 107 deletions
Original file line numberDiff line numberDiff line change
@@ -53,16 +53,13 @@ class SkeletonMetric:
5353

5454
def __init__(
5555
self,
56-
gt_pointer,
57-
label_mask,
56+
gt_graphs,
57+
fragment_graphs,
58+
label_handler,
5859
output_dir,
5960
anisotropy=(1.0, 1.0, 1.0),
60-
connections_path=None,
61-
fragments_pointer=None,
6261
save_merges=False,
6362
save_fragments=False,
64-
use_anisotropy=False,
65-
valid_labels=None,
6663
verbose=True
6764
):
6865
"""
@@ -71,59 +68,41 @@ def __init__(
7168
7269
Parameters
7370
----------
74-
gt_pointer : Any
75-
Pointer to ground truth SWC files, see "swc_util.Reader" for
76-
documentation. These SWC files are assumed to be stored in voxel
77-
coordinates.
78-
label_mask : ImageReader
79-
Predicted segmentation.
80-
output_dir : str
71+
gt_graphs : SkeletonGraph
72+
...
73+
fragment_graphs ; SkeletonGraph
74+
...
75+
label_handler : LabelHandler
76+
...
77+
output_dir : str
8178
Path to directory wehere results are written.
8279
anisotropy : Tuple[float], optional
8380
Image to physical coordinate scaling factors applied to SWC files
8481
stored at "fragments_pointer". Default is (1.0, 1.0, 1.0).
85-
connections_path : str, optional
86-
Path to a txt file containing pairs of segment IDs that represents
87-
fragments that were merged. Default is None.
88-
fragments_pointer : Any, optional
89-
Pointer to SWC files corresponding to "label_mask", see
90-
"swc_util.Reader" for documentation. Notes: (1) "anisotropy" is
91-
applied to these SWC files and (2) these SWC files are required
92-
for counting merges. Default is None.
82+
output_dir : str
83+
...
9384
save_merges: bool, optional
9485
Indication of whether to save fragments with a merge mistake.
9586
Default is None.
9687
save_fragments : bool, optional
9788
Indication of whether to save fragments that project onto each
9889
ground truth skeleton. Default is False.
99-
valid_labels : Set[int], optional
100-
Segment IDs that can be assigned to nodes. This argument accounts
101-
for segments that were been removed due to some type of filtering.
102-
Default is None.
103-
use_anisotropy : bool, optional
104-
Indication of whether coordinates in fragment SWC files should be
105-
converted from physical to image coordinates using the given
106-
anisotropy. Default is False.
10790
verbose : bool, optional
10891
Indication of whether to printout updates. Default is True.
10992
"""
11093
# Instance attributes
11194
self.anisotropy = anisotropy
112-
self.connections_path = connections_path
11395
self.output_dir = output_dir
11496
self.save_merges = save_merges
11597
self.save_fragments = save_fragments
116-
self.use_anisotropy = use_anisotropy
11798
self.verbose = verbose
11899

119-
# Label handler
120-
self.label_handler = LabelHandler(
121-
connections_path=connections_path, valid_labels=valid_labels
122-
)
100+
# Core data structures
101+
self.graphs = gt_graphs
102+
self.fragment_graphs = fragment_graphs
103+
self.label_handler = label_handler
123104

124-
# Load data
125-
self.load_groundtruth(gt_pointer, label_mask)
126-
self.load_fragments(fragments_pointer)
105+
self.gt_graphs = deepcopy(self.graphs)
127106

128107
# Initialize metrics
129108
util.mkdir(output_dir)
@@ -151,74 +130,6 @@ def __init__(
151130
self.metrics["# Splits"] = 0
152131
self.metrics["SWC Name"] = self.metrics.index
153132

154-
# --- Load Data ---
155-
def load_groundtruth(self, swc_pointer, label_mask):
156-
"""
157-
Loads ground truth graphs and initializes the "graphs" attribute.
158-
159-
Parameters
160-
----------
161-
swc_pointer : Any
162-
Pointer to ground truth SWC files.
163-
label_mask : ImageReader
164-
Predicted segmentation mask.
165-
"""
166-
if self.verbose:
167-
print("\n(1) Load Ground Truth")
168-
169-
# Build graphs
170-
graph_loader = gutil.GraphLoader(
171-
anisotropy=self.anisotropy,
172-
is_groundtruth=True,
173-
label_handler=self.label_handler,
174-
label_mask=label_mask,
175-
use_anisotropy=False,
176-
)
177-
self.graphs = graph_loader.run(swc_pointer)
178-
179-
# Save initial graphs (if applicable)
180-
if self.save_merges:
181-
self.gt_graphs = deepcopy(self.graphs)
182-
183-
def load_fragments(self, swc_pointer):
184-
"""
185-
Loads fragments generated from the segmentation and initializes the
186-
"fragment_graphs" attribute.
187-
188-
Parameters
189-
----------
190-
swc_pointer : Any
191-
Pointer to predicted SWC files if provided.
192-
"""
193-
if self.verbose:
194-
print("\n(2) Load Fragments")
195-
196-
if swc_pointer:
197-
graph_loader = gutil.GraphLoader(
198-
anisotropy=self.anisotropy,
199-
is_groundtruth=False,
200-
selected_ids=self.get_all_node_labels(),
201-
use_anisotropy=self.use_anisotropy,
202-
)
203-
self.fragment_graphs = graph_loader.run(swc_pointer)
204-
else:
205-
self.fragment_graphs = None
206-
207-
def get_all_node_labels(self):
208-
"""
209-
Gets the set of unique node labels across all graphs in "self.graphs".
210-
211-
Returns
212-
-------
213-
Set[int]
214-
Unique node labels across all graphs.
215-
"""
216-
all_node_labels = set()
217-
for graph in self.graphs.values():
218-
node_labels = self.label_handler.get_node_labels(graph)
219-
all_node_labels = all_node_labels.union(node_labels)
220-
return all_node_labels
221-
222133
def init_writers(self):
223134
"""
224135
Initializes "self.merge_writer" attribute by setting up a directory for
@@ -259,7 +170,7 @@ def run(self):
259170
self.compute_erl()
260171

261172
# Save results
262-
prefix = "corrected-" if self.connections_path else ""
173+
prefix = "corrected-" if self.label_handler.use_mapping() else ""
263174
path = f"{self.output_dir}/{prefix}results.csv"
264175
if self.fragment_graphs is None:
265176
self.metrics = self.metrics.drop("# Merges", axis=1)

0 commit comments

Comments
 (0)