Skip to content

Commit 75a5ed2

Browse files
Update scalable seg
1 parent 09c4d40 commit 75a5ed2

File tree

2 files changed

+53
-17
lines changed

2 files changed

+53
-17
lines changed

synapse_net/inference/scalable_segmentation.py

Lines changed: 48 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import os
22
import tempfile
3-
from typing import Dict, Optional
3+
from typing import Dict, List, Optional
44

55
import elf.parallel as parallel
66
import numpy as np
@@ -9,8 +9,9 @@
99
from elf.io import open_file
1010
from elf.wrapper import ThresholdWrapper, SimpleTransformationWrapper
1111
from elf.wrapper.base import MultiTransformationWrapper
12-
from synapse_net.inference.util import get_prediction
12+
from elf.wrapper.resized_volume import ResizedVolume
1313
from numpy.typing import ArrayLike
14+
from synapse_net.inference.util import get_prediction
1415

1516

1617
class SelectChannel(SimpleTransformationWrapper):
@@ -37,7 +38,7 @@ def ndim(self):
3738
return self._volume.ndim - 1
3839

3940

40-
def _run_segmentation(pred, output, seeds, chunks, seed_threshold, min_size, verbose):
41+
def _run_segmentation(pred, output, seeds, chunks, seed_threshold, min_size, verbose, original_shape):
4142
# Create wrappers for selecting the foreground and the boundary channel.
4243
foreground = SelectChannel(pred, 0)
4344
boundaries = SelectChannel(pred, 1)
@@ -51,6 +52,13 @@ def _run_segmentation(pred, output, seeds, chunks, seed_threshold, min_size, ver
5152

5253
# Run watershed to extend back from the seeds to the boundaries.
5354
mask = ThresholdWrapper(foreground, 0.5)
55+
56+
# Resize if necessary.
57+
if original_shape is not None:
58+
boundaries = ResizedVolume(boundaries, original_shape, order=1)
59+
seeds = ResizedVolume(seeds, original_shape, order=0)
60+
mask = ResizedVolume(mask, original_shape, order=0)
61+
5462
parallel.seeded_watershed(
5563
boundaries, seeds=seeds, out=output, verbose=verbose, mask=mask, block_shape=chunks, halo=3 * (16,)
5664
)
@@ -60,44 +68,67 @@ def _run_segmentation(pred, output, seeds, chunks, seed_threshold, min_size, ver
6068
parallel.size_filter(output, output, min_size=min_size, verbose=verbose, block_shape=chunks)
6169

6270

63-
# TODO support resizing via the wrapper
6471
def scalable_segmentation(
6572
input_: ArrayLike,
6673
output: ArrayLike,
6774
model: torch.nn.Module,
6875
tiling: Optional[Dict[str, Dict[str, int]]] = None,
76+
scale: Optional[List[float]] = None,
6977
seed_threshold: float = 0.5,
7078
min_size: int = 500,
79+
prediction: Optional[ArrayLike] = None,
7180
verbose: bool = True,
7281
) -> None:
73-
"""Lorem ipsum
82+
"""Run segmentation based on a prediction with foreground and boundary channel.
83+
84+
This function first subtracts the boundary prediction from the foreground prediction,
85+
then applies a threshold, connected components, and a watershed to fit the components
86+
back to the foreground. All processing steps are implemented in a scalable fashion,
87+
so that the function runs for large input volumes.
7488
7589
Args:
76-
input_:
77-
output:
78-
model: The model.
90+
input_: The input data.
91+
output: The array for storing the output segmentation.
92+
Can be a numpy array, a zarr array, or similar.
93+
model: The model for prediction.
7994
tiling: The tiling configuration for the prediction.
95+
scale: The scale factor to use for rescaling the input volume before prediction.
96+
seed_threshold: The threshold applied before computing connected components.
8097
min_size: The minimum size of a vesicle to be considered.
98+
prediction: The array for storing the prediction.
99+
If given, this can be a numpy array, a zarr array, or similar
100+
If not given will be stored in a temporary n5 array.
81101
verbose: Whether to print timing information.
82102
"""
83103
assert model.out_channels == 2
84104

85105
# Create a temporary directory for storing the predictions.
106+
chunks = (128,) * 3
86107
with tempfile.TemporaryDirectory() as tmp_dir:
87-
tmp_pred = os.path.join(tmp_dir, "prediction.n5")
88-
f = open_file(tmp_pred, mode="a")
89108

90-
# Create the dataset for storing the prediction.
91-
chunks = (128,) * 3
92-
pred_shape = (2,) + input_.shape
93-
pred_chunks = (1,) + chunks
94-
pred = f.create_dataset("pred", shape=pred_shape, dtype="float32", chunks=pred_chunks)
109+
if scale is None or np.allclose(scale, 1.0, atol=1e-3):
110+
original_shape = None
111+
else:
112+
original_shape = input_.shape
113+
new_shape = tuple(int(sh * sc) for sh, sc in zip(input_.shape, scale))
114+
input_ = ResizedVolume(input_, shape=new_shape, order=1)
115+
116+
if prediction is None:
117+
# Create the dataset for storing the prediction.
118+
tmp_pred = os.path.join(tmp_dir, "prediction.n5")
119+
f = open_file(tmp_pred, mode="a")
120+
pred_shape = (2,) + input_.shape
121+
pred_chunks = (1,) + chunks
122+
prediction = f.create_dataset("pred", shape=pred_shape, dtype="float32", chunks=pred_chunks)
123+
else:
124+
assert prediction.shape[0] == 2
125+
assert prediction.shape[1:] == input_.shape
95126

96127
# Create temporary storage for the seeds.
97128
tmp_seeds = os.path.join(tmp_dir, "seeds.n5")
98129
f = open_file(tmp_seeds, mode="a")
99130
seeds = f.create_dataset("seeds", shape=input_.shape, dtype="uint64", chunks=chunks)
100131

101132
# Run prediction and segmentation.
102-
get_prediction(input_, prediction=pred, tiling=tiling, model=model, verbose=verbose)
103-
_run_segmentation(pred, output, seeds, chunks, seed_threshold, min_size, verbose)
133+
get_prediction(input_, prediction=prediction, tiling=tiling, model=model, verbose=verbose)
134+
_run_segmentation(prediction, output, seeds, chunks, seed_threshold, min_size, verbose, original_shape)

synapse_net/tools/cli.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,11 @@ def segmentation_cli():
152152
"--verbose", "-v", action="store_true",
153153
help="Whether to print verbose information about the segmentation progress."
154154
)
155+
# TODO scalable seg
156+
parser.add_argument(
157+
"--", action="store_true",
158+
help=""
159+
)
155160
args = parser.parse_args()
156161

157162
if args.checkpoint is None:

0 commit comments

Comments
 (0)