Skip to content

Commit 9c252ed

Browse files
Implement scalable segmentation (#134)
Implement scalable segmentation function
1 parent 43eff47 commit 9c252ed

File tree

3 files changed

+208
-26
lines changed

3 files changed

+208
-26
lines changed
Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
import os
2+
import tempfile
3+
from typing import Dict, List, Optional
4+
5+
import elf.parallel as parallel
6+
import numpy as np
7+
import torch
8+
9+
from elf.io import open_file
10+
from elf.wrapper import ThresholdWrapper, SimpleTransformationWrapper
11+
from elf.wrapper.base import MultiTransformationWrapper
12+
from elf.wrapper.resized_volume import ResizedVolume
13+
from numpy.typing import ArrayLike
14+
from synapse_net.inference.util import get_prediction
15+
16+
17+
class SelectChannel(SimpleTransformationWrapper):
18+
"""Wrapper to select a chanel from an array-like dataset object.
19+
20+
Args:
21+
volume: The array-like input dataset.
22+
channel: The channel that will be selected.
23+
"""
24+
def __init__(self, volume: np.typing.ArrayLike, channel: int):
25+
self.channel = channel
26+
super().__init__(volume, lambda x: x[self.channel], with_channels=True)
27+
28+
@property
29+
def shape(self):
30+
return self._volume.shape[1:]
31+
32+
@property
33+
def chunks(self):
34+
return self._volume.chunks[1:]
35+
36+
@property
37+
def ndim(self):
38+
return self._volume.ndim - 1
39+
40+
41+
def _run_segmentation(pred, output, seeds, chunks, seed_threshold, min_size, verbose, original_shape):
42+
# Create wrappers for selecting the foreground and the boundary channel.
43+
foreground = SelectChannel(pred, 0)
44+
boundaries = SelectChannel(pred, 1)
45+
46+
# Create wrappers for subtracting and thresholding boundary subtracted from the foreground.
47+
# And then compute the seeds based on this.
48+
seed_input = ThresholdWrapper(
49+
MultiTransformationWrapper(np.subtract, foreground, boundaries), seed_threshold
50+
)
51+
parallel.label(seed_input, seeds, verbose=verbose, block_shape=chunks)
52+
53+
# Run watershed to extend back from the seeds to the boundaries.
54+
mask = ThresholdWrapper(foreground, 0.5)
55+
56+
# Resize if necessary.
57+
if original_shape is not None:
58+
boundaries = ResizedVolume(boundaries, original_shape, order=1)
59+
seeds = ResizedVolume(seeds, original_shape, order=0)
60+
mask = ResizedVolume(mask, original_shape, order=0)
61+
62+
parallel.seeded_watershed(
63+
boundaries, seeds=seeds, out=output, verbose=verbose, mask=mask, block_shape=chunks, halo=3 * (16,)
64+
)
65+
66+
# Run the size filter.
67+
if min_size > 0:
68+
parallel.size_filter(output, output, min_size=min_size, verbose=verbose, block_shape=chunks)
69+
70+
71+
def scalable_segmentation(
72+
input_: ArrayLike,
73+
output: ArrayLike,
74+
model: torch.nn.Module,
75+
tiling: Optional[Dict[str, Dict[str, int]]] = None,
76+
scale: Optional[List[float]] = None,
77+
seed_threshold: float = 0.5,
78+
min_size: int = 500,
79+
prediction: Optional[ArrayLike] = None,
80+
verbose: bool = True,
81+
mask: Optional[ArrayLike] = None,
82+
) -> None:
83+
"""Run segmentation based on a prediction with foreground and boundary channel.
84+
85+
This function first subtracts the boundary prediction from the foreground prediction,
86+
then applies a threshold, connected components, and a watershed to fit the components
87+
back to the foreground. All processing steps are implemented in a scalable fashion,
88+
so that the function runs for large input volumes.
89+
90+
Args:
91+
input_: The input data.
92+
output: The array for storing the output segmentation.
93+
Can be a numpy array, a zarr array, or similar.
94+
model: The model for prediction.
95+
tiling: The tiling configuration for the prediction.
96+
scale: The scale factor to use for rescaling the input volume before prediction.
97+
seed_threshold: The threshold applied before computing connected components.
98+
min_size: The minimum size of a vesicle to be considered.
99+
prediction: The array for storing the prediction.
100+
If given, this can be a numpy array, a zarr array, or similar
101+
If not given will be stored in a temporary n5 array.
102+
verbose: Whether to print timing information.
103+
"""
104+
if mask is not None:
105+
raise NotImplementedError
106+
assert model.out_channels == 2
107+
108+
# Create a temporary directory for storing the predictions.
109+
chunks = (128,) * 3
110+
with tempfile.TemporaryDirectory() as tmp_dir:
111+
112+
if scale is None or np.allclose(scale, 1.0, atol=1e-3):
113+
original_shape = None
114+
else:
115+
original_shape = input_.shape
116+
new_shape = tuple(int(sh * sc) for sh, sc in zip(input_.shape, scale))
117+
input_ = ResizedVolume(input_, shape=new_shape, order=1)
118+
119+
if prediction is None:
120+
# Create the dataset for storing the prediction.
121+
tmp_pred = os.path.join(tmp_dir, "prediction.n5")
122+
f = open_file(tmp_pred, mode="a")
123+
pred_shape = (2,) + input_.shape
124+
pred_chunks = (1,) + chunks
125+
prediction = f.create_dataset("pred", shape=pred_shape, dtype="float32", chunks=pred_chunks)
126+
else:
127+
assert prediction.shape[0] == 2
128+
assert prediction.shape[1:] == input_.shape
129+
130+
# Create temporary storage for the seeds.
131+
tmp_seeds = os.path.join(tmp_dir, "seeds.n5")
132+
f = open_file(tmp_seeds, mode="a")
133+
seeds = f.create_dataset("seeds", shape=input_.shape, dtype="uint64", chunks=chunks)
134+
135+
# Run prediction and segmentation.
136+
get_prediction(input_, prediction=prediction, tiling=tiling, model=model, verbose=verbose)
137+
_run_segmentation(prediction, output, seeds, chunks, seed_threshold, min_size, verbose, original_shape)

synapse_net/inference/util.py

Lines changed: 48 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
# import xarray
1919

2020
from elf.io import open_file
21+
from numpy.typing import ArrayLike
2122
from scipy.ndimage import binary_closing
2223
from skimage.measure import regionprops
2324
from skimage.morphology import remove_small_holes
@@ -99,16 +100,32 @@ def rescale_output(self, output, is_segmentation):
99100
return output
100101

101102

103+
def _preprocess(input_volume, with_channels, channels_to_standardize):
104+
# We standardize the data for the whole volume beforehand.
105+
# If we have channels then the standardization is done independently per channel.
106+
if with_channels:
107+
input_volume = input_volume.astype(np.float32, copy=False)
108+
# TODO Check that this is the correct axis.
109+
if channels_to_standardize is None: # assume all channels
110+
channels_to_standardize = range(input_volume.shape[0])
111+
for ch in channels_to_standardize:
112+
input_volume[ch] = torch_em.transform.raw.standardize(input_volume[ch])
113+
else:
114+
input_volume = torch_em.transform.raw.standardize(input_volume)
115+
return input_volume
116+
117+
102118
def get_prediction(
103-
input_volume: np.ndarray, # [z, y, x]
119+
input_volume: ArrayLike, # [z, y, x]
104120
tiling: Optional[Dict[str, Dict[str, int]]], # {"tile": {"z": int, ...}, "halo": {"z": int, ...}}
105121
model_path: Optional[str] = None,
106122
model: Optional[torch.nn.Module] = None,
107123
verbose: bool = True,
108124
with_channels: bool = False,
109125
channels_to_standardize: Optional[List[int]] = None,
110-
mask: Optional[np.ndarray] = None,
111-
) -> np.ndarray:
126+
mask: Optional[ArrayLike] = None,
127+
prediction: Optional[ArrayLike] = None,
128+
) -> ArrayLike:
112129
"""Run prediction on a given volume.
113130
114131
This function will automatically choose the correct prediction implementation,
@@ -124,6 +141,8 @@ def get_prediction(
124141
channels_to_standardize: List of channels to standardize. Defaults to None.
125142
mask: Optional binary mask. If given, the prediction will only be run in
126143
the foreground region of the mask.
144+
prediction: An array like object for writing the prediction.
145+
If not given, the prediction will be computed in moemory.
127146
128147
Returns:
129148
The predicted volume.
@@ -140,17 +159,11 @@ def get_prediction(
140159
if tiling is None:
141160
tiling = get_default_tiling()
142161

143-
# We standardize the data for the whole volume beforehand.
144-
# If we have channels then the standardization is done independently per channel.
145-
if with_channels:
146-
input_volume = input_volume.astype(np.float32, copy=False)
147-
# TODO Check that this is the correct axis.
148-
if channels_to_standardize is None: # assume all channels
149-
channels_to_standardize = range(input_volume.shape[0])
150-
for ch in channels_to_standardize:
151-
input_volume[ch] = torch_em.transform.raw.standardize(input_volume[ch])
152-
else:
153-
input_volume = torch_em.transform.raw.standardize(input_volume)
162+
# Normalize the whole input volume if it is a numpy array.
163+
# Otherwise we have a zarr array or similar as input, and can't normalize it en-block.
164+
# Normalization will be applied later per block in this case.
165+
if isinstance(input_volume, np.ndarray):
166+
input_volume = _preprocess(input_volume, with_channels, channels_to_standardize)
154167

155168
# Run prediction with the bioimage.io library.
156169
if is_bioimageio:
@@ -174,21 +187,23 @@ def get_prediction(
174187
for dim in tiling["tile"]:
175188
updated_tiling["tile"][dim] = tiling["tile"][dim] - 2 * tiling["halo"][dim]
176189
# print(f"updated_tiling {updated_tiling}")
177-
pred = get_prediction_torch_em(
178-
input_volume, updated_tiling, model_path, model, verbose, with_channels, mask=mask
190+
prediction = get_prediction_torch_em(
191+
input_volume, updated_tiling, model_path, model, verbose, with_channels,
192+
mask=mask, prediction=prediction,
179193
)
180194

181-
return pred
195+
return prediction
182196

183197

184198
def get_prediction_torch_em(
185-
input_volume: np.ndarray, # [z, y, x]
199+
input_volume: ArrayLike, # [z, y, x]
186200
tiling: Dict[str, Dict[str, int]], # {"tile": {"z": int, ...}, "halo": {"z": int, ...}}
187201
model_path: Optional[str] = None,
188202
model: Optional[torch.nn.Module] = None,
189203
verbose: bool = True,
190204
with_channels: bool = False,
191-
mask: Optional[np.ndarray] = None,
205+
mask: Optional[ArrayLike] = None,
206+
prediction: Optional[ArrayLike] = None,
192207
) -> np.ndarray:
193208
"""Run prediction using torch-em on a given volume.
194209
@@ -201,6 +216,8 @@ def get_prediction_torch_em(
201216
with_channels: Whether to predict with channels.
202217
mask: Optional binary mask. If given, the prediction will only be run in
203218
the foreground region of the mask.
219+
prediction: An array like object for writing the prediction.
220+
If not given, the prediction will be computed in moemory.
204221
205222
Returns:
206223
The predicted volume.
@@ -234,14 +251,16 @@ def get_prediction_torch_em(
234251
print("Run prediction with mask.")
235252
mask = mask.astype("bool")
236253

237-
pred = predict_with_halo(
254+
preprocess = None if isinstance(input_volume, np.ndarray) else torch_em.transform.raw.standardize
255+
prediction = predict_with_halo(
238256
input_volume, model, gpu_ids=[device],
239257
block_shape=block_shape, halo=halo,
240-
preprocess=None, with_channels=with_channels, mask=mask,
258+
preprocess=preprocess, with_channels=with_channels, mask=mask,
259+
output=prediction,
241260
)
242261
if verbose:
243262
print("Prediction time in", time.time() - t0, "s")
244-
return pred
263+
return prediction
245264

246265

247266
def _get_file_paths(input_path, ext=".mrc"):
@@ -325,6 +344,7 @@ def inference_helper(
325344
output_key: Optional[str] = None,
326345
model_resolution: Optional[Tuple[float, float, float]] = None,
327346
scale: Optional[Tuple[float, float, float]] = None,
347+
allocate_output: bool = False,
328348
) -> None:
329349
"""Helper function to run segmentation for mrc files.
330350
@@ -347,6 +367,7 @@ def inference_helper(
347367
model_resolution: The resolution / voxel size to which the inputs should be scaled for prediction.
348368
If given, the scaling factor will automatically be determined based on the voxel_size of the input data.
349369
scale: Fixed factor for scaling the model inputs. Cannot be passed together with 'model_resolution'.
370+
allocate_output: Whether to allocate the output for the segmentation function.
350371
"""
351372
if (scale is not None) and (model_resolution is not None):
352373
raise ValueError("You must not provide both 'scale' and 'model_resolution' arguments.")
@@ -412,7 +433,11 @@ def inference_helper(
412433
this_scale = _derive_scale(img_path, model_resolution)
413434

414435
# Run the segmentation.
415-
segmentation = segmentation_function(input_volume, mask=mask, scale=this_scale)
436+
if allocate_output:
437+
segmentation = np.zeros(input_volume.shape, dtype="uint32")
438+
segmentation_function(input_volume, output=segmentation, mask=mask, scale=this_scale)
439+
else:
440+
segmentation = segmentation_function(input_volume, mask=mask, scale=this_scale)
416441

417442
# Write the result to tif or h5.
418443
os.makedirs(os.path.split(output_path)[0], exist_ok=True)

synapse_net/tools/cli.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import torch_em
77
from ..imod.to_imod import export_helper, write_segmentation_to_imod_as_points, write_segmentation_to_imod
88
from ..inference.inference import _get_model_registry, get_model, get_model_training_resolution, run_segmentation
9+
from ..inference.scalable_segmentation import scalable_segmentation
910
from ..inference.util import inference_helper, parse_tiling
1011

1112

@@ -152,6 +153,10 @@ def segmentation_cli():
152153
"--verbose", "-v", action="store_true",
153154
help="Whether to print verbose information about the segmentation progress."
154155
)
156+
parser.add_argument(
157+
"--scalable", action="store_true", help="Use the scalable segmentation implementation. "
158+
"Currently this only works for vesicles, mitochondria, or active zones."
159+
)
155160
args = parser.parse_args()
156161

157162
if args.checkpoint is None:
@@ -181,11 +186,26 @@ def segmentation_cli():
181186
model_resolution = None
182187
scale = (2 if is_2d else 3) * (args.scale,)
183188

184-
segmentation_function = partial(
185-
run_segmentation, model=model, model_type=args.model, verbose=args.verbose, tiling=tiling,
186-
)
189+
if args.scalable:
190+
if not args.model.startswith(("vesicle", "mito", "active")):
191+
raise ValueError(
192+
"The scalable segmentation implementation is currently only supported for "
193+
f"vesicles, mitochondria, or active zones, not for {args.model}."
194+
)
195+
segmentation_function = partial(
196+
scalable_segmentation, model=model, tiling=tiling, verbose=args.verbose
197+
)
198+
allocate_output = True
199+
200+
else:
201+
segmentation_function = partial(
202+
run_segmentation, model=model, model_type=args.model, verbose=args.verbose, tiling=tiling,
203+
)
204+
allocate_output = False
205+
187206
inference_helper(
188207
args.input_path, args.output_path, segmentation_function,
189208
mask_input_path=args.mask_path, force=args.force, data_ext=args.data_ext,
190209
output_key=args.segmentation_key, model_resolution=model_resolution, scale=scale,
210+
allocate_output=allocate_output
191211
)

0 commit comments

Comments
 (0)