Skip to content

Commit 04c48d6

Browse files
author
anna-grim
committed
refactor: simplified image reader
1 parent 985e694 commit 04c48d6

File tree

4 files changed

+125
-247
lines changed

4 files changed

+125
-247
lines changed

src/segmentation_skeleton_metrics/skeleton_metric.py

Lines changed: 94 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from segmentation_skeleton_metrics.utils import (
2828
graph_util as gutil,
2929
img_util,
30-
util
30+
util,
3131
)
3232

3333

@@ -46,12 +46,19 @@ class SkeletonMetric:
4646
(7) Expected Run Length (ERL)
4747
(8) Normalized ERL
4848
49+
Class attributes
50+
----------------
51+
merge_dist : float
52+
...
53+
min_label_cnt : int
54+
...
55+
4956
"""
5057

5158
def __init__(
5259
self,
5360
gt_pointer,
54-
pred_labels,
61+
label_mask,
5562
anisotropy=(1.0, 1.0, 1.0),
5663
connections_path=None,
5764
fragments_pointer=None,
@@ -70,7 +77,7 @@ def __init__(
7077
Pointer to ground truth SWC files, see "swc_util.Reader" for
7178
documentation. These SWC files are assumed to be stored in voxel
7279
coordinates.
73-
pred_labels : ArrayLike
80+
label_mask : ArrayLike
7481
Predicted segmentation mask.
7582
anisotropy : Tuple[float], optional
7683
Image to physical coordinate scaling factors applied to SWC files
@@ -79,7 +86,7 @@ def __init__(
7986
Path to a txt file containing pairs of segment IDs that represents
8087
fragments that were merged. The default is None.
8188
fragments_pointer : Any, optional
82-
Pointer to SWC files corresponding to "pred_labels", see
89+
Pointer to SWC files corresponding to "label_mask", see
8390
"swc_util.Reader" for documentation. Notes: (1) "anisotropy" is
8491
applied to these SWC files and (2) these SWC files are required
8592
for counting merges. The default is None.
@@ -114,7 +121,7 @@ def __init__(
114121
)
115122

116123
# Load data
117-
self.label_mask = pred_labels
124+
self.label_mask = label_mask
118125
self.load_groundtruth(gt_pointer)
119126
self.load_fragments(fragments_pointer)
120127

@@ -125,14 +132,12 @@ def __init__(
125132
# --- Load Data ---
126133
def load_groundtruth(self, swc_pointer):
127134
"""
128-
Initializes "self.graphs" by iterating over "paths" which corresponds
129-
to neurons in the ground truth.
135+
Loads ground truth graphs and initializes the "graphs" attribute.
130136
131137
Parameters
132138
----------
133-
paths : List[str]
134-
List of paths to swc files which correspond to neurons in the
135-
ground truth.
139+
swc_pointer : Any
140+
Pointer to ground truth SWC files.
136141
137142
Returns
138143
-------
@@ -153,6 +158,20 @@ def load_groundtruth(self, swc_pointer):
153158
self.label_graphs(key)
154159

