Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 57 additions & 11 deletions flamingo_tools/segmentation/unet_prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,13 @@ def ndim(self):
return self._volume.ndim - 1


def prediction_impl(input_path, input_key, output_folder, model_path, scale, block_shape, halo):
def prediction_impl(input_path, input_key, output_folder, model_path, scale, block_shape, halo, prediction_instances=1, slurm_task_id=0):
with warnings.catch_warnings():
warnings.simplefilter("ignore")
if os.path.isdir(model_path):
model = load_model(model_path)
else:
model = torch.load(model_path)
model = torch.load(model_path, weights_only=False)

mask_path = os.path.join(output_folder, "mask.zarr")
image_mask = z5py.File(mask_path, "r")["mask"]
Expand All @@ -65,22 +65,24 @@ def prediction_impl(input_path, input_key, output_folder, model_path, scale, blo
input_ = ResizedVolume(input_, shape=new_shape, order=3)
image_mask = ResizedVolume(image_mask, new_shape, order=0)

chunks = (128, 128, 128)
block_shape = chunks

have_cuda = torch.cuda.is_available()
if block_shape is None:
block_shape = tuple([2 * ch for ch in input_.chunks]) if have_cuda else input_.chunks
if halo is None:
halo = (16, 64, 64) if have_cuda else (16, 32, 32)
assert have_cuda
if have_cuda:
print("Predict with GPU")
gpu_ids = [0]
else:
print("Predict with CPU")
gpu_ids = ["cpu"]
if halo is None:
halo = (16, 32, 32)

# Compute the global mean and standard deviation.
n_threads = min(16, mp.cpu_count())
n_threads = min(2, mp.cpu_count())
mean, std = parallel.mean_and_std(
input_, block_shape=block_shape, n_threads=n_threads, verbose=True,
input_, block_shape=tuple([2* i for i in input_.chunks]), n_threads=n_threads, verbose=True,
mask=image_mask
)
print("Mean and standard deviation computed for the full volume:")
Expand All @@ -98,12 +100,24 @@ def postprocess(x):
x[1] = vigra.filters.gaussianSmoothing(x[1], sigma=2.0)
return x

shape = input_.shape
ndim = len(shape)

blocking = nt.blocking([0] * ndim, shape, block_shape)
n_blocks = blocking.numberOfBlocks
iteration_ids = []
if 1 != prediction_instances:
iteration_ids = [x.tolist() for x in np.array_split(list(range(n_blocks)), prediction_instances)]
slurm_iteration = iteration_ids[slurm_task_id]
else:
slurm_iteration = list(range(n_blocks))

output_path = os.path.join(output_folder, "predictions.zarr")
with open_file(output_path, "a") as f:
output = f.require_dataset(
"prediction",
shape=(3,) + input_.shape,
chunks=(1,) + block_shape,
chunks=(1,) + chunks,
compression="gzip",
dtype="float32",
)
Expand All @@ -113,6 +127,7 @@ def postprocess(x):
gpu_ids=gpu_ids, block_shape=block_shape, halo=halo,
output=output, preprocess=preprocess, postprocess=postprocess,
mask=image_mask,
iter_list=slurm_iteration,
)

return original_shape
Expand Down Expand Up @@ -228,14 +243,45 @@ def run_unet_prediction(
output_folder, model_path,
min_size, scale=None,
block_shape=None, halo=None,
prediction_instances=1,
):
if prediction_instances > 1:
run_unet_prediction_slurm(
input_path, input_key, output_folder, model_path,
scale=scale, block_shape=block_shape, halo=halo,
prediction_instances=prediction_instances,
)
else:
os.makedirs(output_folder, exist_ok=True)

find_mask(input_path, input_key, output_folder)

original_shape = prediction_impl(
input_path, input_key, output_folder, model_path, scale, block_shape, halo
)

pmap_out = os.path.join(output_folder, "predictions.zarr")
segmentation_impl(pmap_out, output_folder, min_size=min_size, original_shape=original_shape)

def run_unet_prediction_slurm(
input_path, input_key, output_folder, model_path,
scale=None,
block_shape=None, halo=None, prediction_instances=1,
):
os.makedirs(output_folder, exist_ok=True)
prediction_instances = int(prediction_instances)
slurm_task_id = os.environ.get("SLURM_ARRAY_TASK_ID")
if slurm_task_id is not None:
slurm_task_id = int(slurm_task_id)

find_mask(input_path, input_key, output_folder)

original_shape = prediction_impl(
input_path, input_key, output_folder, model_path, scale, block_shape, halo
input_path, input_key, output_folder, model_path, scale, block_shape, halo, prediction_instances, slurm_task_id
)

