Skip to content
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion flamingo_tools/segmentation/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from .unet_prediction import run_unet_prediction
from .unet_prediction import run_unet_prediction, run_unet_prediction_slurm
from .postprocessing import filter_isolated_objects
118 changes: 104 additions & 14 deletions flamingo_tools/segmentation/unet_prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import vigra
import torch
import z5py
import json

from elf.wrapper import ThresholdWrapper, SimpleTransformationWrapper
from elf.wrapper.resized_volume import ResizedVolume
Expand Down Expand Up @@ -37,13 +38,13 @@ def ndim(self):
return self._volume.ndim - 1


def prediction_impl(input_path, input_key, output_folder, model_path, scale, block_shape, halo):
def prediction_impl(input_path, input_key, output_folder, model_path, scale, block_shape, halo, prediction_instances=1, slurm_task_id=0, mean=None, std=None):
with warnings.catch_warnings():
warnings.simplefilter("ignore")
if os.path.isdir(model_path):
model = load_model(model_path)
else:
model = torch.load(model_path)
model = torch.load(model_path, weights_only=False)

mask_path = os.path.join(output_folder, "mask.zarr")
image_mask = z5py.File(mask_path, "r")["mask"]
Expand All @@ -65,24 +66,27 @@ def prediction_impl(input_path, input_key, output_folder, model_path, scale, blo
input_ = ResizedVolume(input_, shape=new_shape, order=3)
image_mask = ResizedVolume(image_mask, new_shape, order=0)

chunks = (128, 128, 128)
block_shape = chunks

have_cuda = torch.cuda.is_available()
if block_shape is None:
block_shape = tuple([2 * ch for ch in input_.chunks]) if have_cuda else input_.chunks
if halo is None:
halo = (16, 64, 64) if have_cuda else (16, 32, 32)
assert have_cuda
if have_cuda:
print("Predict with GPU")
gpu_ids = [0]
else:
print("Predict with CPU")
gpu_ids = ["cpu"]

# Compute the global mean and standard deviation.
n_threads = min(16, mp.cpu_count())
mean, std = parallel.mean_and_std(
input_, block_shape=block_shape, n_threads=n_threads, verbose=True,
mask=image_mask
)
if halo is None:
halo = (16, 32, 32)

if None == mean or None == std:
# Compute the global mean and standard deviation.
n_threads = min(16, mp.cpu_count())
mean, std = parallel.mean_and_std(
input_, block_shape=tuple([2* i for i in input_.chunks]), n_threads=n_threads, verbose=True,
mask=image_mask
)
print("Mean and standard deviation computed for the full volume:")
print(mean, std)

Expand All @@ -98,12 +102,24 @@ def postprocess(x):
x[1] = vigra.filters.gaussianSmoothing(x[1], sigma=2.0)
return x

shape = input_.shape
ndim = len(shape)

blocking = nt.blocking([0] * ndim, shape, block_shape)
n_blocks = blocking.numberOfBlocks
iteration_ids = []
if 1 != prediction_instances:
iteration_ids = [x.tolist() for x in np.array_split(list(range(n_blocks)), prediction_instances)]
slurm_iteration = iteration_ids[slurm_task_id]
else:
slurm_iteration = list(range(n_blocks))

output_path = os.path.join(output_folder, "predictions.zarr")
with open_file(output_path, "a") as f:
output = f.require_dataset(
"prediction",
shape=(3,) + input_.shape,
chunks=(1,) + block_shape,
chunks=(1,) + chunks,
compression="gzip",
dtype="float32",
)
Expand All @@ -113,6 +129,7 @@ def postprocess(x):
gpu_ids=gpu_ids, block_shape=block_shape, halo=halo,
output=output, preprocess=preprocess, postprocess=postprocess,
mask=image_mask,
iter_list=slurm_iteration,
)

return original_shape
Expand Down Expand Up @@ -223,6 +240,30 @@ def write_block(block_id):
tp.map(write_block, range(blocking.numberOfBlocks))


def calc_mean_and_std(input_path, input_key, output_folder):
"""
Calculate mean and standard deviation of full volume.
Parameters are saved in 'mean_std.json' within the output folder.
"""
json_file = os.path.join(output_folder, "mean_std.json")
mask_path = os.path.join(output_folder, "mask.zarr")
image_mask = z5py.File(mask_path, "r")["mask"]

if input_key is None:
input_ = imageio.imread(input_path)
else:
input_ = open_file(input_path, "r")[input_key]