155160
def load_fragments(self, swc_pointer):
161+
"""
162+
Loads fragments generated from the segmentation and initializes the
163+
"fragment_graphs" attribute.
164+
165+
Parameters
166+
----------
167+
swc_pointer : Any
168+
Pointer to predicted SWC files if provided.
169+
170+
Returns
171+
-------
172+
None
173+
174+
"""
156175
print("\n(2) Load Fragments")
157176
if swc_pointer:
158177
graph_builder = gutil.GraphBuilder(
@@ -166,20 +185,28 @@ def load_fragments(self, swc_pointer):
166185
self.fragment_graphs = None
167186

168187
def set_fragment_ids(self):
188+
"""
189+
Sets the "fragment_ids" attribute by extracting unique segment IDs
190+
from the "fragment_graphs" keys.
191+
192+
Returns
193+
-------
194+
None
195+
196+
"""
169197
self.fragment_ids = set()
170198
for key in self.fragment_graphs:
171199
self.fragment_ids.add(util.get_segment_id(key))
172200

173-
def label_graphs(self, key, batch_size=128):
201+
def label_graphs(self, key):
174202
"""
175203
Iterates over nodes in "graph" and stores the corresponding label from
176-
predicted segmentation mask (i.e. "self.label_mask") as a node-level
177-
attribute called "label".
204+
"self.label_mask") as a node-level attribute called "labels".
178205
179206
Parameters
180207
----------
181-
graph : networkx.Graph
182-
Graph that represents a neuron from the ground truth.
208+
key : str
209+
Unique identifier of graph to be labeled.
183210
184211
Returns
185212
-------
@@ -192,15 +219,15 @@ def label_graphs(self, key, batch_size=128):
192219
threads = list()
193220
visited = set()
194221
for i, j in nx.dfs_edges(self.graphs[key]):
195-
# Check for new batch
222+
# Check if starting new batch
196223
if len(batch) == 0:
197224
root = i
198225
batch.add(i)
199226
visited.add(i)
200227

201228
# Check whether to submit batch
202-
is_node_far = self.graphs[key].dist(root, j) > batch_size
203-
is_batch_full = len(batch) >= batch_size
229+
is_node_far = self.graphs[key].dist(root, j) > 128
230+
is_batch_full = len(batch) >= 128
204231
if is_node_far or is_batch_full:
205232
threads.append(
206233
executor.submit(self.get_patch_labels, key, batch)
@@ -214,17 +241,33 @@ def label_graphs(self, key, batch_size=128):
214241
if len(batch) == 1:
215242
root = j
216243

217-
# Submit last thread
244+
# Submit last batch
218245
threads.append(executor.submit(self.get_patch_labels, key, batch))
219246

220-
# Process results
247+
# Store results
221248
self.graphs[key].init_labels()
222249
for thread in as_completed(threads):
223250
node_to_label = thread.result()
224251
for i, label in node_to_label.items():
225252
self.graphs[key].labels[i] = label
226253

227254
def get_patch_labels(self, key, nodes):
255+
"""
256+
Gets the labels for a given set of nodes within a specified patch of
257+
the label mask.
258+
259+
Parameters
260+
----------
261+
key : str
262+
Unique identifier of graph to be labeled.
263+
nodes : list
264+
A list of node IDs for which the labels are to be retrieved.
265+
266+
Returns
267+
-------
268+
dict
269+
A dictionary mapping node IDs to their respective labels.
270+
"""
228271
bbox = self.graphs[key].get_bbox(nodes)
229272
label_patch = self.label_mask.read_with_bbox(bbox)
230273
node_to_label = dict()
@@ -234,18 +277,19 @@ def get_patch_labels(self, key, nodes):
234277
node_to_label[i] = label
235278
return node_to_label
236279

280+
# --------- HERE
237281
def get_all_node_labels(self):
238282
"""
239-
Gets the a set of all unique labels from all graphs in "self.graphs".
283+
Gets the a set of unique labels from all graphs in "self.graphs".
240284
241285
Parameters
242286
----------
243287
None
244288
245289
Returns
246290
-------
247-
set
248-
Set containing all unique labels from all graphs.
291+
Set[int]
292+
Set containing unique labels from all graphs.
249293
250294
"""
251295
all_labels = set()
@@ -257,21 +301,22 @@ def get_all_node_labels(self):
257301

258302
def get_node_labels(self, key, inverse_bool=False):
259303
"""
260-
Gets the set of labels of nodes in the graph corresponding to "key".
304+
Gets the set of labels for nodes in the graph corresponding to the
305+
given key.
261306
262307
Parameters
263308
----------
264309
key : str
265-
ID of graph in "self.graphs".
310+
Unique identifier of graph from which to retrieve the node labels.
266311
inverse_bool : bool
267-
Indication of whether to return original labels from
268-
"self.labels_mask" in the case where labels were remapped. The
269-
default is False.
312+
Indication of whether to return the labels (from "labels_mask") or
313+
a remapping of these labels in the case when "connections_path" is
314+
provided. The default is False.
270315
271316
Returns
272317
-------
273-
set
274-
Labels contained in the graph corresponding to "key".
318+
Set[int]
319+
Labels corresponding to nodes in the graph identified by "key".
275320
276321
"""
277322
if inverse_bool:
@@ -332,14 +377,13 @@ def run(self):
332377
self.quantify_merges()
333378

334379
# Compute metrics
335-
full_results, avg_results = self.compile_results()
336-
return full_results, avg_results
380+
return self.compile_results()
337381

338382
# -- Split Detection --
339383
def detect_splits(self):
340384
"""
341-
Detects splits in the predicted segmentation, then deletes node and
342-
edges in "self.graphs" that correspond to a split.
385+
Detects split and omit edges in the labeled ground truth graphs, then
386+
removes omit nodes.
343387
344388
Parameters
345389
----------
@@ -356,11 +400,7 @@ def detect_splits(self):
356400
processes = list()
357401
for key, graph in self.graphs.items():
358402
processes.append(
359-
executor.submit(
360-
split_detection.run,
361-
key,
362-
graph,
363-
)
403+
executor.submit(split_detection.run, key, graph)
364404
)
365405

366406
# Store results
@@ -373,8 +413,8 @@ def detect_splits(self):
373413

374414
def quantify_splits(self):
375415
"""
376-
Counts the number of splits, number of omit edges, and percent of omit
377-
edges for each graph in "self.graphs".
416+
Counts the number of splits, number of omit edges, and omit edge ratio
417+
in the labeled ground truth graphs.
378418
379419
Parameters
380420
----------
@@ -401,6 +441,8 @@ def quantify_splits(self):
401441
# -- Merge Detection --
402442
def detect_merges(self):
403443
"""
444+
--> HERE
445+
404446
Detects merges in the predicted segmentation, then deletes node and
405447
edges in "self.graphs" that correspond to a merge.
406448
@@ -419,7 +461,7 @@ def detect_merges(self):
419461
self.merged_percent = self.init_counter()
420462
self.merged_labels = set()
421463

422-
# Count total merges
464+
# Detect merges by comparing fragment graphs to ground truth graphs
423465
if self.fragment_graphs:
424466
pbar = tqdm(total=len(self.graphs), desc="Merge Detection")
425467
for key, graph in self.graphs.items():
@@ -433,7 +475,7 @@ def detect_merges(self):
433475
for key in self.graphs:
434476
self.adjust_metrics(key)
435477

436-
# Find graphs with common node labels
478+
# Detect merges by finding ground truth graphs with common node labels
437479
for (key_1, key_2), label in self.find_label_intersections():
438480
self.process_merge(key_1, label, -1)
439481
self.process_merge(key_2, label, -1)
@@ -454,9 +496,9 @@ def count_merges(self, key, kdtree):
454496
Parameters
455497
----------
456498
key : str
457-
ID of graph in "self.graphs".
499+
Unique identifier of graph to detect merges.
458500
kdtree : scipy.spatial.KDTree
459-
A KD-tree built from xyz coordinates in "self.graphs[key]".
501+
A KD-tree built from voxels in graph corresponding to "key".
460502
461503
Returns
462504
-------
@@ -480,11 +522,12 @@ def is_fragment_merge(self, key, label, kdtree):
480522
Parameters
481523
----------
482524
key : str
483-
ID of graph in "self.graphs".
525+
Unique identifier of graph to detect merges.
484526
label : int
485-
ID of fragment.
527+
Label contained in "labels" attribute in the graph corresponding
528+
to "key".
486529
kdtree : scipy.spatial.KDTree
487-
A KD-tree built from xyz coordinates in "self.graphs[key]".
530+
A KD-tree built from voxels in graph corresponding to "key".
488531
489532
Returns
490533
-------
@@ -516,7 +559,8 @@ def adjust_metrics(self, key):
516559
Parameters
517560
----------
518561
key : str
519-
Identifier for the graph to adjust.
562+
Unique identifier of the graph to adjust attributes that are are
563+
used to compute various metrics.
520564
521565
Returns
522566
-------
@@ -552,7 +596,7 @@ def find_label_intersections(self):
552596
553597
Returns
554598
-------
555-
set[tuple]
599+
Set[tuple]
556600
Set of tuples containing a tuple of graph ids and common label
557601
between the graphs.
558602

src/segmentation_skeleton_metrics/utils/graph_util.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,10 @@
44
@author: Anna Grim
55
66
7+
8+
Code for building a custom graph object called a SkeletonGraph and helper
9+
routines for working with graph.
10+
711
"""
812
from concurrent.futures import as_completed, ProcessPoolExecutor
913
from tqdm import tqdm

0 commit comments

Comments
 (0)