# does NOT need GPU, FIXME: only run on CPU
def run_unet_segmentation_slurm(output_folder, min_size):
min_size = int(min_size)
pmap_out = os.path.join(output_folder, "predictions.zarr")
segmentation_impl(pmap_out, output_folder, min_size=min_size, original_shape=original_shape)
segmentation_impl(pmap_out, output_folder, min_size=min_size)
33 changes: 33 additions & 0 deletions scripts/convert_tif_to_n5.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import os, sys
import argparse
import pybdv
import imageio.v3 as imageio


def main(input_path, output_path):
if not os.path.isfile(input_path):
sys.exit("Input file does not exist.")

if input_path.split(".")[-1] not in ["TIFF", "TIF", "tiff", "tif"]:
sys.exit("Input file must be in tif format.")

basename = "".join(input_path.split("/")[-1].split(".")[:-1])
input_dir = input_path.split(basename)[0]
input_dir = os.path.abspath(input_dir)

if "" == output_path:
output_path = os.path.join(input_dir, basename + ".n5")
img = imageio.imread(input_path)
pybdv.make_bdv(img, output_path)

if __name__ == "__main__":

parser = argparse.ArgumentParser(
description="Script to transform file from tif into n5 format.")

parser.add_argument('input', type=str, help="Input file")
parser.add_argument('-o', "--output", type=str, default="", help="Output file")

args = parser.parse_args()

main(args.input, args.output)
4 changes: 4 additions & 0 deletions scripts/prediction/run_prediction_distance_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ def main():
parser.add_argument("-m", "--model", required=True)
parser.add_argument("-k", "--input_key", default=None)
parser.add_argument("-s", "--scale", default=None, type=float, help="Downscale the image by the given factor.")
parser.add_argument("-n", "--number_gpu", default=1, type=int, help="Number of GPUs to use in parallel.")

args = parser.parse_args()

Expand All @@ -36,10 +37,13 @@ def main():
block_shape = tuple([2 * ch for ch in chunks]) if have_cuda else tuple(chunks)
halo = (16, 64, 64) if have_cuda else (8, 32, 32)

prediction_instances = args.number_gpu if have_cuda else 1

run_unet_prediction(
args.input, args.input_key, args.output_folder, args.model,
scale=scale, min_size=min_size,
block_shape=block_shape, halo=halo,
prediction_instances=prediction_instances,
)


Expand Down
63 changes: 63 additions & 0 deletions scripts/resize_wrongly_scaled_cochleas.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import argparse
import sys, os

import multiprocessing as mp
from concurrent import futures

import imageio.v3 as imageio
import numpy as np
import nifty.tools as nt
from tqdm import tqdm

from elf.wrapper.resized_volume import ResizedVolume
from elf.io import open_file

def main(input_path, output_folder, scale, input_key, interpolation_order):
input_ = open_file(input_path, "r")[input_key]

abs_path = os.path.abspath(input_path)
basename = "".join(os.path.basename(abs_path).split(".")[:-1])
output_path = os.path.join(output_folder, basename + "_resized.n5")

shape = input_.shape
ndim = len(shape)

# Limit the number of cores for parallelization.
n_threads = min(16, mp.cpu_count())

shape = input_.shape
new_shape = tuple(
int(round(sh / scale)) for sh in shape
)

resized_volume = ResizedVolume(input_, new_shape, order=interpolation_order)

output = open_file(output_path, mode="a")
dataset = output.create_dataset(input_key, shape=new_shape, dtype = input_.dtype, chunks=input_.chunks, compression="gzip")
blocking = nt.blocking([0] * ndim, new_shape, input_.chunks)

def copy_chunk(block_index):
block = blocking.getBlock(block_index)
volume_index = tuple(slice(begin, end) for (begin, end) in zip(block.begin, block.end))
data = resized_volume[volume_index]
output[volume_index] = data

with futures.ThreadPoolExecutor(n_threads) as resize_pool:
list(tqdm(resize_pool.map(copy_chunk, range(blocking.numberOfBlocks)), total=blocking.numberOfBlocks))


if __name__ == "__main__":

parser = argparse.ArgumentParser(
description="Script for resizing microscoopy data in n5 format.")

parser.add_argument('input_file', type=str, help="Input file")
parser.add_argument('output_folder', type=str, help="Output folder. Default resized output is <basename>_resized.n5")

parser.add_argument('-s', "--scale", type=float, default=0.38, help="Scale of input. Re-scaled to 1.")
parser.add_argument('-k', "--input_key", type=str, default="setup0/timepoint0/s0", help="Input key for n5 file.")
parser.add_argument('-i', "--interpolation_order", type=float, default=3, help="Interpolation order.")

args = parser.parse_args()

main(args.input, args.output, args.scale, args.input_key, args.interpolation_order)
Loading