Skip to content

Commit 8a6cf74

Browse files
author
anna-grim
committed
refactor: dataloading
1 parent 1f8ec73 commit 8a6cf74

File tree

3 files changed

+56
-128
lines changed

3 files changed

+56
-128
lines changed

src/segmentation_skeleton_metrics/data_handling/graph_loading.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -18,33 +18,29 @@
1818
import numpy as np
1919

2020
from segmentation_skeleton_metrics.data_handling.skeleton_graph import SkeletonGraph
21-
from segmentation_skeleton_metrics.utils import graph_util as gutil, swc_util, util
21+
from segmentation_skeleton_metrics.utils import swc_util, util
2222

2323

2424
class DataLoader:
2525

2626
def __init__(
2727
self,
28+
label_handler,
2829
anisotropy=(1.0, 1.0, 1.0),
29-
connections_path=None,
3030
use_anisotropy=False,
31-
valid_labels=None,
3231
verbose=True
3332
):
3433
# Instance attributes
3534
self.anisotropy = anisotropy
35+
self.label_handler = label_handler
3636
self.use_anisotropy = use_anisotropy
37-
self.valid_labels = valid_labels
3837
self.verbose = verbose
3938

40-
# Label handler
41-
self.label_handler = LabelHandler(connections_path, valid_labels)
42-
4339
# --- Core Routines ---
4440
def load_groundtruth(self, swc_pointer, label_mask):
4541
"""
4642
Loads ground truth graphs.
47-
43+
4844
Parameters
4945
----------
5046
swc_pointer : str
@@ -60,7 +56,7 @@ def load_groundtruth(self, swc_pointer, label_mask):
6056
if self.verbose:
6157
print("\n(1) Load Ground Truth")
6258

63-
graph_loader = gutil.GraphLoader(
59+
graph_loader = GraphLoader(
6460
anisotropy=self.anisotropy,
6561
is_groundtruth=True,
6662
label_handler=self.label_handler,
@@ -94,7 +90,7 @@ def load_fragments(self, swc_pointer, gt_graphs):
9490

9591
# Load fragments
9692
selected_ids = self.get_all_node_labels(gt_graphs)
97-
graph_loader = gutil.GraphLoader(
93+
graph_loader = GraphLoader(
9894
anisotropy=self.anisotropy,
9995
is_groundtruth=False,
10096
selected_ids=selected_ids,

src/segmentation_skeleton_metrics/evaluate.py

Lines changed: 32 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,9 @@
99
"""
1010

1111
from segmentation_skeleton_metrics.skeleton_metric import SkeletonMetric
12-
from segmentation_skeleton_metrics.data_handling.graph_loading import DataLoader
12+
from segmentation_skeleton_metrics.data_handling.graph_loading import (
13+
DataLoader, LabelHandler
14+
)
1315

1416

1517
def evaluate(
@@ -25,28 +27,48 @@ def evaluate(
2527
valid_labels=None,
2628
verbose=True
2729
):
30+
"""
31+
...
32+
33+
Parameters
34+
----------
35+
gt_pointer : str
36+
Pointer to ground truth SWC files, see "swc_util.Reader" for
37+
documentation. These SWC files are assumed to be stored in voxel
38+
coordinates.
39+
label_mask : ImageReader
40+
Predicted segmentation.
41+
anisotropy : Tuple[float], optional
42+
...
43+
connections_path : str, optional
44+
Path to a txt file containing pairs of segment IDs that represents
45+
fragments that were merged. Default is None.
46+
fragments_pointer : str, optional
47+
Pointer to SWC files corresponding to "label_mask", see
48+
"swc_util.Reader" for documentation. Notes: (1) "anisotropy" is
49+
applied to these SWC files and (2) these SWC files are required
50+
for counting merges. Default is None.
51+
"""
2852
# Load data
53+
label_handler = LabelHandler(connections_path, valid_labels)
2954
dataloader = DataLoader(
55+
label_handler,
3056
anisotropy=anisotropy,
31-
connections_path=connections_path,
3257
use_anisotropy=use_anisotropy,
33-
valid_labels=valid_labels,
3458
verbose=verbose
35-
)
59+
)
3660
gt_graphs = dataloader.load_groundtruth(gt_pointer, label_mask)
37-
fragments_graph = dataloader.load_fragments(fragments_pointer, gt_graphs)
61+
fragment_graphs = dataloader.load_fragments(fragments_pointer, gt_graphs)
3862

