Skip to content

Commit 8ba623e

Browse files
Implement scalable segmentatio WIP
1 parent 43eff47 commit 8ba623e

File tree

2 files changed

+118
-10
lines changed

2 files changed

+118
-10
lines changed
Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
import os
2+
import tempfile
3+
from typing import Dict, Optional
4+
5+
import elf.parallel as parallel
6+
import numpy as np
7+
import torch
8+
9+
from elf.io import open_file
10+
from elf.wrapper import ThresholdWrapper, SimpleTransformationWrapper
11+
from elf.wrapper.base import MultiTransformationWrapper
12+
from synapse_net.inference.util import get_prediction
13+
from numpy.typing import ArrayLike
14+
15+
16+
class SelectChannel(SimpleTransformationWrapper):
17+
"""Wrapper to select a chanel from an array-like dataset object.
18+
19+
Args:
20+
volume: The array-like input dataset.
21+
channel: The channel that will be selected.
22+
"""
23+
def __init__(self, volume: np.typing.ArrayLike, channel: int):
24+
self.channel = channel
25+
super().__init__(volume, lambda x: x[self.channel], with_channels=True)
26+
27+
@property
28+
def shape(self):
29+
return self._volume.shape[1:]
30+
31+
@property
32+
def chunks(self):
33+
return self._volume.chunks[1:]
34+
35+
@property
36+
def ndim(self):
37+
return self._volume.ndim - 1
38+
39+
40+
# TODO support resizing via the wrapper
41+
def scalable_segmentation(
42+
input_: ArrayLike,
43+
output: ArrayLike,
44+
model: torch.nn.Module,
45+
tiling: Optional[Dict[str, Dict[str, int]]] = None,
46+
seed_threshold: float = 0.5,
47+
min_size: int = 500,
48+
verbose: bool = True,
49+
) -> None:
50+
"""Lorem ipsum
51+
52+
Args:
53+
input_:
54+
output:
55+
model: The model.
56+
tiling: The tiling configuration for the prediction.
57+
min_size: The minimum size of a vesicle to be considered.
58+
verbose: Whether to print timing information.
59+
"""
60+
assert model.out_channels == 2
61+
62+
# Create a temporary directory for storing the predictions.
63+
with tempfile.TemporaryDirectory() as tmp_dir:
64+
tmp_pred = os.path.join(tmp_dir, "prediction.n5")
65+
f = open_file(tmp_pred, mode="a")
66+
67+
# Create the dataset for storing the prediction.
68+
chunks = (128,) * 3
69+
pred_shape = (2,) + input_.shape
70+
pred_chunks = (1,) + chunks
71+
pred = f.create_dataset("pred", shape=pred_shape, dtype="float32", chunks=pred_chunks)
72+
73+
# Run the prediction.
74+
get_prediction(input_, prediction=pred, tiling=tiling, model=model, verbose=verbose)
75+
76+
# Create wrappers for selecting the foreground and the boundary channel.
77+
foreground = SelectChannel(pred, 0)
78+
boundaries = SelectChannel(pred, 1)
79+
80+
# Create temporary storage for the seeds.
81+
tmp_seeds = os.path.join(tmp_dir, "seeds.n5")
82+
f = open_file(tmp_seeds, mode="a")
83+
seeds = f.create_dataset("seeds", shape=input_.shape, dtype="uint64", chunks=chunks)
84+
85+
# Create wrappers for subtracting and thresholding boundary subtracted from the foreground.
86+
# And then compute the seeds based on this.
87+
seed_input = ThresholdWrapper(
88+
MultiTransformationWrapper(np.subtract, foreground, boundaries), seed_threshold
89+
)
90+
parallel.label(seed_input, seeds, verbose=verbose)
91+
92+
# Run watershed to extend back from the seeds to the boundaries.
93+
parallel.seeded_watershed(
94+
boundaries, seeds=seeds, out=output, verbose=verbose, block_shape=chunks, halo=3 * (16,)
95+
)
96+
97+
# Run the size filter.
98+
if min_size > 0:
99+
parallel.size_filter(output, output, min_size=min_size, verbose=verbose)

