Skip to content

Commit 46c0800

Browse files
Expose scalable segmentation to the CLI
1 parent 75a5ed2 commit 46c0800

File tree

3 files changed

+31
-7
lines changed

3 files changed

+31
-7
lines changed

synapse_net/inference/scalable_segmentation.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ def scalable_segmentation(
7878
min_size: int = 500,
7979
prediction: Optional[ArrayLike] = None,
8080
verbose: bool = True,
81+
mask: Optional[ArrayLike] = None,
8182
) -> None:
8283
"""Run segmentation based on a prediction with foreground and boundary channel.
8384
@@ -100,6 +101,8 @@ def scalable_segmentation(
100101
If not given will be stored in a temporary n5 array.
101102
verbose: Whether to print timing information.
102103
"""
104+
if mask is not None:
105+
raise NotImplementedError
103106
assert model.out_channels == 2
104107

105108
# Create a temporary directory for storing the predictions.

synapse_net/inference/util.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -344,6 +344,7 @@ def inference_helper(
344344
output_key: Optional[str] = None,
345345
model_resolution: Optional[Tuple[float, float, float]] = None,
346346
scale: Optional[Tuple[float, float, float]] = None,
347+
allocate_output: bool = False,
347348
) -> None:
348349
"""Helper function to run segmentation for mrc files.
349350
@@ -366,6 +367,7 @@ def inference_helper(
366367
model_resolution: The resolution / voxel size to which the inputs should be scaled for prediction.
367368
If given, the scaling factor will automatically be determined based on the voxel_size of the input data.
368369
scale: Fixed factor for scaling the model inputs. Cannot be passed together with 'model_resolution'.
370+
allocate_output: Whether to allocate the output for the segmentation function.
369371
"""
370372
if (scale is not None) and (model_resolution is not None):
371373
raise ValueError("You must not provide both 'scale' and 'model_resolution' arguments.")
@@ -431,7 +433,11 @@ def inference_helper(
431433
this_scale = _derive_scale(img_path, model_resolution)
432434

433435
# Run the segmentation.
434-
segmentation = segmentation_function(input_volume, mask=mask, scale=this_scale)
436+
if allocate_output:
437+
segmentation = np.zeros(input_volume.shape, dtype="uint32")
438+
segmentation_function(input_volume, output=segmentation, mask=mask, scale=this_scale)
439+
else:
440+
segmentation = segmentation_function(input_volume, mask=mask, scale=this_scale)
435441

436442
# Write the result to tif or h5.
437443
os.makedirs(os.path.split(output_path)[0], exist_ok=True)

synapse_net/tools/cli.py

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import torch_em
77
from ..imod.to_imod import export_helper, write_segmentation_to_imod_as_points, write_segmentation_to_imod
88
from ..inference.inference import _get_model_registry, get_model, get_model_training_resolution, run_segmentation
9+
from ..inference.scalable_segmentation import scalable_segmentation
910
from ..inference.util import inference_helper, parse_tiling
1011

1112

@@ -152,10 +153,9 @@ def segmentation_cli():
152153
"--verbose", "-v", action="store_true",
153154
help="Whether to print verbose information about the segmentation progress."
154155
)
155-
# TODO scalable seg
156156
parser.add_argument(
157-
"--", action="store_true",
158-
help=""
157+
"--scalable", action="store_true", help="Use the scalable segmentation implementation. "
158+
"Currently this only works for vesicles, mitochondria, or active zones."
159159
)
160160
args = parser.parse_args()
161161

@@ -186,11 +186,26 @@ def segmentation_cli():
186186
model_resolution = None
187187
scale = (2 if is_2d else 3) * (args.scale,)
188188

189-
segmentation_function = partial(
190-
run_segmentation, model=model, model_type=args.model, verbose=args.verbose, tiling=tiling,
191-
)
189+
if args.scalable:
190+
if not args.model.startswith(("vesicle", "mito", "active")):
191+
raise ValueError(
192+
"The scalable segmentation implementation is currently only supported for "
193+
f"vesicles, mitochondria, or active zones, not for {args.model}."
194+
)
195+
segmentation_function = partial(
196+
scalable_segmentation, model=model, tiling=tiling, verbose=args.verbose
197+
)
198+
allocate_output = True
199+
200+
else:
201+
segmentation_function = partial(
202+
run_segmentation, model=model, model_type=args.model, verbose=args.verbose, tiling=tiling,
203+
)
204+
allocate_output = False
205+
192206
inference_helper(
193207
args.input_path, args.output_path, segmentation_function,
194208
mask_input_path=args.mask_path, force=args.force, data_ext=args.data_ext,
195209
output_key=args.segmentation_key, model_resolution=model_resolution, scale=scale,
210+
allocate_output=allocate_output
196211
)

0 commit comments

Comments
 (0)