Skip to content

Commit 2b30d70

Browse files
author
anna-grim
committed
refactor: report results in run routine
1 parent a11c2b6 commit 2b30d70

File tree

3 files changed

+35
-29
lines changed

3 files changed

+35
-29
lines changed

demo/demo.py

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -29,27 +29,15 @@ def evaluate():
2929
3030
"""
3131
# Initializations
32+
path = f"{output_dir}/results.xls"
3233
pred_labels = TiffReader(pred_labels_path)
3334
skeleton_metric = SkeletonMetric(
3435
groundtruth_pointer,
3536
pred_labels,
3637
fragments_pointer=fragments_pointer,
3738
output_dir=output_dir,
3839
)
39-
full_results, avg_results = skeleton_metric.run()
40-
41-
# Report results
42-
print(f"\nAveraged Results...")
43-
for key in avg_results.keys():
44-
print(f" {key}: {round(avg_results[key], 4)}")
45-
46-
print(f"\nTotal Results...")
47-
print("# splits:", skeleton_metric.count_total_splits())
48-
print("# merges:", skeleton_metric.count_total_merges())
49-
50-
# Save results
51-
path = f"{output_dir}/evaluation_results.xls"
52-
util.save_results(path, full_results)
40+
full_results, avg_results = skeleton_metric.run(path)
5341

5442

5543
if __name__ == "__main__":

demo/results.xls

5.5 KB
Binary file not shown.

src/segmentation_skeleton_metrics/skeleton_metric.py

Lines changed: 33 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,10 @@ def set_fragment_ids(self):
189189
Sets the "fragment_ids" attribute by extracting unique segment IDs
190190
from the "fragment_graphs" keys.
191191
192+
Parameters
193+
----------
194+
None
195+
192196
Returns
193197
-------
194198
None
@@ -253,20 +257,21 @@ def label_graphs(self, key):
253257

254258
def get_patch_labels(self, key, nodes):
255259
"""
256-
Gets the labels for a given set of nodes within a specified patch of
257-
the label mask.
260+
Gets the segment labels for a given set of nodes within a specified
261+
patch of the label mask.
258262
259263
Parameters
260264
----------
261265
key : str
262266
Unique identifier of graph to be labeled.
263-
nodes : list
264-
A list of node IDs for which the labels are to be retrieved.
267+
nodes : List[int]
268+
Node IDs for which the labels are to be retrieved.
265269
266270
Returns
267271
-------
268272
dict
269-
A dictionary mapping node IDs to their respective labels.
273+
A dictionary that maps node IDs to their respective labels.
274+
270275
"""
271276
bbox = self.graphs[key].get_bbox(nodes)
272277
label_patch = self.label_mask.read_with_bbox(bbox)
@@ -277,10 +282,9 @@ def get_patch_labels(self, key, nodes):
277282
node_to_label[i] = label
278283
return node_to_label
279284

280-
# --------- HERE
281285
def get_all_node_labels(self):
282286
"""
283-
Gets the a set of unique labels from all graphs in "self.graphs".
287+
Gets the set of unique node labels from all graphs in "self.graphs".
284288
285289
Parameters
286290
----------
@@ -289,7 +293,7 @@ def get_all_node_labels(self):
289293
Returns
290294
-------
291295
Set[int]
292-
Set containing unique labels from all graphs.
296+
Set of unique node labels from all graphs.
293297
294298
"""
295299
all_labels = set()
@@ -301,7 +305,7 @@ def get_all_node_labels(self):
301305

302306
def get_node_labels(self, key, inverse_bool=False):
303307
"""
304-
Gets the set of labels for nodes in the graph corresponding to the
308+
Gets the set of unique node labels from the graph corresponding to the
305309
given key.
306310
307311
Parameters
@@ -352,13 +356,14 @@ def init_zip_writer(self):
352356
self.graphs[key].to_zipped_swc(self.zip_writer[key])
353357

354358
# -- Main Routine --
355-
def run(self):
359+
def run(self, path=None):
356360
"""
357361
Computes skeleton-based metrics.
358362
359363
Parameters
360364
----------
361-
None
365+
path : str, optional
366+
Path where the results will be saved. The default is None.
362367
363368
Returns
364369
-------
@@ -368,16 +373,29 @@ def run(self):
368373
"""
369374
print("\n(3) Evaluation")
370375

371-
# Split evaluation
376+
# Split detection
372377
self.detect_splits()
373378
self.quantify_splits()
374379

375-
# Merge evaluation
380+
# Merge detection
376381
self.detect_merges()
377382
self.quantify_merges()
378383

379-
# Compute metrics
380-
return self.compile_results()
384+
# Report results
385+
full_results, avg_results = self.compile_results()
386+
print(f"\nAverage Results...")
387+
for key in avg_results.keys():
388+
print(f" {key}: {round(avg_results[key], 4)}")
389+
390+
print(f"\nTotal Results...")
391+
print("# splits:", self.count_total_splits())
392+
print("# merges:", self.count_total_merges())
393+
394+
# Save results (if applicable)
395+
if path:
396+
util.save_results(path, full_results)
397+
398+
return full_results, avg_results
381399

382400
# -- Split Detection --
383401
def detect_splits(self):

0 commit comments

Comments
 (0)