synapse_net/inference/util.py

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
# import xarray
1919

2020
from elf.io import open_file
21+
from numpy.typing import ArrayLike
2122
from scipy.ndimage import binary_closing
2223
from skimage.measure import regionprops
2324
from skimage.morphology import remove_small_holes
@@ -100,15 +101,16 @@ def rescale_output(self, output, is_segmentation):
100101

101102

102103
def get_prediction(
103-
input_volume: np.ndarray, # [z, y, x]
104+
input_volume: ArrayLike, # [z, y, x]
104105
tiling: Optional[Dict[str, Dict[str, int]]], # {"tile": {"z": int, ...}, "halo": {"z": int, ...}}
105106
model_path: Optional[str] = None,
106107
model: Optional[torch.nn.Module] = None,
107108
verbose: bool = True,
108109
with_channels: bool = False,
109110
channels_to_standardize: Optional[List[int]] = None,
110-
mask: Optional[np.ndarray] = None,
111-
) -> np.ndarray:
111+
mask: Optional[ArrayLike] = None,
112+
prediction: Optional[ArrayLike] = None,
113+
) -> ArrayLike:
112114
"""Run prediction on a given volume.
113115
114116
This function will automatically choose the correct prediction implementation,
@@ -124,6 +126,8 @@ def get_prediction(
124126
channels_to_standardize: List of channels to standardize. Defaults to None.
125127
mask: Optional binary mask. If given, the prediction will only be run in
126128
the foreground region of the mask.
129+
prediction: An array like object for writing the prediction.
130+
If not given, the prediction will be computed in moemory.
127131
128132
Returns:
129133
The predicted volume.
@@ -174,21 +178,23 @@ def get_prediction(
174178
for dim in tiling["tile"]:
175179
updated_tiling["tile"][dim] = tiling["tile"][dim] - 2 * tiling["halo"][dim]
176180
# print(f"updated_tiling {updated_tiling}")
177-
pred = get_prediction_torch_em(
178-
input_volume, updated_tiling, model_path, model, verbose, with_channels, mask=mask
181+
prediction = get_prediction_torch_em(
182+
input_volume, updated_tiling, model_path, model, verbose, with_channels,
183+
mask=mask, prediction=prediction,
179184
)
180185

181-
return pred
186+
return prediction
182187

183188

184189
def get_prediction_torch_em(
185-
input_volume: np.ndarray, # [z, y, x]
190+
input_volume: ArrayLike, # [z, y, x]
186191
tiling: Dict[str, Dict[str, int]], # {"tile": {"z": int, ...}, "halo": {"z": int, ...}}
187192
model_path: Optional[str] = None,
188193
model: Optional[torch.nn.Module] = None,
189194
verbose: bool = True,
190195
with_channels: bool = False,
191-
mask: Optional[np.ndarray] = None,
196+
mask: Optional[ArrayLike] = None,
197+
prediction: Optional[ArrayLike] = None,
192198
) -> np.ndarray:
193199
"""Run prediction using torch-em on a given volume.
194200
@@ -201,6 +207,8 @@ def get_prediction_torch_em(
201207
with_channels: Whether to predict with channels.
202208
mask: Optional binary mask. If given, the prediction will only be run in
203209
the foreground region of the mask.
210+
prediction: An array like object for writing the prediction.
211+
If not given, the prediction will be computed in moemory.
204212
205213
Returns:
206214
The predicted volume.
@@ -234,14 +242,15 @@ def get_prediction_torch_em(
234242
print("Run prediction with mask.")
235243
mask = mask.astype("bool")
236244

237-
pred = predict_with_halo(
245+
prediction = predict_with_halo(
238246
input_volume, model, gpu_ids=[device],
239247
block_shape=block_shape, halo=halo,
240248
preprocess=None, with_channels=with_channels, mask=mask,
249+
output=prediction,
241250
)
242251
if verbose:
243252
print("Prediction time in", time.time() - t0, "s")
244-
return pred
253+
return prediction
245254

246255

247256
def _get_file_paths(input_path, ext=".mrc"):

0 commit comments

Comments
 (0)