3963
# Evaluator
4064
skeleton_metric = SkeletonMetric(
41-
gt_pointer,
42-
label_mask,
65+
gt_graphs,
66+
fragment_graphs,
67+
label_handler,
4368
anisotropy=anisotropy,
44-
connections_path=connections_path,
45-
fragments_pointer=fragments_pointer,
4669
output_dir=output_dir,
4770
save_merges=save_merges,
4871
save_fragments=save_fragments,
49-
use_anisotropy=use_anisotropy,
5072
)
5173
skeleton_metric.run()
5274

src/segmentation_skeleton_metrics/skeleton_metric.py

Lines changed: 18 additions & 108 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
import os
2424
import pandas as pd
2525

26-
from segmentation_skeleton_metrics.data_handling.graph_loading import LabelHandler
2726
from segmentation_skeleton_metrics import split_detection
2827
from segmentation_skeleton_metrics.utils import (
2928
graph_util as gutil,
@@ -53,16 +52,13 @@ class SkeletonMetric:
5352

5453
def __init__(
5554
self,
56-
gt_pointer,
57-
label_mask,
55+
gt_graphs,
56+
fragment_graphs,
57+
label_handler,
5858
output_dir,
5959
anisotropy=(1.0, 1.0, 1.0),
60-
connections_path=None,
61-
fragments_pointer=None,
6260
save_merges=False,
6361
save_fragments=False,
64-
use_anisotropy=False,
65-
valid_labels=None,
6662
verbose=True
6763
):
6864
"""
@@ -71,59 +67,41 @@ def __init__(
7167
7268
Parameters
7369
----------
74-
gt_pointer : Any
75-
Pointer to ground truth SWC files, see "swc_util.Reader" for
76-
documentation. These SWC files are assumed to be stored in voxel
77-
coordinates.
78-
label_mask : ImageReader
79-
Predicted segmentation.
80-
output_dir : str
70+
gt_graphs : SkeletonGraph
71+
...
72+
fragment_graphs ; SkeletonGraph
73+
...
74+
label_handler : LabelHandler
75+
...
76+
output_dir : str
8177
Path to directory wehere results are written.
8278
anisotropy : Tuple[float], optional
8379
Image to physical coordinate scaling factors applied to SWC files
8480
stored at "fragments_pointer". Default is (1.0, 1.0, 1.0).
85-
connections_path : str, optional
86-
Path to a txt file containing pairs of segment IDs that represents
87-
fragments that were merged. Default is None.
88-
fragments_pointer : Any, optional
89-
Pointer to SWC files corresponding to "label_mask", see
90-
"swc_util.Reader" for documentation. Notes: (1) "anisotropy" is
91-
applied to these SWC files and (2) these SWC files are required
92-
for counting merges. Default is None.
81+
output_dir : str
82+
...
9383
save_merges: bool, optional
9484
Indication of whether to save fragments with a merge mistake.
9585
Default is None.
9686
save_fragments : bool, optional
9787
Indication of whether to save fragments that project onto each
9888
ground truth skeleton. Default is False.
99-
valid_labels : Set[int], optional
100-
Segment IDs that can be assigned to nodes. This argument accounts
101-
for segments that were been removed due to some type of filtering.
102-
Default is None.
103-
use_anisotropy : bool, optional
104-
Indication of whether coordinates in fragment SWC files should be
105-
converted from physical to image coordinates using the given
106-
anisotropy. Default is False.
10789
verbose : bool, optional
10890
Indication of whether to printout updates. Default is True.
10991
"""
11092
# Instance attributes
11193
self.anisotropy = anisotropy
112-
self.connections_path = connections_path
11394
self.output_dir = output_dir
11495
self.save_merges = save_merges
11596
self.save_fragments = save_fragments
116-
self.use_anisotropy = use_anisotropy
11797
self.verbose = verbose
11898

119-
# Label handler
120-
self.label_handler = LabelHandler(
121-
connections_path=connections_path, valid_labels=valid_labels
122-
)
99+
# Core data structures
100+
self.graphs = gt_graphs
101+
self.fragment_graphs = fragment_graphs
102+
self.label_handler = label_handler
123103

124-
# Load data
125-
self.load_groundtruth(gt_pointer, label_mask)
126-
self.load_fragments(fragments_pointer)
104+
self.gt_graphs = deepcopy(self.graphs)
127105

128106
# Initialize metrics
129107
util.mkdir(output_dir)
@@ -151,74 +129,6 @@ def __init__(
151129
self.metrics["# Splits"] = 0
152130
self.metrics["SWC Name"] = self.metrics.index
153131

154-
# --- Load Data ---
155-
def load_groundtruth(self, swc_pointer, label_mask):
156-
"""
157-
Loads ground truth graphs and initializes the "graphs" attribute.
158-
159-
Parameters
160-
----------
161-
swc_pointer : Any
162-
Pointer to ground truth SWC files.
163-
label_mask : ImageReader
164-
Predicted segmentation mask.
165-
"""
166-
if self.verbose:
167-
print("\n(1) Load Ground Truth")
168-
169-
# Build graphs
170-
graph_loader = gutil.GraphLoader(
171-
anisotropy=self.anisotropy,
172-
is_groundtruth=True,
173-
label_handler=self.label_handler,
174-
label_mask=label_mask,
175-
use_anisotropy=False,
176-
)
177-
self.graphs = graph_loader.run(swc_pointer)
178-
179-
# Save initial graphs (if applicable)
180-
if self.save_merges:
181-
self.gt_graphs = deepcopy(self.graphs)
182-
183-
def load_fragments(self, swc_pointer):
184-
"""
185-
Loads fragments generated from the segmentation and initializes the
186-
"fragment_graphs" attribute.
187-
188-
Parameters
189-
----------
190-
swc_pointer : Any
191-
Pointer to predicted SWC files if provided.
192-
"""
193-
if self.verbose:
194-
print("\n(2) Load Fragments")
195-
196-
if swc_pointer:
197-
graph_loader = gutil.GraphLoader(
198-
anisotropy=self.anisotropy,
199-
is_groundtruth=False,
200-
selected_ids=self.get_all_node_labels(),
201-
use_anisotropy=self.use_anisotropy,
202-
)
203-
self.fragment_graphs = graph_loader.run(swc_pointer)
204-
else:
205-
self.fragment_graphs = None
206-
207-
def get_all_node_labels(self):
208-
"""
209-
Gets the set of unique node labels across all graphs in "self.graphs".
210-
211-
Returns
212-
-------
213-
Set[int]
214-
Unique node labels across all graphs.
215-
"""
216-
all_node_labels = set()
217-
for graph in self.graphs.values():
218-
node_labels = self.label_handler.get_node_labels(graph)
219-
all_node_labels = all_node_labels.union(node_labels)
220-
return all_node_labels
221-
222132
def init_writers(self):
223133
"""
224134
Initializes "self.merge_writer" attribute by setting up a directory for
@@ -259,7 +169,7 @@ def run(self):
259169
self.compute_erl()
260170

261171
# Save results
262-
prefix = "corrected-" if self.connections_path else ""
172+
prefix = "corrected-" if self.label_handler.use_mapping() else ""
263173
path = f"{self.output_dir}/{prefix}results.csv"
264174
if self.fragment_graphs is None:
265175
self.metrics = self.metrics.drop("# Merges", axis=1)

0 commit comments

Comments
 (0)