Skip to content
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 108 additions & 10 deletions flamingo_tools/segmentation/unet_prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import vigra
import torch
import z5py
import json

from elf.wrapper import ThresholdWrapper, SimpleTransformationWrapper
from elf.wrapper.resized_volume import ResizedVolume
Expand All @@ -18,6 +19,11 @@
from torch_em.util.prediction import predict_with_halo
from tqdm import tqdm

"""
Prediction using distance U-Net.
Parallelization using multiple GPUs is currently only possible by calling functions directly.
Functions for the parallelization end with '_slurm' and divide the process into preprocessing, prediction, and segmentation.
"""

class SelectChannel(SimpleTransformationWrapper):
def __init__(self, volume, channel):
Expand All @@ -37,13 +43,13 @@ def ndim(self):
return self._volume.ndim - 1


def prediction_impl(input_path, input_key, output_folder, model_path, scale, block_shape, halo):
def prediction_impl(input_path, input_key, output_folder, model_path, scale, block_shape, halo, prediction_instances=1, slurm_task_id=0, mean=None, std=None):
with warnings.catch_warnings():
warnings.simplefilter("ignore")
if os.path.isdir(model_path):
model = load_model(model_path)
else:
model = torch.load(model_path)
model = torch.load(model_path, weights_only=False)

