diff --git a/synapse_net/inference/util.py b/synapse_net/inference/util.py index 454df943..36269b35 100644 --- a/synapse_net/inference/util.py +++ b/synapse_net/inference/util.py @@ -54,6 +54,7 @@ def __init__(self, scale, verbose): self.scale = scale def scale_input(self, input_volume, is_segmentation=False): + t0 = time.time() if self.scale is None: return input_volume @@ -73,10 +74,11 @@ def scale_input(self, input_volume, is_segmentation=False): input_volume = rescale(input_volume, self.scale, preserve_range=True).astype(input_volume.dtype) if self.verbose: - print("Rescaled volume from", self._original_shape, "to", input_volume.shape) + print("Rescaled volume from", self._original_shape, "to", input_volume.shape, "in", time.time() - t0, "s") return input_volume def rescale_output(self, output, is_segmentation): + t0 = time.time() if self.scale is None: return output @@ -91,6 +93,9 @@ def rescale_output(self, output, is_segmentation): else: output = resize(output, out_shape, preserve_range=True).astype(output.dtype) + if self.verbose: + print("Resized prediction back to original shape", output.shape, "in", time.time() - t0, "s") + return output @@ -463,7 +468,6 @@ def get_default_tiling(is_2d: bool = False) -> Dict[str, Dict[str, int]]: tiling = {"tile": tile, "halo": halo} print(f"Determined tile size for MPS: {tiling}") - # I am not sure what is reasonable on a cpu. For now choosing very small tiling. # (This will not work well on a CPU in any case.) else: diff --git a/synapse_net/tools/cli.py b/synapse_net/tools/cli.py index 609bb0ee..60900001 100644 --- a/synapse_net/tools/cli.py +++ b/synapse_net/tools/cli.py @@ -146,6 +146,10 @@ def segmentation_cli(): "By default, the scaling factor will be derived from the voxel size of the input data. " "If this parameter is given it will over-ride the default behavior. " ) + parser.add_argument( + "--verbose", "-v", action="store_true", + help="Whether to print verbose information about the segmentation progress." + ) args = parser.parse_args() if args.checkpoint is None: @@ -169,7 +173,7 @@ def segmentation_cli(): scale = (2 if is_2d else 3) * (args.scale,) segmentation_function = partial( - run_segmentation, model=model, model_type=args.model, verbose=False, tiling=tiling, + run_segmentation, model=model, model_type=args.model, verbose=args.verbose, tiling=tiling, ) inference_helper( args.input_path, args.output_path, segmentation_function,