Skip to content

Commit 7de2e00

Browse files
anna-grimanna-grim
andauthored
Refactor general updates (#148)
* refactor: improved metric defaults * refactor: improved NaN metric handling * refactor: improved tensorstore reader --------- Co-authored-by: anna-grim <[email protected]>
1 parent 036a3dd commit 7de2e00

File tree

3 files changed

+86
-19
lines changed

3 files changed

+86
-19
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ dynamic = ["version"]
1616

1717
dependencies = [
1818
"boto3",
19+
'cloud-volume[crackle,seg-codecs]',
1920
"google-cloud-storage",
2021
"networkx",
2122
"numpy",

src/segmentation_skeleton_metrics/skeleton_metric.py

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -363,7 +363,7 @@ def compute_split_metrics(self, key, n_split_edges):
363363
gt_rl = self.graphs[key].run_length
364364
split_rate = rl / n_splits if n_splits > 0 else np.nan
365365

366-
# Update metrics
366+
# Record metrics
367367
self.metrics.at[key, "# Splits"] = n_splits
368368
self.metrics.at[key, "Split Rate"] = split_rate
369369
self.metrics.at[key, "% Split Edges"] = round(p_split, 2)
@@ -403,7 +403,7 @@ def count_merges(self, key, kdtree):
403403
"""
404404
Counts the number of label merges for a given graph key based on
405405
whether the fragment graph corresponding to a label has a node that is
406-
more that 200ums away from the nearest point in
406+
more that 40ums away from the nearest point in
407407
"kdtree".
408408
409409
Parameters
@@ -413,20 +413,19 @@ def count_merges(self, key, kdtree):
413413
kdtree : scipy.spatial.KDTree
414414
A KD-tree built from voxels in graph corresponding to "key".
415415
"""
416-
# Iterate over fragments that intersect with GT skeleton
416+
# Iterate over fragments that intersect with GT graphs
417417
for label in self.get_node_labels(key):
418418
nodes = self.graphs[key].nodes_with_label(label)
419-
if len(nodes) > 70:
419+
if len(nodes) > 60:
420420
for label in self.label_handler.get_class(label):
421421
if label in self.fragment_ids:
422-
self.is_fragment_merge(key, label, kdtree)
422+
self.check_fragment_for_merges(key, label, kdtree)
423423

424-
def is_fragment_merge(self, key, label, kdtree):
424+
def check_fragment_for_merges(self, key, label, kdtree):
425425
"""
426-
Determines whether fragment corresponding to "label" is falsely merged
427-
to graph corresponding to "key". A fragment is said to be merged if
428-
there is a node in the fragment more than 200ums away from the nearest
429-
point in "kdtree".
426+
Checks whether the fragment corresponding to "label" has a merge
427+
mistake. A fragment has a merge mistake if it has a leaf node more
428+
than 40μm away from the ground-truth graph corresponding to "key".
430429
431430
Parameters
432431
----------
@@ -500,7 +499,7 @@ def find_merge_site(self, key, kdtree, fragment_graph, source, visited):
500499
gutil.write_graph(
501500
self.gt_graphs[key], self.merge_writer
502501
)
503-
return
502+
return None
504503

505504
def is_valid_merge(self, graph, kdtree, root):
506505
n_hits = 0
@@ -735,4 +734,3 @@ def to_local_voxels(self, key, i, offset):
735734
voxel = np.array(self.graphs[key].voxels[i])
736735
offset = np.array(offset)
737736
return tuple(voxel - offset)
738-

src/segmentation_skeleton_metrics/utils/img_util.py

Lines changed: 75 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
"""
1010

1111
from abc import ABC, abstractmethod
12+
from cloudvolume import CloudVolume
1213
from tifffile import imread
1314

1415
import io
@@ -17,6 +18,8 @@
1718
import tensorstore as ts
1819
import zipfile
1920

21+
from segmentation_skeleton_metrics.utils import util
22+
2023

2124
class ImageReader(ABC):
2225
"""
@@ -118,20 +121,43 @@ def __init__(self, img_path, driver):
118121
driver : str
119122
Storage driver needed to read the image.
120123
"""
121-
self.driver = driver
124+
self.driver = self.get_driver(img_path)
122125
super().__init__(img_path)
123126

127+
def get_driver(self, img_path):
128+
"""
129+
Gets the storage driver needed to read the image.
130+
131+
Returns
132+
-------
133+
str
134+
Storage driver needed to read the image.
135+
"""
136+
if ".zarr" in img_path:
137+
return "zarr"
138+
elif ".n5" in img_path:
139+
return "n5"
140+
elif is_neuroglancer_precomputed(img_path):
141+
return "neuroglancer_precomputed"
142+
else:
143+
raise ValueError(f"Unsupported image format: {img_path}")
144+
124145
def _load_image(self):
125146
"""
126147
Loads image using the TensorStore library.
127148
"""
149+
# Extract metadata
150+
bucket_name, path = util.parse_cloud_path(self.img_path)
151+
storage_driver = get_storage_driver(self.img_path)
152+
153+
# Load image
128154
self.img = ts.open(
129155
{
130156
"driver": self.driver,
131157
"kvstore": {
132-
"driver": "gcs",
133-
"bucket": "allen-nd-goog",
134-
"path": self.img_path,
158+
"driver": storage_driver,
159+
"bucket": bucket_name,
160+
"path": path,
135161
},
136162
"context": {
137163
"cache_pool": {"total_bytes_limit": 1000000000},
@@ -141,8 +167,6 @@ def _load_image(self):
141167
"recheck_cached_data": "open",
142168
}
143169
).result()
144-
if self.driver == "neuroglancer_precomputed":
145-
return self.img[ts.d["channel"][0]]
146170

147171
def read(self, voxel, shape):
148172
"""
@@ -229,7 +253,51 @@ def _load_zipped_image(self):
229253
self.img = imread(io.BytesIO(f.read()))
230254

231255

232-
# --- Miscellaneous ---
256+
# --- Helpers ---
257+
def get_storage_driver(img_path):
258+
"""
259+
Gets the storage driver needed to read the image.
260+
261+
Parameters
262+
----------
263+
img_path : str
264+
Image path to be checked.
265+
266+
Returns
267+
-------
268+
str
269+
Storage driver needed to read the image.
270+
"""
271+
if util.is_s3_path(img_path):
272+
return "s3"
273+
elif util.is_gcs_path(img_path):
274+
return "gcs"
275+
else:
276+
raise ValueError(f"Unsupported path type: {img_path}")
277+
278+
279+
def is_neuroglancer_precomputed(path):
280+
"""
281+
Checks if the path points to a neuroglancer precomputed dataset.
282+
283+
Parameters
284+
----------
285+
path : str
286+
Path to be checked.
287+
288+
Returns
289+
-------
290+
bool
291+
Indication of whether the path points to a neuroglancer precomputed
292+
dataset.
293+
"""
294+
try:
295+
vol = CloudVolume(path)
296+
return all(k in vol.info for k in ["data_type", "scales", "type"])
297+
except Exception:
298+
return False
299+
300+
233301
def to_physical(voxel, anisotropy):
234302
"""
235303
Converts a voxel coordinate to a physical coordinate by applying the

0 commit comments

Comments
 (0)