|
6 | 6 | import torch_em |
7 | 7 | from ..imod.to_imod import export_helper, write_segmentation_to_imod_as_points, write_segmentation_to_imod |
8 | 8 | from ..inference.inference import _get_model_registry, get_model, get_model_training_resolution, run_segmentation |
| 9 | +from ..inference.scalable_segmentation import scalable_segmentation |
9 | 10 | from ..inference.util import inference_helper, parse_tiling |
10 | 11 |
|
11 | 12 |
|
@@ -152,10 +153,9 @@ def segmentation_cli(): |
152 | 153 | "--verbose", "-v", action="store_true", |
153 | 154 | help="Whether to print verbose information about the segmentation progress." |
154 | 155 | ) |
155 | | - # TODO scalable seg |
156 | 156 | parser.add_argument( |
157 | | - "--", action="store_true", |
158 | | - help="" |
| 157 | + "--scalable", action="store_true", help="Use the scalable segmentation implementation. " |
| 158 | + "Currently this only works for vesicles, mitochondria, or active zones." |
159 | 159 | ) |
160 | 160 | args = parser.parse_args() |
161 | 161 |
|
@@ -186,11 +186,26 @@ def segmentation_cli(): |
186 | 186 | model_resolution = None |
187 | 187 | scale = (2 if is_2d else 3) * (args.scale,) |
188 | 188 |
|
189 | | - segmentation_function = partial( |
190 | | - run_segmentation, model=model, model_type=args.model, verbose=args.verbose, tiling=tiling, |
191 | | - ) |
| 189 | + if args.scalable: |
| 190 | + if not args.model.startswith(("vesicle", "mito", "active")): |
| 191 | + raise ValueError( |
| 192 | + "The scalable segmentation implementation is currently only supported for " |
| 193 | + f"vesicles, mitochondria, or active zones, not for {args.model}." |
| 194 | + ) |
| 195 | + segmentation_function = partial( |
| 196 | + scalable_segmentation, model=model, tiling=tiling, verbose=args.verbose |
| 197 | + ) |
| 198 | + allocate_output = True |
| 199 | + |
| 200 | + else: |
| 201 | + segmentation_function = partial( |
| 202 | + run_segmentation, model=model, model_type=args.model, verbose=args.verbose, tiling=tiling, |
| 203 | + ) |
| 204 | + allocate_output = False |
| 205 | + |
192 | 206 | inference_helper( |
193 | 207 | args.input_path, args.output_path, segmentation_function, |
194 | 208 | mask_input_path=args.mask_path, force=args.force, data_ext=args.data_ext, |
195 | 209 | output_key=args.segmentation_key, model_resolution=model_resolution, scale=scale, |
| 210 | + allocate_output=allocate_output |
196 | 211 | ) |
0 commit comments