Skip to content

Commit cc78829

Browse files
author
anna-grim
committed
refactor: improved tensorstore reader
1 parent 7deae17 commit cc78829

File tree

3 files changed

+86
-18
lines changed

3 files changed

+86
-18
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ dynamic = ["version"]
1616

1717
dependencies = [
1818
"boto3",
19+
'cloud-volume[crackle,seg-codecs]',
1920
"google-cloud-storage",
2021
"networkx",
2122
"numpy",

src/segmentation_skeleton_metrics/skeleton_metric.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -364,7 +364,7 @@ def compute_split_metrics(self, key, n_split_edges):
364364
gt_rl = self.graphs[key].run_length
365365
split_rate = rl / n_splits if n_splits > 0 else np.nan
366366

367-
# Update metrics
367+
# Record metrics
368368
self.metrics.at[key, "# Splits"] = n_splits
369369
self.metrics.at[key, "Split Rate"] = split_rate
370370
self.metrics.at[key, "% Split Edges"] = round(p_split, 2)
@@ -404,7 +404,7 @@ def count_merges(self, key, kdtree):
404404
"""
405405
Counts the number of label merges for a given graph key based on
406406
whether the fragment graph corresponding to a label has a node that is
407-
more that 200ums away from the nearest point in
407+
more that 40ums away from the nearest point in
408408
"kdtree".
409409
410410
Parameters
@@ -414,20 +414,19 @@ def count_merges(self, key, kdtree):
414414
kdtree : scipy.spatial.KDTree
415415
A KD-tree built from voxels in graph corresponding to "key".
416416
"""
417-
# Iterate over fragments that intersect with GT skeleton
417+
# Iterate over fragments that intersect with GT graphs
418418
for label in self.get_node_labels(key):
419419
nodes = self.graphs[key].nodes_with_label(label)
420-
if len(nodes) > 70:
420+
if len(nodes) > 60:
421421
for label in self.label_handler.get_class(label):
422422
if label in self.fragment_ids:
423-
self.is_fragment_merge(key, label, kdtree)
423+
self.check_fragment_for_merges(key, label, kdtree)
424424

425-
def is_fragment_merge(self, key, label, kdtree):
425+
def check_fragment_for_merges(self, key, label, kdtree):
426426
"""
427-
Determines whether fragment corresponding to "label" is falsely merged
428-
to graph corresponding to "key". A fragment is said to be merged if
429-
there is a node in the fragment more than 200ums away from the nearest
430-
point in "kdtree".
427+
Checks whether the fragment corresponding to "label" has a merge
428+
mistake. A fragment has a merge mistake if it has a leaf node more
429+
than 40μm away from the ground-truth graph corresponding to "key".
431430
432431
Parameters
433432
----------
@@ -501,7 +500,7 @@ def find_merge_site(self, key, kdtree, fragment_graph, source, visited):
501500
gutil.write_graph(
502501
self.gt_graphs[key], self.merge_writer
503502
)
504-
return
503+
return None
505504

506505
def is_valid_merge(self, graph, kdtree, root):
507506
n_hits = 0

src/segmentation_skeleton_metrics/utils/img_util.py

Lines changed: 75 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
"""
1111

1212
from abc import ABC, abstractmethod
13+
from cloudvolume import CloudVolume
1314
from tifffile import imread
1415

1516
import io
@@ -18,6 +19,8 @@
1819
import tensorstore as ts
1920
import zipfile
2021

22+
from segmentation_skeleton_metrics.utils import util
23+
2124

2225
class ImageReader(ABC):
2326
"""
@@ -119,20 +122,43 @@ def __init__(self, img_path, driver):
119122
driver : str
120123
Storage driver needed to read the image.
121124
"""
122-
self.driver = driver
125+
self.driver = self.get_driver(img_path)
123126
super().__init__(img_path)
124127

128+
def get_driver(self, img_path):
129+
"""
130+
Gets the storage driver needed to read the image.
131+
132+
Returns
133+
-------
134+
str
135+
Storage driver needed to read the image.
136+
"""
137+
if ".zarr" in img_path:
138+
return "zarr"
139+
elif ".n5" in img_path:
140+
return "n5"
141+
elif is_neuroglancer_precomputed(img_path):
142+
return "neuroglancer_precomputed"
143+
else:
144+
raise ValueError(f"Unsupported image format: {img_path}")
145+
125146
def _load_image(self):
126147
"""
127148
Loads image using the TensorStore library.
128149
"""
150+
# Extract metadata
151+
bucket_name, path = util.parse_cloud_path(self.img_path)
152+
storage_driver = get_storage_driver(self.img_path)
153+
154+
# Load image
129155
self.img = ts.open(
130156
{
131157
"driver": self.driver,
132158
"kvstore": {
133-
"driver": "gcs",
134-
"bucket": "allen-nd-goog",
135-
"path": self.img_path,
159+
"driver": storage_driver,
160+
"bucket": bucket_name,
161+
"path": path,
136162
},
137163
"context": {
138164
"cache_pool": {"total_bytes_limit": 1000000000},
@@ -142,8 +168,6 @@ def _load_image(self):
142168
"recheck_cached_data": "open",
143169
}
144170
).result()
145-
if self.driver == "neuroglancer_precomputed":
146-
return self.img[ts.d["channel"][0]]
147171

148172
def read(self, voxel, shape):
149173
"""
@@ -230,7 +254,51 @@ def _load_zipped_image(self):
230254
self.img = imread(io.BytesIO(f.read()))
231255

232256

233-
# --- Miscellaneous ---
257+
# --- Helpers ---
258+
def get_storage_driver(img_path):
259+
"""
260+
Gets the storage driver needed to read the image.
261+
262+
Parameters
263+
----------
264+
img_path : str
265+
Image path to be checked.
266+
267+
Returns
268+
-------
269+
str
270+
Storage driver needed to read the image.
271+
"""
272+
if util.is_s3_path(img_path):
273+
return "s3"
274+
elif util.is_gcs_path(img_path):
275+
return "gcs"
276+
else:
277+
raise ValueError(f"Unsupported path type: {img_path}")
278+
279+
280+
def is_neuroglancer_precomputed(path):
281+
"""
282+
Checks if the path points to a neuroglancer precomputed dataset.
283+
284+
Parameters
285+
----------
286+
path : str
287+
Path to be checked.
288+
289+
Returns
290+
-------
291+
bool
292+
Indication of whether the path points to a neuroglancer precomputed
293+
dataset.
294+
"""
295+
try:
296+
vol = CloudVolume(path)
297+
return all(k in vol.info for k in ["data_type", "scales", "type"])
298+
except Exception:
299+
return False
300+
301+
234302
def to_physical(voxel, anisotropy):
235303
"""
236304
Converts a voxel coordinate to a physical coordinate by applying the

0 commit comments

Comments
 (0)