# Compute the global mean and standard deviation.
n_threads = min(16, mp.cpu_count())
mean, std = parallel.mean_and_std(
input_, block_shape=tuple([2* i for i in input_.chunks]), n_threads=n_threads, verbose=True,
mask=image_mask
)
ddict = {"mean":str(mean), "std": str(std)}
with open(json_file, "w") as f:
json.dump(ddict, f)

def run_unet_prediction(
input_path, input_key,
output_folder, model_path,
Expand All @@ -239,3 +280,52 @@ def run_unet_prediction(

pmap_out = os.path.join(output_folder, "predictions.zarr")
segmentation_impl(pmap_out, output_folder, min_size=min_size, original_shape=original_shape)

def run_unet_prediction_slurm_preprocess(
input_path, input_key, output_folder,
):
"""
Pre-processing for the parallel prediction with U-Net models.
Masks are stored in mask.zarr in the output folder.
The mean and standard deviation are precomputed for later usage during prediction
and stored in a JSON file within the output folder as mean_std.json
"""
find_mask(input_path, input_key, output_folder)
calc_mean_and_std(input_path, input_key, output_folder)

def run_unet_prediction_slurm(
input_path, input_key, output_folder, model_path,
scale=None,
block_shape=None, halo=None, prediction_instances=1,
):
os.makedirs(output_folder, exist_ok=True)
prediction_instances = int(prediction_instances)
slurm_task_id = os.environ.get("SLURM_ARRAY_TASK_ID")

if slurm_task_id is not None:
slurm_task_id = int(slurm_task_id)
else:
raise ValueError("The SLURM_ARRAY_TASK_ID is not set. Ensure that you are using the '-a' option with SBATCH.")

if not os.path.isdir(os.path.join(output_folder, "mask.zarr")):
find_mask(input_path, input_key, output_folder)

# get pre-computed mean and standard deviation of full volume from JSON file
if os.path.isfile(os.path.join(output_folder, "mean_std.json")):
with open(os.path.join(output_folder, "mean_std.json")) as f:
d = json.load(f)
mean = float(d["mean"])
std = float(d["std"])
else:
mean = None
std = None

original_shape = prediction_impl(
input_path, input_key, output_folder, model_path, scale, block_shape, halo, prediction_instances, slurm_task_id, mean=mean, std=std
)

# does NOT need GPU, FIXME: only run on CPU
def run_unet_segmentation_slurm(output_folder, min_size):
min_size = int(min_size)
pmap_out = os.path.join(output_folder, "predictions.zarr")
segmentation_impl(pmap_out, output_folder, min_size=min_size)
33 changes: 33 additions & 0 deletions scripts/convert_tif_to_n5.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import os, sys
import argparse
import pybdv
import imageio.v3 as imageio


def main(input_path, output_path):
if not os.path.isfile(input_path):
sys.exit("Input file does not exist.")

if input_path.split(".")[-1] not in ["TIFF", "TIF", "tiff", "tif"]:
sys.exit("Input file must be in tif format.")

basename = "".join(input_path.split("/")[-1].split(".")[:-1])
input_dir = input_path.split(basename)[0]
input_dir = os.path.abspath(input_dir)

if "" == output_path:
output_path = os.path.join(input_dir, basename + ".n5")
img = imageio.imread(input_path)
pybdv.make_bdv(img, output_path)

if __name__ == "__main__":

parser = argparse.ArgumentParser(
description="Script to transform file from tif into n5 format.")

parser.add_argument('input', type=str, help="Input file")
parser.add_argument('-o', "--output", type=str, default="", help="Output file")

args = parser.parse_args()

main(args.input, args.output)
68 changes: 68 additions & 0 deletions scripts/extract_block.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import os
import argparse
import numpy as np
import h5py
import z5py

"""
This script extracts data around an input center coordinate in a given ROI halo.
"""


def main(input_file, output_dir, input_key, resolution, coords, roi_halo):
"""
:param str input_file: File path to input folder in n5 format
:param str output_dir: output directory for saving cropped n5 file as <basename>_crop.n5
:param str input_key: Key for accessing volume in n5 format, e.g. 'setup0/s0'
:param float resolution: Resolution of input data in micrometer
:param str coords: Center coordinates of extracted 3D volume in format 'z,y,x'
:param str roi_halo: ROI halo of extracted 3D volume in format 'z,y,x'
"""

coords = [int(r) for r in coords.split(",")]
roi_halo = [int(r) for r in roi_halo.split(",")]

input_content = list(filter(None, input_file.split("/")))
basename = "".join(input_content[-1].split(".")[:-1])
input_dir = input_file.split(basename)[0]
input_dir = os.path.abspath(input_dir)

if "" == output_dir:
output_dir = input_dir

input_key = "setup0/timepoint0/s0"

output_file = os.path.join(output_dir, basename + "_crop" + ".n5")

#M_LR_000167_R, coords = '806,1042,1334', coords = (z, y, x) compared to MoBIE view

coords = np.array(coords)
coords = coords / resolution
coords = np.round(coords).astype(np.int32)

roi = tuple(slice(co - rh, co + rh) for co, rh in zip(coords, roi_halo))

with z5py.File(input_file, "r") as f:
raw = f[input_key][roi]

with z5py.File(output_file, "w") as f_out:
f_out.create_dataset("raw", data=raw, compression="gzip")

if __name__ == "__main__":

parser = argparse.ArgumentParser(
description="Script to extract region of interest (ROI) block around center coordinate.")

parser.add_argument('input', type=str, help="Input file in n5 format.")
parser.add_argument('-o', "--output", type=str, default="", help="Output directory")
parser.add_argument('-c', "--coord", type=str, required=True, help="3D coordinate in format 'z,y,x' as center of extracted block. Dimensions are inversed to view in MoBIE (x y z) -> (z y x)")

parser.add_argument('-k', "--input_key", type=str, default="setup0/timepoint0/s0", help="Input key for data in input file")
parser.add_argument('-r', "--resolution", type=float, default=0.38, help="Resolution of input in micrometer")

parser.add_argument("--roi_halo", type=str, default="128,128,64", help="ROI halo around center coordinate")

args = parser.parse_args()

main(args.input, args.output, args.input_key, args.resolution, args.coord, args.roi_halo)
32 changes: 32 additions & 0 deletions scripts/prediction/count_cells.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import argparse
import os
import sys

from elf.parallel import unique
from elf.io import open_file

sys.path.append("../..")


def main():
parser = argparse.ArgumentParser()
parser.add_argument("-o", "--output_folder", type=str, required=True, help="Output directory containing segmentation.zarr")
parser.add_argument("-m", "--min_size", type=int, default=1000, help="Minimal number of voxel size for counting object")
args = parser.parse_args()

seg_path = os.path.join(args.output_folder, "segmentation.zarr")
seg_key = "segmentation"

file = open_file(seg_path, mode='r')
dataset = file[seg_key]

ids, counts = unique(dataset, return_counts=True)

# You can change the minimal size for objects to be counted here:
min_size = args.min_size

counts = counts[counts > min_size]
print("Number of objects:", len(counts))

if __name__ == "__main__":
main()
26 changes: 20 additions & 6 deletions scripts/prediction/run_prediction_distance_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,15 @@


def main():
from flamingo_tools.segmentation import run_unet_prediction
from flamingo_tools.segmentation import run_unet_prediction, run_unet_prediction_slurm

parser = argparse.ArgumentParser()
parser.add_argument("-i", "--input", required=True)
parser.add_argument("-o", "--output_folder", required=True)
parser.add_argument("-m", "--model", required=True)
parser.add_argument("-k", "--input_key", default=None)
parser.add_argument("-s", "--scale", default=None, type=float, help="Downscale the image by the given factor.")
parser.add_argument("-n", "--number_gpu", default=1, type=int, help="Number of GPUs to use in parallel.")

args = parser.parse_args()

Expand All @@ -36,11 +37,24 @@ def main():
block_shape = tuple([2 * ch for ch in chunks]) if have_cuda else tuple(chunks)
halo = (16, 64, 64) if have_cuda else (8, 32, 32)

run_unet_prediction(
args.input, args.input_key, args.output_folder, args.model,
scale=scale, min_size=min_size,
block_shape=block_shape, halo=halo,
)
prediction_instances = args.number_gpu if have_cuda else 1

if 1 > prediction_instances:
# FIXME: only does prediction part, no segmentation yet
# FIXME: implement array job
run_unet_prediction_slurm(
args.input, args.input_key, args.output_folder, args.model,
scale=scale,
block_shape=block_shape, halo=halo,
prediction_instances=prediction_instances,
)
else:

run_unet_prediction(
args.input, args.input_key, args.output_folder, args.model,
scale=scale, min_size=min_size,
block_shape=block_shape, halo=halo,
)


if __name__ == "__main__":
Expand Down
Loading
Loading