mask_path = os.path.join(output_folder, "mask.zarr")
image_mask = z5py.File(mask_path, "r")["mask"]
Expand All @@ -66,23 +72,25 @@ def prediction_impl(input_path, input_key, output_folder, model_path, scale, blo
image_mask = ResizedVolume(image_mask, new_shape, order=0)

have_cuda = torch.cuda.is_available()

if block_shape is None:
block_shape = tuple([2 * ch for ch in input_.chunks]) if have_cuda else input_.chunks
block_shape = (128, 128, 128) if have_cuda else input_.chunks
if halo is None:
halo = (16, 64, 64) if have_cuda else (16, 32, 32)
halo = (16, 32, 32)
if have_cuda:
print("Predict with GPU")
gpu_ids = [0]
else:
print("Predict with CPU")
gpu_ids = ["cpu"]

# Compute the global mean and standard deviation.
n_threads = min(16, mp.cpu_count())
mean, std = parallel.mean_and_std(
input_, block_shape=block_shape, n_threads=n_threads, verbose=True,
mask=image_mask
)
if None == mean or None == std:
# Compute the global mean and standard deviation.
n_threads = min(16, mp.cpu_count())
mean, std = parallel.mean_and_std(
input_, block_shape=block_shape, n_threads=n_threads, verbose=True,
mask=image_mask
)
print("Mean and standard deviation computed for the full volume:")
print(mean, std)

Expand All @@ -98,6 +106,18 @@ def postprocess(x):
x[1] = vigra.filters.gaussianSmoothing(x[1], sigma=2.0)
return x

shape = input_.shape
ndim = len(shape)

blocking = nt.blocking([0] * ndim, shape, block_shape)
n_blocks = blocking.numberOfBlocks
iteration_ids = []
if 1 != prediction_instances:
iteration_ids = [x.tolist() for x in np.array_split(list(range(n_blocks)), prediction_instances)]
slurm_iteration = iteration_ids[slurm_task_id]
else:
slurm_iteration = list(range(n_blocks))

output_path = os.path.join(output_folder, "predictions.zarr")
with open_file(output_path, "a") as f:
output = f.require_dataset(
Expand All @@ -113,6 +133,7 @@ def postprocess(x):
gpu_ids=gpu_ids, block_shape=block_shape, halo=halo,
output=output, preprocess=preprocess, postprocess=postprocess,
mask=image_mask,
iter_list=slurm_iteration,
)

return original_shape
Expand Down Expand Up @@ -223,6 +244,30 @@ def write_block(block_id):
tp.map(write_block, range(blocking.numberOfBlocks))


def calc_mean_and_std(input_path, input_key, output_folder):
"""
Calculate mean and standard deviation of full volume.
Parameters are saved in 'mean_std.json' within the output folder.
"""
json_file = os.path.join(output_folder, "mean_std.json")
mask_path = os.path.join(output_folder, "mask.zarr")
image_mask = z5py.File(mask_path, "r")["mask"]

if input_key is None:
input_ = imageio.imread(input_path)
else:
input_ = open_file(input_path, "r")[input_key]

# Compute the global mean and standard deviation.
n_threads = min(16, mp.cpu_count())
mean, std = parallel.mean_and_std(
input_, block_shape=tuple([2* i for i in input_.chunks]), n_threads=n_threads, verbose=True,
mask=image_mask
)
ddict = {"mean":str(mean), "std": str(std)}
with open(json_file, "w") as f:
json.dump(ddict, f)

def run_unet_prediction(
input_path, input_key,
output_folder, model_path,
Expand All @@ -239,3 +284,56 @@ def run_unet_prediction(

pmap_out = os.path.join(output_folder, "predictions.zarr")
segmentation_impl(pmap_out, output_folder, min_size=min_size, original_shape=original_shape)

#---Workflow for parallel prediction using slurm---

def run_unet_prediction_preprocess_slurm(
input_path, input_key, output_folder,
):
"""
Pre-processing for the parallel prediction with U-Net models.
Masks are stored in mask.zarr in the output folder.
The mean and standard deviation are precomputed for later usage during prediction
and stored in a JSON file within the output folder as mean_std.json
"""
find_mask(input_path, input_key, output_folder)
calc_mean_and_std(input_path, input_key, output_folder)

def run_unet_prediction_slurm(
input_path, input_key, output_folder, model_path,
scale=None,
block_shape=None, halo=None, prediction_instances=1,
):
os.makedirs(output_folder, exist_ok=True)
prediction_instances = int(prediction_instances)
slurm_task_id = os.environ.get("SLURM_ARRAY_TASK_ID")

if slurm_task_id is not None:
slurm_task_id = int(slurm_task_id)
else:
raise ValueError("The SLURM_ARRAY_TASK_ID is not set. Ensure that you are using the '-a' option with SBATCH.")

if not os.path.isdir(os.path.join(output_folder, "mask.zarr")):
find_mask(input_path, input_key, output_folder)

# get pre-computed mean and standard deviation of full volume from JSON file
if os.path.isfile(os.path.join(output_folder, "mean_std.json")):
with open(os.path.join(output_folder, "mean_std.json")) as f:
d = json.load(f)
mean = float(d["mean"])
std = float(d["std"])
else:
mean = None
std = None

original_shape = prediction_impl(
input_path, input_key, output_folder, model_path, scale, block_shape, halo,
prediction_instances=prediction_instances, slurm_task_id=slurm_task_id,
mean=mean, std=std,
)

# does NOT need GPU, FIXME: only run on CPU
def run_unet_segmentation_slurm(output_folder, min_size):
min_size = int(min_size)
pmap_out = os.path.join(output_folder, "predictions.zarr")
segmentation_impl(pmap_out, output_folder, min_size=min_size)
40 changes: 40 additions & 0 deletions scripts/convert_tif_to_n5.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import os, sys
import argparse
import pybdv
import imageio.v3 as imageio


def main(input_path, output_path):
"""
Convert tif file to n5 format.
If no output_path is supplied, the output file is created in the same directory as the input.
:param str input_path: Input tif
:param str output_path: Output path for n5 format
"""
if not os.path.isfile(input_path):
sys.exit("Input file does not exist.")

if input_path.split(".")[-1] not in ["TIFF", "TIF", "tiff", "tif"]:
sys.exit("Input file must be in tif format.")

basename = "".join(input_path.split("/")[-1].split(".")[:-1])
input_dir = input_path.split(basename)[0]
input_dir = os.path.abspath(input_dir)

if "" == output_path:
output_path = os.path.join(input_dir, basename + ".n5")
img = imageio.imread(input_path)
pybdv.make_bdv(img, output_path)

if __name__ == "__main__":

parser = argparse.ArgumentParser(
description="Script to transform file from tif into n5 format.")

parser.add_argument('input', type=str, help="Input file")
parser.add_argument('-o', "--output", type=str, default="", help="Output file. Default: <basename>.n5")

args = parser.parse_args()

main(args.input, args.output)
110 changes: 110 additions & 0 deletions scripts/extract_block.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
import os
import argparse
import numpy as np
import z5py
import zarr

import s3fs

"""
This script extracts data around an input center coordinate in a given ROI halo.

The support for using an S3 bucket is currently limited to the lightsheet-cochlea bucket with the endpoint url https://s3.fs.gwdg.de.
If more use cases appear, the script will be generalized.
The usage requires the export of the access and the secret access key within the environment before executing the script.
run the following commands in the shell of your choice, or add them to your ~/.bashrc:
export AWS_ACCESS_KEY_ID=<access key>
export AWS_SECRET_ACCESS_KEY=<secret access key>
"""


def main(input_file, output_dir, input_key, resolution, coords, roi_halo, s3):
"""

:param str input_file: File path to input folder in n5 format
:param str output_dir: output directory for saving cropped n5 file as <basename>_crop.n5
:param str input_key: Key for accessing volume in n5 format, e.g. 'setup0/s0'
:param float resolution: Resolution of input data in micrometer
:param str coords: Center coordinates of extracted 3D volume in format 'x,y,z'
:param str roi_halo: ROI halo of extracted 3D volume in format 'x,y,z'
:param bool s3: Flag for using an S3 bucket
"""

coords = [int(r) for r in coords.split(",")]
roi_halo = [int(r) for r in roi_halo.split(",")]

coord_string = "-".join([str(c) for c in coords])

# Dimensions are inversed to view in MoBIE (x y z) -> (z y x)
coords.reverse()
roi_halo.reverse()

input_content = list(filter(None, input_file.split("/")))

if s3:
basename = input_content[0] + "_" + input_content[-1].split(".")[0]
else:
basename = "".join(input_content[-1].split(".")[:-1])

input_dir = input_file.split(basename)[0]
input_dir = os.path.abspath(input_dir)

if output_dir == "":
output_dir = input_dir

output_file = os.path.join(output_dir, basename + "_crop_" + coord_string + ".n5")

coords = np.array(coords)
coords = coords / resolution
coords = np.round(coords).astype(np.int32)

roi = tuple(slice(co - rh, co + rh) for co, rh in zip(coords, roi_halo))

if s3:

# Define S3 bucket and OME-Zarr dataset path

bucket_name = "cochlea-lightsheet"
zarr_path = f"{bucket_name}/{input_file}"

# Create an S3 filesystem
fs = s3fs.S3FileSystem(
client_kwargs={"endpoint_url": "https://s3.fs.gwdg.de"},
anon=False
)

if not fs.exists(zarr_path):
print("Error: Path does not exist!")

# Open the OME-Zarr dataset
store = zarr.storage.FSStore(zarr_path, fs=fs)
print(f"Opening file {zarr_path} from the S3 bucket.")

with zarr.open(store, mode="r") as f:
raw = f[input_key][roi]

else:
with z5py.File(input_file, "r") as f:
raw = f[input_key][roi]

with z5py.File(output_file, "w") as f_out:
f_out.create_dataset("raw", data=raw, compression="gzip")

if __name__ == "__main__":

parser = argparse.ArgumentParser(
description="Script to extract region of interest (ROI) block around center coordinate.")

parser.add_argument('input', type=str, help="Input file in n5 format.")
parser.add_argument('-o', "--output", type=str, default="", help="Output directory")
parser.add_argument('-c', "--coord", type=str, required=True, help="3D coordinate in format 'x,y,z' as center of extracted block.")

parser.add_argument('-k', "--input_key", type=str, default="setup0/timepoint0/s0", help="Input key for data in input file")
parser.add_argument('-r', "--resolution", type=float, default=0.38, help="Resolution of input in micrometer")

parser.add_argument("--roi_halo", type=str, default="128,128,64", help="ROI halo around center coordinate in format 'x,y,z'")
parser.add_argument("--s3", action="store_true", help="Use S3 bucket")

args = parser.parse_args()

main(args.input, args.output, args.input_key, args.resolution, args.coord, args.roi_halo, args.s3)
32 changes: 32 additions & 0 deletions scripts/prediction/count_cells.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import argparse
import os
import sys

from elf.parallel import unique
from elf.io import open_file

sys.path.append("../..")


def main():
parser = argparse.ArgumentParser()
parser.add_argument("-o", "--output_folder", type=str, required=True, help="Output directory containing segmentation.zarr")
parser.add_argument("-m", "--min_size", type=int, default=1000, help="Minimal number of voxel size for counting object")
args = parser.parse_args()

seg_path = os.path.join(args.output_folder, "segmentation.zarr")
seg_key = "segmentation"

file = open_file(seg_path, mode='r')
dataset = file[seg_key]

ids, counts = unique(dataset, return_counts=True)

# You can change the minimal size for objects to be counted here:
min_size = args.min_size

counts = counts[counts > min_size]
print("Number of objects:", len(counts))

if __name__ == "__main__":
main()
5 changes: 5 additions & 0 deletions scripts/prediction/run_prediction_distance_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,11 @@

sys.path.append("../..")

"""
Prediction using distance U-Net.
Parallelization using multiple GPUs is currently only possible by calling functions located in segmentation/unet_prediction.py directly.
Functions for the parallelization end with '_slurm' and divide the process into preprocessing, prediction, and segmentation.
"""

def main():
from flamingo_tools.segmentation import run_unet_prediction
Expand Down
Loading
Loading