2727from segmentation_skeleton_metrics .utils import (
2828 graph_util as gutil ,
2929 img_util ,
30- util
30+ util ,
3131)
3232
3333
@@ -46,12 +46,19 @@ class SkeletonMetric:
4646 (7) Expected Run Length (ERL)
4747 (8) Normalized ERL
4848
49+ Class attributes
50+ ----------------
51+ merge_dist : float
52+ ...
53+ min_label_cnt : int
54+ ...
55+
4956 """
5057
5158 def __init__ (
5259 self ,
5360 gt_pointer ,
54- pred_labels ,
61+ label_mask ,
5562 anisotropy = (1.0 , 1.0 , 1.0 ),
5663 connections_path = None ,
5764 fragments_pointer = None ,
@@ -70,7 +77,7 @@ def __init__(
7077 Pointer to ground truth SWC files, see "swc_util.Reader" for
7178 documentation. These SWC files are assumed to be stored in voxel
7279 coordinates.
73- pred_labels : ArrayLike
80+ label_mask : ArrayLike
7481 Predicted segmentation mask.
7582 anisotropy : Tuple[float], optional
7683 Image to physical coordinate scaling factors applied to SWC files
@@ -79,7 +86,7 @@ def __init__(
7986 Path to a txt file containing pairs of segment IDs that represents
8087 fragments that were merged. The default is None.
8188 fragments_pointer : Any, optional
82- Pointer to SWC files corresponding to "pred_labels ", see
89+ Pointer to SWC files corresponding to "label_mask ", see
8390 "swc_util.Reader" for documentation. Notes: (1) "anisotropy" is
8491 applied to these SWC files and (2) these SWC files are required
8592 for counting merges. The default is None.
@@ -114,7 +121,7 @@ def __init__(
114121 )
115122
116123 # Load data
117- self .label_mask = pred_labels
124+ self .label_mask = label_mask
118125 self .load_groundtruth (gt_pointer )
119126 self .load_fragments (fragments_pointer )
120127
@@ -125,14 +132,12 @@ def __init__(
125132 # --- Load Data ---
126133 def load_groundtruth (self , swc_pointer ):
127134 """
128- Initializes "self.graphs" by iterating over "paths" which corresponds
129- to neurons in the ground truth.
135+ Loads ground truth graphs and initializes the "graphs" attribute.
130136
131137 Parameters
132138 ----------
133- paths : List[str]
134- List of paths to swc files which correspond to neurons in the
135- ground truth.
139+ swc_pointer : Any
140+ Pointer to ground truth SWC files.
136141
137142 Returns
138143 -------
@@ -153,6 +158,20 @@ def load_groundtruth(self, swc_pointer):
153158 self .label_graphs (key )
154159
155160 def load_fragments (self , swc_pointer ):
161+ """
162+ Loads fragments generated from the segmentation and initializes the
163+ "fragment_graphs" attribute.
164+
165+ Parameters
166+ ----------
167+ swc_pointer : Any
168+ Pointer to predicted SWC files if provided.
169+
170+ Returns
171+ -------
172+ None
173+
174+ """
156175 print ("\n (2) Load Fragments" )
157176 if swc_pointer :
158177 graph_builder = gutil .GraphBuilder (
@@ -166,20 +185,28 @@ def load_fragments(self, swc_pointer):
166185 self .fragment_graphs = None
167186
168187 def set_fragment_ids (self ):
188+ """
189+ Sets the "fragment_ids" attribute by extracting unique segment IDs
190+ from the "fragment_graphs" keys.
191+
192+ Returns
193+ -------
194+ None
195+
196+ """
169197 self .fragment_ids = set ()
170198 for key in self .fragment_graphs :
171199 self .fragment_ids .add (util .get_segment_id (key ))
172200
173- def label_graphs (self , key , batch_size = 128 ):
201+ def label_graphs (self , key ):
174202 """
175203 Iterates over nodes in "graph" and stores the corresponding label from
176- predicted segmentation mask (i.e. "self.label_mask") as a node-level
177- attribute called "label".
204+ "self.label_mask") as a node-level attribute called "labels".
178205
179206 Parameters
180207 ----------
181- graph : networkx.Graph
182- Graph that represents a neuron from the ground truth .
208+ key : str
209+ Unique identifier of graph to be labeled .
183210
184211 Returns
185212 -------
@@ -192,15 +219,15 @@ def label_graphs(self, key, batch_size=128):
192219 threads = list ()
193220 visited = set ()
194221 for i , j in nx .dfs_edges (self .graphs [key ]):
195- # Check for new batch
222+ # Check if starting new batch
196223 if len (batch ) == 0 :
197224 root = i
198225 batch .add (i )
199226 visited .add (i )
200227
201228 # Check whether to submit batch
202- is_node_far = self .graphs [key ].dist (root , j ) > batch_size
203- is_batch_full = len (batch ) >= batch_size
229+ is_node_far = self .graphs [key ].dist (root , j ) > 128
230+ is_batch_full = len (batch ) >= 128
204231 if is_node_far or is_batch_full :
205232 threads .append (
206233 executor .submit (self .get_patch_labels , key , batch )
@@ -214,17 +241,33 @@ def label_graphs(self, key, batch_size=128):
214241 if len (batch ) == 1 :
215242 root = j
216243
217- # Submit last thread
244+ # Submit last batch
218245 threads .append (executor .submit (self .get_patch_labels , key , batch ))
219246
220- # Process results
247+ # Store results
221248 self .graphs [key ].init_labels ()
222249 for thread in as_completed (threads ):
223250 node_to_label = thread .result ()
224251 for i , label in node_to_label .items ():
225252 self .graphs [key ].labels [i ] = label
226253
227254 def get_patch_labels (self , key , nodes ):
255+ """
256+ Gets the labels for a given set of nodes within a specified patch of
257+ the label mask.
258+
259+ Parameters
260+ ----------
261+ key : str
262+ Unique identifier of graph to be labeled.
263+ nodes : list
264+ A list of node IDs for which the labels are to be retrieved.
265+
266+ Returns
267+ -------
268+ dict
269+ A dictionary mapping node IDs to their respective labels.
270+ """
228271 bbox = self .graphs [key ].get_bbox (nodes )
229272 label_patch = self .label_mask .read_with_bbox (bbox )
230273 node_to_label = dict ()
@@ -234,18 +277,19 @@ def get_patch_labels(self, key, nodes):
234277 node_to_label [i ] = label
235278 return node_to_label
236279
280+ # --------- HERE
237281 def get_all_node_labels (self ):
238282 """
239- Gets the a set of all unique labels from all graphs in "self.graphs".
283+ Gets the a set of unique labels from all graphs in "self.graphs".
240284
241285 Parameters
242286 ----------
243287 None
244288
245289 Returns
246290 -------
247- set
248- Set containing all unique labels from all graphs.
291+ Set[int]
292+ Set containing unique labels from all graphs.
249293
250294 """
251295 all_labels = set ()
@@ -257,21 +301,22 @@ def get_all_node_labels(self):
257301
258302 def get_node_labels (self , key , inverse_bool = False ):
259303 """
260- Gets the set of labels of nodes in the graph corresponding to "key".
304+ Gets the set of labels for nodes in the graph corresponding to the
305+ given key.
261306
262307 Parameters
263308 ----------
264309 key : str
265- ID of graph in "self.graphs" .
310+ Unique identifier of graph from which to retrieve the node labels .
266311 inverse_bool : bool
267- Indication of whether to return original labels from
268- "self.labels_mask" in the case where labels were remapped. The
269- default is False.
312+ Indication of whether to return the labels ( from "labels_mask") or
313+ a remapping of these labels in the case when "connections_path" is
314+ provided. The default is False.
270315
271316 Returns
272317 -------
273- set
274- Labels contained in the graph corresponding to "key".
318+ Set[int]
319+ Labels corresponding to nodes in the graph identified by "key".
275320
276321 """
277322 if inverse_bool :
@@ -332,14 +377,13 @@ def run(self):
332377 self .quantify_merges ()
333378
334379 # Compute metrics
335- full_results , avg_results = self .compile_results ()
336- return full_results , avg_results
380+ return self .compile_results ()
337381
338382 # -- Split Detection --
339383 def detect_splits (self ):
340384 """
341- Detects splits in the predicted segmentation, then deletes node and
342- edges in "self.graphs" that correspond to a split .
385+ Detects split and omit edges in the labeled ground truth graphs, then
386+ removes omit nodes .
343387
344388 Parameters
345389 ----------
@@ -356,11 +400,7 @@ def detect_splits(self):
356400 processes = list ()
357401 for key , graph in self .graphs .items ():
358402 processes .append (
359- executor .submit (
360- split_detection .run ,
361- key ,
362- graph ,
363- )
403+ executor .submit (split_detection .run , key , graph )
364404 )
365405
366406 # Store results
@@ -373,8 +413,8 @@ def detect_splits(self):
373413
374414 def quantify_splits (self ):
375415 """
376- Counts the number of splits, number of omit edges, and percent of omit
377- edges for each graph in "self. graphs" .
416+ Counts the number of splits, number of omit edges, and omit edge ratio
417+ in the labeled ground truth graphs.
378418
379419 Parameters
380420 ----------
@@ -401,6 +441,8 @@ def quantify_splits(self):
401441 # -- Merge Detection --
402442 def detect_merges (self ):
403443 """
444+ --> HERE
445+
404446 Detects merges in the predicted segmentation, then deletes node and
405447 edges in "self.graphs" that correspond to a merge.
406448
@@ -419,7 +461,7 @@ def detect_merges(self):
419461 self .merged_percent = self .init_counter ()
420462 self .merged_labels = set ()
421463
422- # Count total merges
464+ # Detect merges by comparing fragment graphs to ground truth graphs
423465 if self .fragment_graphs :
424466 pbar = tqdm (total = len (self .graphs ), desc = "Merge Detection" )
425467 for key , graph in self .graphs .items ():
@@ -433,7 +475,7 @@ def detect_merges(self):
433475 for key in self .graphs :
434476 self .adjust_metrics (key )
435477
436- # Find graphs with common node labels
478+ # Detect merges by finding ground truth graphs with common node labels
437479 for (key_1 , key_2 ), label in self .find_label_intersections ():
438480 self .process_merge (key_1 , label , - 1 )
439481 self .process_merge (key_2 , label , - 1 )
@@ -454,9 +496,9 @@ def count_merges(self, key, kdtree):
454496 Parameters
455497 ----------
456498 key : str
457- ID of graph in "self.graphs" .
499+ Unique identifier of graph to detect merges .
458500 kdtree : scipy.spatial.KDTree
459- A KD-tree built from xyz coordinates in "self.graphs[ key] ".
501+ A KD-tree built from voxels in graph corresponding to " key".
460502
461503 Returns
462504 -------
@@ -480,11 +522,12 @@ def is_fragment_merge(self, key, label, kdtree):
480522 Parameters
481523 ----------
482524 key : str
483- ID of graph in "self.graphs" .
525+ Unique identifier of graph to detect merges .
484526 label : int
485- ID of fragment.
527+ Label contained in "labels" attribute in the graph corresponding
528+ to "key".
486529 kdtree : scipy.spatial.KDTree
487- A KD-tree built from xyz coordinates in "self.graphs[ key] ".
530+ A KD-tree built from voxels in graph corresponding to " key".
488531
489532 Returns
490533 -------
@@ -516,7 +559,8 @@ def adjust_metrics(self, key):
516559 Parameters
517560 ----------
518561 key : str
519- Identifier for the graph to adjust.
562+ Unique identifier of the graph to adjust attributes that are are
563+ used to compute various metrics.
520564
521565 Returns
522566 -------
@@ -552,7 +596,7 @@ def find_label_intersections(self):
552596
553597 Returns
554598 -------
555- set [tuple]
599+ Set [tuple]
556600 Set of tuples containing a tuple of graph ids and common label
557601 between the graphs.
558602
0 commit comments