Skip to content

Commit 9907e9a

Browse files
anna-grimanna-grim
andauthored
refactor: updated anisotropy handling (#95)
Co-authored-by: anna-grim <[email protected]>
1 parent 1879fe3 commit 9907e9a

File tree

8 files changed

+99
-73
lines changed

8 files changed

+99
-73
lines changed

demo/demo.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
import os
2-
31
import numpy as np
42
from tifffile import imread
53
from xlwt import Workbook

demo/evaluation_results.xls

-4 KB
Binary file not shown.

demo/merged_ids-segmentation.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Label - xyz

src/segmentation_skeleton_metrics/skeleton_metric.py

Lines changed: 34 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -7,22 +7,26 @@
77
88
"""
99

10-
import os
11-
from collections import defaultdict
10+
1211
from concurrent.futures import ThreadPoolExecutor, as_completed
1312
from copy import deepcopy
13+
from scipy.spatial import KDTree
1414
from time import time
15+
from tqdm import tqdm
1516
from zipfile import ZipFile
1617

1718
import networkx as nx
1819
import numpy as np
20+
import os
1921
import tensorstore as ts
20-
from scipy.spatial import KDTree
21-
from tqdm import tqdm
2222

23-
from segmentation_skeleton_metrics import graph_utils as gutils
24-
from segmentation_skeleton_metrics import split_detection, swc_utils, utils
25-
from segmentation_skeleton_metrics.graph_utils import to_xyz_array
23+
from segmentation_skeleton_metrics import split_detection
24+
from segmentation_skeleton_metrics.utils import (
25+
graph_util as gutil,
26+
swc_util,
27+
util
28+
)
29+
from segmentation_skeleton_metrics.utils.graph_util import to_array
2630

2731
MERGE_DIST_THRESHOLD = 100
2832
MIN_CNT = 40
@@ -62,7 +66,7 @@ def __init__(
6266
Parameters
6367
----------
6468
gt_pointer : dict/str/list[str]
65-
Pointer to ground truth swcs, see "swc_utils.Reader" for further
69+
Pointer to ground truth swcs, see "swc_util.Reader" for further
6670
documentation. Note these swc files are assumed to be stored in
6771
image coordinates.
6872
pred_labels : numpy.ndarray or tensorstore.TensorStore
@@ -75,12 +79,12 @@ def __init__(
7579
that were merged into a single segment. The default is None.
7680
fragments_pointer : dict/str/list[str], optional
7781
Pointer to fragments (i.e. swcs) corresponding to "pred_labels",
78-
see "swc_utils.Reader" for further documentation. Note these swc
82+
see "swc_util.Reader" for further documentation. Note these swc
7983
files may be stored in either world or image coordinates. If the
8084
swcs are stored in world coordinates, then provide the world to
8185
image coordinates anisotropy factor. Note the filename of each swc
8286
is assumed to "segment_id.swc" where segment_id cooresponds to the
83-
segment id from "pred_labels". The default is None.
87+
segment id from "pred_labels". The default is None.
8488
output_dir : str, optional
8589
Path to directory that mistake sites are written to. The default
8690
is None.
@@ -103,7 +107,7 @@ def __init__(
103107
104108
"""
105109
# Instance attributes
106-
self.anisotropy = [1.0 / a for a in anisotropy]
110+
self.anisotropy = anisotropy
107111
self.connections_path = connections_path
108112
self.output_dir = output_dir
109113
self.preexisting_merges = preexisting_merges
@@ -142,7 +146,7 @@ def init_label_map(self, path):
142146
"""
143147
if path:
144148
assert self.valid_labels is not None, "Must provide valid labels!"
145-
self.label_map, self.inverse_label_map = utils.init_label_map(
149+
self.label_map, self.inverse_label_map = util.init_label_map(
146150
path, self.valid_labels
147151
)
148152
else:
@@ -166,14 +170,14 @@ def init_graphs(self, paths):
166170
167171
"""
168172
# Read graphs
169-
self.graphs = swc_utils.Reader().load(paths)
173+
self.graphs = swc_util.Reader().load(paths)
170174
self.fragment_graphs = None
171175

172176
# Label nodes
173177
self.key_to_label_to_nodes = dict() # {id: {label: nodes}}
174178
for key in tqdm(self.graphs, desc="Labeling Graphs"):
175179
self.set_node_labels(key)
176-
self.key_to_label_to_nodes[key] = gutils.init_label_to_nodes(
180+
self.key_to_label_to_nodes[key] = gutil.init_label_to_nodes(
177181
self.graphs[key]
178182
)
179183

@@ -318,7 +322,7 @@ def load_fragments(self, fragments_pointer):
318322
319323
"""
320324
# Read fragments
321-
reader = swc_utils.Reader(anisotropy=self.anisotropy, min_size=40)
325+
reader = swc_util.Reader(anisotropy=self.anisotropy, min_size=40)
322326
fragment_graphs = reader.load(fragments_pointer)
323327
self.fragment_ids = set(fragment_graphs.keys())
324328

@@ -349,13 +353,13 @@ def init_zip_writer(self):
349353
"""
350354
# Initialize output directory
351355
output_dir = os.path.join(self.output_dir, "projections")
352-
utils.mkdir(output_dir)
356+
util.mkdir(output_dir)
353357

354358
# Save intial graphs
355359
self.zip_writer = dict()
356360
for key in self.graphs.keys():
357361
self.zip_writer[key] = ZipFile(f"{output_dir}/{key}.zip", "w")
358-
swc_utils.to_zipped_swc(
362+
swc_util.to_zipped_swc(
359363
self.zip_writer[key], self.graphs[key],
360364
)
361365

@@ -419,7 +423,7 @@ def adjust_metrics(self, key):
419423

420424
# Adjust metrics
421425
n_edges = subgraph.number_of_edges()
422-
rls = gutils.compute_run_lengths(subgraph)
426+
rls = gutil.compute_run_lengths(subgraph)
423427
self.graphs[key].graph["run_length"] -= np.sum(rls)
424428
self.graphs[key].graph["n_edges"] -= n_edges
425429

@@ -448,13 +452,13 @@ def detect_splits(self):
448452
graph = split_detection.run(graph, self.graphs[key])
449453

450454
# Update graph by removing omits (i.e. nodes labeled 0)
451-
self.graphs[key] = gutils.delete_nodes(graph, 0)
452-
self.key_to_label_to_nodes[key] = gutils.init_label_to_nodes(
455+
self.graphs[key] = gutil.delete_nodes(graph, 0)
456+
self.key_to_label_to_nodes[key] = gutil.init_label_to_nodes(
453457
self.graphs[key]
454458
)
455459

456460
# Report runtime
457-
t, unit = utils.time_writer(time() - t0)
461+
t, unit = util.time_writer(time() - t0)
458462
print(f"Runtime: {round(t, 2)} {unit}\n")
459463

460464
def quantify_splits(self):
@@ -478,7 +482,7 @@ def quantify_splits(self):
478482
n_pred_edges = self.graphs[key].number_of_edges()
479483
n_target_edges = self.graphs[key].graph["n_edges"]
480484

481-
self.split_cnt[key] = gutils.count_splits(self.graphs[key])
485+
self.split_cnt[key] = gutil.count_splits(self.graphs[key])
482486
self.omit_cnts[key] = n_target_edges - n_pred_edges
483487
self.omit_percent[key] = 1 - n_pred_edges / n_target_edges
484488

@@ -507,7 +511,7 @@ def detect_merges(self):
507511
if self.fragment_graphs:
508512
for key, graph in self.graphs.items():
509513
if graph.number_of_nodes() > 0:
510-
kdtree = KDTree(gutils.to_xyz_array(graph))
514+
kdtree = KDTree(gutil.to_array(graph))
511515
self.count_merges(key, kdtree)
512516

513517
# Process merges
@@ -574,20 +578,20 @@ def is_fragment_merge(self, key, label, kdtree):
574578
None
575579
576580
"""
577-
for xyz in to_xyz_array(self.fragment_graphs[label])[::5]:
578-
if kdtree.query(xyz)[0] > MERGE_DIST_THRESHOLD:
581+
for voxel in to_array(self.fragment_graphs[label])[::2]:
582+
if kdtree.query(voxel)[0] > MERGE_DIST_THRESHOLD:
579583
# Check whether to take inverse of label
580584
if self.inverse_label_map:
581585
equivalent_label = self.label_map[label]
582586
else:
583587
equivalent_label = label
584588

585589
# Record merge mistake
586-
xyz = utils.to_world(xyz)
590+
xyz = util.to_world(voxel)
587591
self.merge_cnt[key] += 1
588592
self.merged_labels.add((key, equivalent_label, tuple(xyz)))
589593
if self.save_projections:
590-
swc_utils.to_zipped_swc(
594+
swc_util.to_zipped_swc(
591595
self.zip_writer[key], self.fragment_graphs[label]
592596
)
593597
return
@@ -871,7 +875,7 @@ def compute_erl(self):
871875
total_run_length = 0
872876
for key in self.graphs:
873877
run_length = self.get_run_length(key)
874-
run_lengths = gutils.compute_run_lengths(self.graphs[key])
878+
run_lengths = gutil.compute_run_lengths(self.graphs[key])
875879
total_run_length += run_length
876880
wgt = run_lengths / max(np.sum(run_lengths), 1)
877881

@@ -924,7 +928,7 @@ def list_metrics(self):
924928
]
925929
return metrics
926930

927-
# -- Utils --
931+
# -- util --
928932
def init_counter(self):
929933
"""
930934
Initializes a dictionary that is used to count some type of mistake
@@ -943,7 +947,7 @@ def init_counter(self):
943947
return {key: 0 for key in self.graphs}
944948

945949

946-
# -- utils --
950+
# -- util --
947951
def find_sites(graphs, get_labels):
948952
"""
949953
Detects merges between ground truth graphs which are considered to be

src/segmentation_skeleton_metrics/split_detection.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,7 @@
1111

1212
import networkx as nx
1313

14-
from segmentation_skeleton_metrics import graph_utils as gutils
15-
from segmentation_skeleton_metrics import utils
14+
from segmentation_skeleton_metrics.utils import graph_util as gutil, util
1615

1716

1817
def run(target_graph, labeled_graph):
@@ -33,7 +32,7 @@ def run(target_graph, labeled_graph):
3332
Labeled graph with omit and split edges removed.
3433
3534
"""
36-
r = gutils.sample_leaf(labeled_graph)
35+
r = gutil.sample_leaf(labeled_graph)
3736
dfs_edges = list(nx.dfs_edges(labeled_graph, source=r))
3837
while len(dfs_edges) > 0:
3938
# Visit edge
@@ -94,7 +93,7 @@ def is_zero_misalignment(target_graph, labeled_graph, dfs_edges, nb, root):
9493
# Add nbs to queue
9594
nbs = target_graph.neighbors(j)
9695
for k in [k for k in nbs if k not in visited]:
97-
if utils.check_edge(dfs_edges, (j, k)):
96+
if util.check_edge(dfs_edges, (j, k)):
9897
queue.append(k)
9998
dfs_edges = remove_edge(dfs_edges, (j, k))
10099
elif k == nb:
@@ -103,7 +102,7 @@ def is_zero_misalignment(target_graph, labeled_graph, dfs_edges, nb, root):
103102
# Upd zero nodes
104103
if len(collision_labels) == 1:
105104
label = collision_labels.pop()
106-
labeled_graph = gutils.upd_labels(labeled_graph, visited, label)
105+
labeled_graph = gutil.upd_labels(labeled_graph, visited, label)
107106

108107
return dfs_edges, labeled_graph
109108

@@ -150,7 +149,7 @@ def is_nonzero_misalignment(target_graph, labeled_graph, dfs_edges, nb, root):
150149
visited.add(j)
151150
if label_j == origin_label and len(queue) == 0:
152151
# misalignment
153-
labeled_graph = gutils.upd_labels(
152+
labeled_graph = gutil.upd_labels(
154153
labeled_graph, visited, origin_label
155154
)
156155
return dfs_edges, labeled_graph
@@ -171,7 +170,7 @@ def is_nonzero_misalignment(target_graph, labeled_graph, dfs_edges, nb, root):
171170
return dfs_edges, labeled_graph
172171

173172

174-
# -- utils --
173+
# -- util --
175174
def remove_edge(dfs_edges, edge):
176175
"""
177176
Checks whether "edge" is in "dfs_edges" and removes it.

src/segmentation_skeleton_metrics/graph_utils.py renamed to src/segmentation_skeleton_metrics/utils/graph_util.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,9 @@
1414
import numpy as np
1515
from scipy.spatial.distance import euclidean as get_dist
1616

17-
from segmentation_skeleton_metrics import utils
17+
from segmentation_skeleton_metrics.utils import util
18+
19+
ANISOTROPY = (0.748, 0.748, 1.0)
1820

1921

2022
# --- Update graph ---
@@ -151,14 +153,14 @@ def compute_run_length(graph):
151153
"""
152154
path_length = 0
153155
for i, j in nx.dfs_edges(graph):
154-
xyz_1 = utils.to_world(graph.nodes[i]["xyz"])
155-
xyz_2 = utils.to_world(graph.nodes[j]["xyz"])
156+
xyz_1 = util.to_physical(graph.nodes[i]["xyz"], ANISOTROPY)
157+
xyz_2 = util.to_physical(graph.nodes[j]["xyz"], ANISOTROPY)
156158
path_length += get_dist(xyz_1, xyz_2)
157159
return path_length
158160

159161

160162
# -- miscellaneous --
161-
def to_xyz_array(graph):
163+
def to_array(graph):
162164
"""
163165
Converts node coordinates from a graph into a NumPy array.
164166

0 commit comments

Comments
 (0)