diff --git a/flamingo_tools/segmentation/unet_prediction.py b/flamingo_tools/segmentation/unet_prediction.py index 38f66e6..0f5aabf 100644 --- a/flamingo_tools/segmentation/unet_prediction.py +++ b/flamingo_tools/segmentation/unet_prediction.py @@ -67,6 +67,7 @@ def prediction_impl( slurm_task_id=0, mean=None, std=None, + mask=None ): """@private """ @@ -79,18 +80,20 @@ def prediction_impl( input_ = read_image_data(input_path, input_key) chunks = getattr(input_, "chunks", (64, 64, 64)) - mask_path = os.path.join(output_folder, "mask.zarr") - - if os.path.exists(mask_path): - image_mask = z5py.File(mask_path, "r")["mask"] - # resize mask - image_shape = input_.shape - mask_shape = image_mask.shape - if image_shape != mask_shape: - image_mask = ResizedVolume(image_mask, image_shape, order=0) + if output_folder is None: + image_mask = mask else: - image_mask = None + mask_path = os.path.join(output_folder, "mask.zarr") + if os.path.exists(mask_path): + image_mask = z5py.File(mask_path, "r")["mask"] + # resize mask + image_shape = input_.shape + mask_shape = image_mask.shape + if image_shape != mask_shape: + image_mask = ResizedVolume(image_mask, image_shape, order=0) + else: + image_mask = mask if scale is None or np.isclose(scale, 1): original_shape = None @@ -162,16 +165,8 @@ def postprocess(x): else: slurm_iteration = list(range(n_blocks)) - output_path = os.path.join(output_folder, "predictions.zarr") - with open_file(output_path, "a") as f: - output = f.require_dataset( - "prediction", - shape=output_shape, - chunks=output_chunks, - compression="gzip", - dtype="float32", - ) - + if output_folder is None: + output = np.zeros(output_shape, dtype=np.float32) predict_with_halo( input_, model, gpu_ids=gpu_ids, block_shape=block_shape, halo=halo, @@ -180,10 +175,37 @@ def postprocess(x): iter_list=slurm_iteration, ) - return original_shape + else: + output_path = os.path.join(output_folder, "predictions.zarr") + with open_file(output_path, "a") as f: + output = f.require_dataset( + "prediction", + shape=output_shape, + chunks=output_chunks, + compression="gzip", + dtype="float32", + ) + + predict_with_halo( + input_, model, + gpu_ids=gpu_ids, block_shape=block_shape, halo=halo, + output=output, preprocess=preprocess, postprocess=postprocess, + mask=image_mask, + iter_list=slurm_iteration, + ) + + if output_folder is None: + return original_shape, output + else: + return original_shape, None -def find_mask(input_path: str, input_key: Optional[str], output_folder: str, seg_class: Optional[str] = "sgn") -> None: +def find_mask( + input_path: str, + input_key: Optional[str], + output_folder: Optional[str], + seg_class: Optional[str] = "sgn" +) -> None: """Determine the mask for running prediction. The mask corresponds to data that contains actual signal and not just noise. @@ -197,9 +219,6 @@ def find_mask(input_path: str, input_key: Optional[str], output_folder: str, seg output_folder: The output folder for storing the mask data. seg_class: Specifier for exclusion criterias for mask generation. """ - mask_path = os.path.join(output_folder, "mask.zarr") - f = z5py.File(mask_path, "a") - # set parameters for the exclusion of chunks within mask generation if seg_class == "sgn": upper_percentile = 95 @@ -214,10 +233,6 @@ def find_mask(input_path: str, input_key: Optional[str], output_folder: str, seg min_intensity = 200 print("Calculating mask with default values.") - mask_key = "mask" - if mask_key in f: - return - raw = read_image_data(input_path, input_key) chunks = getattr(raw, "chunks", (64, 64, 64)) @@ -225,7 +240,17 @@ def find_mask(input_path: str, input_key: Optional[str], output_folder: str, seg blocking = nt.blocking([0, 0, 0], raw.shape, block_shape) n_blocks = blocking.numberOfBlocks - ds_mask = f.create_dataset(mask_key, shape=raw.shape, compression="gzip", dtype="uint8", chunks=block_shape) + if output_folder is None: + ds_mask = np.zeros(raw.shape, dtype=np.uint64) + + else: + mask_path = os.path.join(output_folder, "mask.zarr") + f = z5py.File(mask_path, "a") + mask_key = "mask" + if mask_key in f: + return + + ds_mask = f.create_dataset(mask_key, shape=raw.shape, compression="gzip", dtype="uint8", chunks=block_shape) # TODO more sophisticated criterion?! def find_mask_block(block_id): @@ -240,15 +265,20 @@ def find_mask_block(block_id): with futures.ThreadPoolExecutor(n_threads) as tp: list(tqdm(tp.map(find_mask_block, range(n_blocks)), total=n_blocks)) + if output_folder is None: + return ds_mask + else: + return None + def distance_watershed_implementation( input_path: str, - output_folder: str, - min_size: int, + output_folder: Optional[str] = None, + min_size: int = 1000, center_distance_threshold: float = 0.4, boundary_distance_threshold: Optional[float] = None, fg_threshold: float = 0.5, - original_shape: Optional[Tuple[int, int, int]] = None, + original_shape: Optional[Tuple[int, int, int]] = None ) -> None: """Parallel implementation of the distance-prediction based watershed. @@ -262,7 +292,10 @@ def distance_watershed_implementation( fg_threshold: The threshold applied to the foreground prediction for deriving the watershed mask. original_shape: The original shape to resize the segmentation to. """ - input_ = open_file(input_path, "r")["prediction"] + if isinstance(input_path, str): + input_ = open_file(input_path, "r")["prediction"] + else: + input_ = input_path # Limit the number of cores for parallelization. n_threads = min(16, mp.cpu_count()) @@ -280,13 +313,17 @@ def distance_watershed_implementation( # center_distances = SimpleTransformationWrapper(center_distances, transformation=smoothing) # boundary_distances = SimpleTransformationWrapper(boundary_distances, transformation=smoothing) - # Allocate an zarr array for the seeds. - block_shape = center_distances.chunks - seed_path = os.path.join(output_folder, "seeds.zarr") - seed_file = open_file(os.path.join(seed_path), "a") - seeds = seed_file.require_dataset( - "seeds", shape=center_distances.shape, chunks=block_shape, compression="gzip", dtype="uint64" - ) + # Allocate the (zarr) array for the seeds. + if output_folder is None: + block_shape = (20, 128, 128) + seeds = np.zeros(center_distances.shape, dtype=np.uint64) + else: + block_shape = center_distances.chunks + seed_path = os.path.join(output_folder, "seeds.zarr") + seed_file = open_file(os.path.join(seed_path), "a") + seeds = seed_file.require_dataset( + "seeds", shape=center_distances.shape, chunks=block_shape, compression="gzip", dtype="uint64" + ) # Compute the seed inputs: # First, threshold the center distances. @@ -301,12 +338,15 @@ def distance_watershed_implementation( data=seed_inputs, out=seeds, block_shape=block_shape, mask=mask, verbose=True, n_threads=n_threads ) - # Allocate the zarr array for the segmentation. - seg_path = os.path.join(output_folder, "segmentation.zarr" if original_shape is None else "seg_downscaled.zarr") - seg_file = open_file(seg_path, "a") - seg = seg_file.create_dataset( - "segmentation", shape=seeds.shape, chunks=block_shape, compression="gzip", dtype="uint64" - ) + # Allocate the (zarr) array for the segmentation. + if output_folder is None: + seg = np.zeros(seeds.shape, dtype=np.uint64) + else: + seg_path = os.path.join(output_folder, "segmentation.zarr" if original_shape is None else "seg_downscaled.zarr") + seg_file = open_file(seg_path, "a") + seg = seg_file.create_dataset( + "segmentation", shape=seeds.shape, chunks=block_shape, compression="gzip", dtype="uint64" + ) # Compute the segmentation with a seeded watershed halo = (2, 8, 8) @@ -341,6 +381,11 @@ def write_block(block_id): with futures.ThreadPoolExecutor(n_threads) as tp: tp.map(write_block, range(blocking.numberOfBlocks)) + if output_folder is None: + return seg + else: + return None + def calc_mean_and_std(input_path: str, input_key: str, output_folder: str) -> None: """Calculate mean and standard deviation of the input volume. @@ -372,7 +417,7 @@ def calc_mean_and_std(input_path: str, input_key: str, output_folder: str) -> No def run_unet_prediction( input_path: str, input_key: Optional[str], - output_folder: str, + output_folder: Optional[str], model_path: str, min_size: int, scale: Optional[float] = None, @@ -403,22 +448,33 @@ def run_unet_prediction( fg_threshold: The threshold applied to the foreground prediction for deriving the watershed mask. seg_class: Specifier for exclusion criterias for mask generation. """ - os.makedirs(output_folder, exist_ok=True) + if output_folder is not None: + os.makedirs(output_folder, exist_ok=True) if use_mask: - find_mask(input_path, input_key, output_folder, seg_class=seg_class) - original_shape = prediction_impl( - input_path, input_key, output_folder, model_path, scale, block_shape, halo + mask = find_mask(input_path, input_key, output_folder=output_folder, seg_class=seg_class) + else: + mask = None + + original_shape, prediction = prediction_impl( + input_path=input_path, input_key=input_key, output_folder=output_folder, model_path=model_path, scale=scale, + block_shape=block_shape, halo=halo, mask=mask ) - pmap_out = os.path.join(output_folder, "predictions.zarr") - distance_watershed_implementation( + if output_folder is None: + pmap_out = prediction + else: + pmap_out = os.path.join(output_folder, "predictions.zarr") + + segmentation = distance_watershed_implementation( pmap_out, output_folder, min_size=min_size, original_shape=original_shape, center_distance_threshold=center_distance_threshold, boundary_distance_threshold=boundary_distance_threshold, - fg_threshold=fg_threshold, + fg_threshold=fg_threshold ) + return segmentation + # # ---Workflow for parallel prediction using slurm--- diff --git a/scripts/prediction/run_prediction_distance_unet.py b/scripts/prediction/run_prediction_distance_unet.py index edca013..70a913e 100644 --- a/scripts/prediction/run_prediction_distance_unet.py +++ b/scripts/prediction/run_prediction_distance_unet.py @@ -5,7 +5,11 @@ prediction, and segmentation. """ import argparse +import json +import time +import os +import imageio.v3 as imageio import torch import z5py @@ -19,7 +23,20 @@ def main(): parser.add_argument("-m", "--model", required=True) parser.add_argument("-k", "--input_key", default=None) parser.add_argument("-s", "--scale", default=None, type=float, help="Downscale the image by the given factor.") - parser.add_argument("-b", "--block_shape", default=None, type=int, nargs=3) + parser.add_argument("-b", "--block_shape", default=None, type=str) + parser.add_argument("--halo", default=None, type=str) + parser.add_argument("--memory", action="store_true", help="Perform prediction in memory and save output as tif.") + parser.add_argument("--time", action="store_true", help="Time prediction process.") + parser.add_argument("--seg_class", default=None, type=str, + help="Segmentation class to load parameters for masking input.") + parser.add_argument("--center_distance_threshold", default=0.4, type=float, + help="The threshold applied to the distance center predictions to derive seeds.") + parser.add_argument("--boundary_distance_threshold", default=None, type=float, + help="The threshold applied to the boundary predictions to derive seeds. \ + By default this is set to 'None', \ + in which case the boundary distances are not used for the seeds.") + parser.add_argument("--fg_threshold", default=0.5, type=float, + help="The threshold applied to the foreground prediction for deriving the watershed mask.") args = parser.parse_args() @@ -36,21 +53,57 @@ def main(): if args.block_shape is None: block_shape = (64, 256, 256) if have_cuda else (64, 64, 64) else: - block_shape = tuple(args.block_shape) - halo = (16, 64, 64) if have_cuda else (8, 32, 32) + block_shape = tuple(json.loads(args.block_shape)) + else: if args.block_shape is None: chunks = z5py.File(args.input, "r")[args.input_key].chunks block_shape = tuple([2 * ch for ch in chunks]) if have_cuda else tuple(chunks) else: - block_shape = tuple(args.block_shape) + block_shape = json.loads(args.block_shape) + + if args.halo is None: halo = (16, 64, 64) if have_cuda else (8, 32, 32) + else: + halo = tuple(json.loads(args.halo)) + + if args.time: + start = time.perf_counter() + + if args.memory: + segmentation = run_unet_prediction( + args.input, args.input_key, output_folder=None, model_path=args.model, + scale=scale, min_size=min_size, + block_shape=block_shape, halo=halo, + seg_class=args.seg_class, + center_distance_threshold = args.center_distance_threshold, + boundary_distance_threshold = args.boundary_distance_threshold, + fg_threshold = args.fg_threshold, + ) + + abs_path = os.path.abspath(args.input) + basename = ".".join(os.path.basename(abs_path).split(".")[:-1]) + output_path = os.path.join(args.output_folder, basename + "_seg.tif") + imageio.imwrite(output_path, segmentation, compression="zlib") + timer_output = os.path.join(args.output_folder, basename + "_timer.json") + + else: + run_unet_prediction( + args.input, args.input_key, output_folder=args.output_folder, model_path=args.model, + scale=scale, min_size=min_size, + block_shape=block_shape, halo=halo, + seg_class=args.seg_class, + center_distance_threshold = args.center_distance_threshold, + boundary_distance_threshold = args.boundary_distance_threshold, + fg_threshold = args.fg_threshold, + ) + timer_output = os.path.join(args.output_folder, "timer.json") - run_unet_prediction( - args.input, args.input_key, args.output_folder, args.model, - scale=scale, min_size=min_size, - block_shape=block_shape, halo=halo, - ) + if args.time: + duration = time.perf_counter() - start + time_dict = {"total_duration[s]": duration} + with open(timer_output, "w") as f: + json.dump(time_dict, f, indent='\t', separators=(',', ': ')) if __name__ == "__main__":