diff --git a/flamingo_tools/segmentation/unet_prediction.py b/flamingo_tools/segmentation/unet_prediction.py index fde52a0..ca56d09 100644 --- a/flamingo_tools/segmentation/unet_prediction.py +++ b/flamingo_tools/segmentation/unet_prediction.py @@ -10,6 +10,7 @@ import vigra import torch import z5py +import json from elf.wrapper import ThresholdWrapper, SimpleTransformationWrapper from elf.wrapper.resized_volume import ResizedVolume @@ -18,6 +19,11 @@ from torch_em.util.prediction import predict_with_halo from tqdm import tqdm +""" +Prediction using distance U-Net. +Parallelization using multiple GPUs is currently only possible by calling functions directly. +Functions for the parallelization end with '_slurm' and divide the process into preprocessing, prediction, and segmentation. +""" class SelectChannel(SimpleTransformationWrapper): def __init__(self, volume, channel): @@ -37,13 +43,13 @@ def ndim(self): return self._volume.ndim - 1 -def prediction_impl(input_path, input_key, output_folder, model_path, scale, block_shape, halo): +def prediction_impl(input_path, input_key, output_folder, model_path, scale, block_shape, halo, prediction_instances=1, slurm_task_id=0, mean=None, std=None): with warnings.catch_warnings(): warnings.simplefilter("ignore") if os.path.isdir(model_path): model = load_model(model_path) else: - model = torch.load(model_path) + model = torch.load(model_path, weights_only=False) mask_path = os.path.join(output_folder, "mask.zarr") image_mask = z5py.File(mask_path, "r")["mask"] @@ -66,10 +72,11 @@ def prediction_impl(input_path, input_key, output_folder, model_path, scale, blo image_mask = ResizedVolume(image_mask, new_shape, order=0) have_cuda = torch.cuda.is_available() + if block_shape is None: - block_shape = tuple([2 * ch for ch in input_.chunks]) if have_cuda else input_.chunks + block_shape = (128, 128, 128) if have_cuda else input_.chunks if halo is None: - halo = (16, 64, 64) if have_cuda else (16, 32, 32) + halo = (16, 32, 32) if have_cuda: print("Predict with GPU") gpu_ids = [0] @@ -77,12 +84,13 @@ def prediction_impl(input_path, input_key, output_folder, model_path, scale, blo print("Predict with CPU") gpu_ids = ["cpu"] - # Compute the global mean and standard deviation. - n_threads = min(16, mp.cpu_count()) - mean, std = parallel.mean_and_std( - input_, block_shape=block_shape, n_threads=n_threads, verbose=True, - mask=image_mask - ) + if mean is None or std is None: + # Compute the global mean and standard deviation. + n_threads = min(16, mp.cpu_count()) + mean, std = parallel.mean_and_std( + input_, block_shape=block_shape, n_threads=n_threads, verbose=True, + mask=image_mask + ) print("Mean and standard deviation computed for the full volume:") print(mean, std) @@ -98,6 +106,17 @@ def postprocess(x): x[1] = vigra.filters.gaussianSmoothing(x[1], sigma=2.0) return x + shape = input_.shape + ndim = len(shape) + + blocking = nt.blocking([0] * ndim, shape, block_shape) + n_blocks = blocking.numberOfBlocks + if prediction_instances != 1: + iteration_ids = [x.tolist() for x in np.array_split(list(range(n_blocks)), prediction_instances)] + slurm_iteration = iteration_ids[slurm_task_id] + else: + slurm_iteration = list(range(n_blocks)) + output_path = os.path.join(output_folder, "predictions.zarr") with open_file(output_path, "a") as f: output = f.require_dataset( @@ -113,6 +132,7 @@ def postprocess(x): gpu_ids=gpu_ids, block_shape=block_shape, halo=halo, output=output, preprocess=preprocess, postprocess=postprocess, mask=image_mask, + iter_list=slurm_iteration, ) return original_shape @@ -223,6 +243,30 @@ def write_block(block_id): tp.map(write_block, range(blocking.numberOfBlocks)) +def calc_mean_and_std(input_path, input_key, output_folder): + """ + Calculate mean and standard deviation of full volume. + Parameters are saved in 'mean_std.json' within the output folder. + """ + json_file = os.path.join(output_folder, "mean_std.json") + mask_path = os.path.join(output_folder, "mask.zarr") + image_mask = z5py.File(mask_path, "r")["mask"] + + if input_key is None: + input_ = imageio.imread(input_path) + else: + input_ = open_file(input_path, "r")[input_key] + + # Compute the global mean and standard deviation. + n_threads = min(16, mp.cpu_count()) + mean, std = parallel.mean_and_std( + input_, block_shape=tuple([2* i for i in input_.chunks]), n_threads=n_threads, verbose=True, + mask=image_mask + ) + ddict = {"mean":mean, "std":std} + with open(json_file, "w") as f: + json.dump(ddict, f) + def run_unet_prediction( input_path, input_key, output_folder, model_path, @@ -239,3 +283,56 @@ def run_unet_prediction( pmap_out = os.path.join(output_folder, "predictions.zarr") segmentation_impl(pmap_out, output_folder, min_size=min_size, original_shape=original_shape) + +#---Workflow for parallel prediction using slurm--- + +def run_unet_prediction_preprocess_slurm( + input_path, input_key, output_folder, +): + """ + Pre-processing for the parallel prediction with U-Net models. + Masks are stored in mask.zarr in the output folder. + The mean and standard deviation are precomputed for later usage during prediction + and stored in a JSON file within the output folder as mean_std.json + """ + find_mask(input_path, input_key, output_folder) + calc_mean_and_std(input_path, input_key, output_folder) + +def run_unet_prediction_slurm( + input_path, input_key, output_folder, model_path, + scale=None, + block_shape=None, halo=None, prediction_instances=1, +): + os.makedirs(output_folder, exist_ok=True) + prediction_instances = int(prediction_instances) + slurm_task_id = os.environ.get("SLURM_ARRAY_TASK_ID") + + if slurm_task_id is not None: + slurm_task_id = int(slurm_task_id) + else: + raise ValueError("The SLURM_ARRAY_TASK_ID is not set. Ensure that you are using the '-a' option with SBATCH.") + + if not os.path.isdir(os.path.join(output_folder, "mask.zarr")): + find_mask(input_path, input_key, output_folder) + + # get pre-computed mean and standard deviation of full volume from JSON file + if os.path.isfile(os.path.join(output_folder, "mean_std.json")): + with open(os.path.join(output_folder, "mean_std.json")) as f: + d = json.load(f) + mean = float(d["mean"]) + std = float(d["std"]) + else: + mean = None + std = None + + original_shape = prediction_impl( + input_path, input_key, output_folder, model_path, scale, block_shape, halo, + prediction_instances=prediction_instances, slurm_task_id=slurm_task_id, + mean=mean, std=std, + ) + +# does NOT need GPU, FIXME: only run on CPU +def run_unet_segmentation_slurm(output_folder, min_size): + min_size = int(min_size) + pmap_out = os.path.join(output_folder, "predictions.zarr") + segmentation_impl(pmap_out, output_folder, min_size=min_size) diff --git a/scripts/convert_tif_to_n5.py b/scripts/convert_tif_to_n5.py new file mode 100644 index 0000000..b8844f6 --- /dev/null +++ b/scripts/convert_tif_to_n5.py @@ -0,0 +1,40 @@ +import os, sys +import argparse +import pybdv +import imageio.v3 as imageio + + +def main(input_path, output_path): + """ + Convert tif file to n5 format. + If no output_path is supplied, the output file is created in the same directory as the input. + + :param str input_path: Input tif + :param str output_path: Output path for n5 format + """ + if not os.path.isfile(input_path): + sys.exit("Input file does not exist.") + + if input_path.split(".")[-1] not in ["TIFF", "TIF", "tiff", "tif"]: + sys.exit("Input file must be in tif format.") + + basename = "".join(input_path.split("/")[-1].split(".")[:-1]) + input_dir = input_path.split(basename)[0] + input_dir = os.path.abspath(input_dir) + + if "" == output_path: + output_path = os.path.join(input_dir, basename + ".n5") + img = imageio.imread(input_path) + pybdv.make_bdv(img, output_path) + +if __name__ == "__main__": + + parser = argparse.ArgumentParser( + description="Script to transform file from tif into n5 format.") + + parser.add_argument('input', type=str, help="Input file") + parser.add_argument('-o', "--output", type=str, default="", help="Output file. Default: .n5") + + args = parser.parse_args() + + main(args.input, args.output) \ No newline at end of file diff --git a/scripts/extract_block.py b/scripts/extract_block.py new file mode 100644 index 0000000..f248dc2 --- /dev/null +++ b/scripts/extract_block.py @@ -0,0 +1,110 @@ +import os +import argparse +import numpy as np +import z5py +import zarr + +import s3fs + +""" +This script extracts data around an input center coordinate in a given ROI halo. + +The support for using an S3 bucket is currently limited to the lightsheet-cochlea bucket with the endpoint url https://s3.fs.gwdg.de. +If more use cases appear, the script will be generalized. +The usage requires the export of the access and the secret access key within the environment before executing the script. +run the following commands in the shell of your choice, or add them to your ~/.bashrc: +export AWS_ACCESS_KEY_ID= +export AWS_SECRET_ACCESS_KEY= +""" + + +def main(input_file, output_dir, input_key, resolution, coords, roi_halo, s3): + """ + + :param str input_file: File path to input folder in n5 format + :param str output_dir: output directory for saving cropped n5 file as _crop.n5 + :param str input_key: Key for accessing volume in n5 format, e.g. 'setup0/s0' + :param float resolution: Resolution of input data in micrometer + :param str coords: Center coordinates of extracted 3D volume in format 'x,y,z' + :param str roi_halo: ROI halo of extracted 3D volume in format 'x,y,z' + :param bool s3: Flag for using an S3 bucket + """ + + coords = [int(r) for r in coords.split(",")] + roi_halo = [int(r) for r in roi_halo.split(",")] + + coord_string = "-".join([str(c) for c in coords]) + + # Dimensions are inversed to view in MoBIE (x y z) -> (z y x) + coords.reverse() + roi_halo.reverse() + + input_content = list(filter(None, input_file.split("/"))) + + if s3: + basename = input_content[0] + "_" + input_content[-1].split(".")[0] + else: + basename = "".join(input_content[-1].split(".")[:-1]) + + input_dir = input_file.split(basename)[0] + input_dir = os.path.abspath(input_dir) + + if output_dir == "": + output_dir = input_dir + + output_file = os.path.join(output_dir, basename + "_crop_" + coord_string + ".n5") + + coords = np.array(coords) + coords = coords / resolution + coords = np.round(coords).astype(np.int32) + + roi = tuple(slice(co - rh, co + rh) for co, rh in zip(coords, roi_halo)) + + if s3: + + # Define S3 bucket and OME-Zarr dataset path + + bucket_name = "cochlea-lightsheet" + zarr_path = f"{bucket_name}/{input_file}" + + # Create an S3 filesystem + fs = s3fs.S3FileSystem( + client_kwargs={"endpoint_url": "https://s3.fs.gwdg.de"}, + anon=False + ) + + if not fs.exists(zarr_path): + print("Error: Path does not exist!") + + # Open the OME-Zarr dataset + store = zarr.storage.FSStore(zarr_path, fs=fs) + print(f"Opening file {zarr_path} from the S3 bucket.") + + with zarr.open(store, mode="r") as f: + raw = f[input_key][roi] + + else: + with z5py.File(input_file, "r") as f: + raw = f[input_key][roi] + + with z5py.File(output_file, "w") as f_out: + f_out.create_dataset("raw", data=raw, compression="gzip") + +if __name__ == "__main__": + + parser = argparse.ArgumentParser( + description="Script to extract region of interest (ROI) block around center coordinate.") + + parser.add_argument('input', type=str, help="Input file in n5 format.") + parser.add_argument('-o', "--output", type=str, default="", help="Output directory") + parser.add_argument('-c', "--coord", type=str, required=True, help="3D coordinate in format 'x,y,z' as center of extracted block.") + + parser.add_argument('-k', "--input_key", type=str, default="setup0/timepoint0/s0", help="Input key for data in input file") + parser.add_argument('-r', "--resolution", type=float, default=0.38, help="Resolution of input in micrometer") + + parser.add_argument("--roi_halo", type=str, default="128,128,64", help="ROI halo around center coordinate in format 'x,y,z'") + parser.add_argument("--s3", action="store_true", help="Use S3 bucket") + + args = parser.parse_args() + + main(args.input, args.output, args.input_key, args.resolution, args.coord, args.roi_halo, args.s3) diff --git a/scripts/prediction/count_cells.py b/scripts/prediction/count_cells.py new file mode 100644 index 0000000..b92f284 --- /dev/null +++ b/scripts/prediction/count_cells.py @@ -0,0 +1,32 @@ +import argparse +import os +import sys + +from elf.parallel import unique +from elf.io import open_file + +sys.path.append("../..") + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("-o", "--output_folder", type=str, required=True, help="Output directory containing segmentation.zarr") + parser.add_argument("-m", "--min_size", type=int, default=1000, help="Minimal number of voxel size for counting object") + args = parser.parse_args() + + seg_path = os.path.join(args.output_folder, "segmentation.zarr") + seg_key = "segmentation" + + file = open_file(seg_path, mode='r') + dataset = file[seg_key] + + ids, counts = unique(dataset, return_counts=True) + + # You can change the minimal size for objects to be counted here: + min_size = args.min_size + + counts = counts[counts > min_size] + print("Number of objects:", len(counts)) + +if __name__ == "__main__": + main() diff --git a/scripts/prediction/run_prediction_distance_unet.py b/scripts/prediction/run_prediction_distance_unet.py index 93bb32a..fe30ae2 100644 --- a/scripts/prediction/run_prediction_distance_unet.py +++ b/scripts/prediction/run_prediction_distance_unet.py @@ -6,6 +6,11 @@ sys.path.append("../..") +""" +Prediction using distance U-Net. +Parallelization using multiple GPUs is currently only possible by calling functions located in segmentation/unet_prediction.py directly. +Functions for the parallelization end with '_slurm' and divide the process into preprocessing, prediction, and segmentation. +""" def main(): from flamingo_tools.segmentation import run_unet_prediction diff --git a/scripts/resize_wrongly_scaled_cochleas.py b/scripts/resize_wrongly_scaled_cochleas.py new file mode 100644 index 0000000..dc76ca9 --- /dev/null +++ b/scripts/resize_wrongly_scaled_cochleas.py @@ -0,0 +1,73 @@ +import argparse +import os + +import multiprocessing as mp +from concurrent import futures + +import imageio.v3 as imageio +import nifty.tools as nt +from tqdm import tqdm + +from elf.wrapper.resized_volume import ResizedVolume +from elf.io import open_file + + +def main(input_path, output_folder, scale, input_key, interpolation_order): + if input_path.endswith(".tif"): + input_ = imageio.imread(input_path) + input_chunks = (128,) * 3 + else: + input_ = open_file(input_path, "r")[input_key] + input_chunks = input_.chunks + + abs_path = os.path.abspath(input_path) + basename = "".join(os.path.basename(abs_path).split(".")[:-1]) + output_path = os.path.join(output_folder, basename + "_resized.n5") + + shape = input_.shape + ndim = len(shape) + + # Limit the number of cores for parallelization. + n_threads = min(16, mp.cpu_count()) + + shape = input_.shape + new_shape = tuple( + int(round(sh / scale)) for sh in shape + ) + + resized_volume = ResizedVolume(input_, new_shape, order=interpolation_order) + + output = open_file(output_path, mode="a") + output_dataset = output.create_dataset( + input_key, shape=new_shape, dtype=input_.dtype, + chunks=input_chunks, compression="gzip" + ) + blocking = nt.blocking([0] * ndim, new_shape, input_chunks) + + def copy_chunk(block_index): + block = blocking.getBlock(block_index) + volume_index = tuple(slice(begin, end) for (begin, end) in zip(block.begin, block.end)) + data = resized_volume[volume_index] + output_dataset[volume_index] = data + + with futures.ThreadPoolExecutor(n_threads) as resize_pool: + list(tqdm(resize_pool.map(copy_chunk, range(blocking.numberOfBlocks)), total=blocking.numberOfBlocks)) + + +if __name__ == "__main__": + + parser = argparse.ArgumentParser( + description="Script for resizing microscoopy data in n5 format.") + + parser.add_argument('input_file', type=str, help="Input file") + parser.add_argument( + 'output_folder', type=str, help="Output folder. Default resized output is _resized.n5" + ) + + parser.add_argument('-s', "--scale", type=float, default=0.38, help="Scale of input. Re-scaled to 1.") + parser.add_argument('-k', "--input_key", type=str, default="setup0/timepoint0/s0", help="Input key for n5 file.") + parser.add_argument('-i', "--interpolation_order", type=float, default=3, help="Interpolation order.") + + args = parser.parse_args() + + main(args.input_file, args.output_folder, args.scale, args.input_key, args.interpolation_order)