Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
166 changes: 111 additions & 55 deletions flamingo_tools/segmentation/unet_prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ def prediction_impl(
slurm_task_id=0,
mean=None,
std=None,
mask=None
):
"""@private
"""
Expand All @@ -79,18 +80,20 @@ def prediction_impl(

input_ = read_image_data(input_path, input_key)
chunks = getattr(input_, "chunks", (64, 64, 64))
mask_path = os.path.join(output_folder, "mask.zarr")

if os.path.exists(mask_path):
image_mask = z5py.File(mask_path, "r")["mask"]
# resize mask
image_shape = input_.shape
mask_shape = image_mask.shape
if image_shape != mask_shape:
image_mask = ResizedVolume(image_mask, image_shape, order=0)

if output_folder is None:
image_mask = mask
else:
image_mask = None
mask_path = os.path.join(output_folder, "mask.zarr")
if os.path.exists(mask_path):
image_mask = z5py.File(mask_path, "r")["mask"]
# resize mask
image_shape = input_.shape
mask_shape = image_mask.shape
if image_shape != mask_shape:
image_mask = ResizedVolume(image_mask, image_shape, order=0)
else:
image_mask = mask

if scale is None or np.isclose(scale, 1):
original_shape = None
Expand Down Expand Up @@ -162,16 +165,8 @@ def postprocess(x):
else:
slurm_iteration = list(range(n_blocks))

output_path = os.path.join(output_folder, "predictions.zarr")
with open_file(output_path, "a") as f:
output = f.require_dataset(
"prediction",
shape=output_shape,
chunks=output_chunks,
compression="gzip",
dtype="float32",
)

if output_folder is None:
output = np.zeros(output_shape, dtype=np.float32)
predict_with_halo(
input_, model,
gpu_ids=gpu_ids, block_shape=block_shape, halo=halo,
Expand All @@ -180,10 +175,37 @@ def postprocess(x):
iter_list=slurm_iteration,
)

return original_shape
else:
output_path = os.path.join(output_folder, "predictions.zarr")
with open_file(output_path, "a") as f:
output = f.require_dataset(
"prediction",
shape=output_shape,
chunks=output_chunks,
compression="gzip",
dtype="float32",
)

predict_with_halo(
input_, model,
gpu_ids=gpu_ids, block_shape=block_shape, halo=halo,
output=output, preprocess=preprocess, postprocess=postprocess,
mask=image_mask,
iter_list=slurm_iteration,
)

if output_folder is None:
return original_shape, output
else:
return original_shape, None


def find_mask(input_path: str, input_key: Optional[str], output_folder: str, seg_class: Optional[str] = "sgn") -> None:
def find_mask(
input_path: str,
input_key: Optional[str],
output_folder: Optional[str],
seg_class: Optional[str] = "sgn"
) -> None:
"""Determine the mask for running prediction.
The mask corresponds to data that contains actual signal and not just noise.
Expand All @@ -197,9 +219,6 @@ def find_mask(input_path: str, input_key: Optional[str], output_folder: str, seg
output_folder: The output folder for storing the mask data.
seg_class: Specifier for exclusion criterias for mask generation.
"""
mask_path = os.path.join(output_folder, "mask.zarr")
f = z5py.File(mask_path, "a")

# set parameters for the exclusion of chunks within mask generation
if seg_class == "sgn":
upper_percentile = 95
Expand All @@ -214,18 +233,24 @@ def find_mask(input_path: str, input_key: Optional[str], output_folder: str, seg
min_intensity = 200
print("Calculating mask with default values.")

mask_key = "mask"
if mask_key in f:
return

raw = read_image_data(input_path, input_key)
chunks = getattr(raw, "chunks", (64, 64, 64))

block_shape = tuple(2 * ch for ch in chunks)
blocking = nt.blocking([0, 0, 0], raw.shape, block_shape)
n_blocks = blocking.numberOfBlocks

ds_mask = f.create_dataset(mask_key, shape=raw.shape, compression="gzip", dtype="uint8", chunks=block_shape)
if output_folder is None:
ds_mask = np.zeros(raw.shape, dtype=np.uint64)

else:
mask_path = os.path.join(output_folder, "mask.zarr")
f = z5py.File(mask_path, "a")
mask_key = "mask"
if mask_key in f:
return

ds_mask = f.create_dataset(mask_key, shape=raw.shape, compression="gzip", dtype="uint8", chunks=block_shape)

# TODO more sophisticated criterion?!
def find_mask_block(block_id):
Expand All @@ -240,15 +265,20 @@ def find_mask_block(block_id):
with futures.ThreadPoolExecutor(n_threads) as tp:
list(tqdm(tp.map(find_mask_block, range(n_blocks)), total=n_blocks))

if output_folder is None:
return ds_mask
else:
return None


def distance_watershed_implementation(
input_path: str,
output_folder: str,
min_size: int,
output_folder: Optional[str] = None,
min_size: int = 1000,
center_distance_threshold: float = 0.4,
boundary_distance_threshold: Optional[float] = None,
fg_threshold: float = 0.5,
original_shape: Optional[Tuple[int, int, int]] = None,
original_shape: Optional[Tuple[int, int, int]] = None
) -> None:
"""Parallel implementation of the distance-prediction based watershed.
Expand All @@ -262,7 +292,10 @@ def distance_watershed_implementation(
fg_threshold: The threshold applied to the foreground prediction for deriving the watershed mask.
original_shape: The original shape to resize the segmentation to.
"""
input_ = open_file(input_path, "r")["prediction"]
if isinstance(input_path, str):
input_ = open_file(input_path, "r")["prediction"]
else:
input_ = input_path

# Limit the number of cores for parallelization.
n_threads = min(16, mp.cpu_count())
Expand All @@ -280,13 +313,17 @@ def distance_watershed_implementation(
# center_distances = SimpleTransformationWrapper(center_distances, transformation=smoothing)
# boundary_distances = SimpleTransformationWrapper(boundary_distances, transformation=smoothing)

# Allocate an zarr array for the seeds.
block_shape = center_distances.chunks
seed_path = os.path.join(output_folder, "seeds.zarr")
seed_file = open_file(os.path.join(seed_path), "a")
seeds = seed_file.require_dataset(
"seeds", shape=center_distances.shape, chunks=block_shape, compression="gzip", dtype="uint64"
)
# Allocate the (zarr) array for the seeds.
if output_folder is None:
block_shape = (20, 128, 128)
seeds = np.zeros(center_distances.shape, dtype=np.uint64)
else:
block_shape = center_distances.chunks
seed_path = os.path.join(output_folder, "seeds.zarr")
seed_file = open_file(os.path.join(seed_path), "a")
seeds = seed_file.require_dataset(
"seeds", shape=center_distances.shape, chunks=block_shape, compression="gzip", dtype="uint64"
)

# Compute the seed inputs:
# First, threshold the center distances.
Expand All @@ -301,12 +338,15 @@ def distance_watershed_implementation(
data=seed_inputs, out=seeds, block_shape=block_shape, mask=mask, verbose=True, n_threads=n_threads
)

# Allocate the zarr array for the segmentation.
seg_path = os.path.join(output_folder, "segmentation.zarr" if original_shape is None else "seg_downscaled.zarr")
seg_file = open_file(seg_path, "a")
seg = seg_file.create_dataset(
"segmentation", shape=seeds.shape, chunks=block_shape, compression="gzip", dtype="uint64"
)
# Allocate the (zarr) array for the segmentation.
if output_folder is None:
seg = np.zeros(seeds.shape, dtype=np.uint64)
else:
seg_path = os.path.join(output_folder, "segmentation.zarr" if original_shape is None else "seg_downscaled.zarr")
seg_file = open_file(seg_path, "a")
seg = seg_file.create_dataset(
"segmentation", shape=seeds.shape, chunks=block_shape, compression="gzip", dtype="uint64"
)

# Compute the segmentation with a seeded watershed
halo = (2, 8, 8)
Expand Down Expand Up @@ -341,6 +381,11 @@ def write_block(block_id):
with futures.ThreadPoolExecutor(n_threads) as tp:
tp.map(write_block, range(blocking.numberOfBlocks))

if output_folder is None:
return seg
else:
return None


def calc_mean_and_std(input_path: str, input_key: str, output_folder: str) -> None:
"""Calculate mean and standard deviation of the input volume.
Expand Down Expand Up @@ -372,7 +417,7 @@ def calc_mean_and_std(input_path: str, input_key: str, output_folder: str) -> No
def run_unet_prediction(
input_path: str,
input_key: Optional[str],
output_folder: str,
output_folder: Optional[str],
model_path: str,
min_size: int,
scale: Optional[float] = None,
Expand Down Expand Up @@ -403,22 +448,33 @@ def run_unet_prediction(
fg_threshold: The threshold applied to the foreground prediction for deriving the watershed mask.
seg_class: Specifier for exclusion criterias for mask generation.
"""
os.makedirs(output_folder, exist_ok=True)
if output_folder is not None:
os.makedirs(output_folder, exist_ok=True)

if use_mask:
find_mask(input_path, input_key, output_folder, seg_class=seg_class)
original_shape = prediction_impl(
input_path, input_key, output_folder, model_path, scale, block_shape, halo
mask = find_mask(input_path, input_key, output_folder=output_folder, seg_class=seg_class)
else:
mask = None

original_shape, prediction = prediction_impl(
input_path=input_path, input_key=input_key, output_folder=output_folder, model_path=model_path, scale=scale,
block_shape=block_shape, halo=halo, mask=mask
)

pmap_out = os.path.join(output_folder, "predictions.zarr")
distance_watershed_implementation(
if output_folder is None:
pmap_out = prediction
else:
pmap_out = os.path.join(output_folder, "predictions.zarr")

segmentation = distance_watershed_implementation(
pmap_out, output_folder, min_size=min_size, original_shape=original_shape,
center_distance_threshold=center_distance_threshold,
boundary_distance_threshold=boundary_distance_threshold,
fg_threshold=fg_threshold,
fg_threshold=fg_threshold
)

return segmentation


#
# ---Workflow for parallel prediction using slurm---
Expand Down
71 changes: 62 additions & 9 deletions scripts/prediction/run_prediction_distance_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,11 @@
prediction, and segmentation.
"""
import argparse
import json
import time
import os

import imageio.v3 as imageio
import torch
import z5py

Expand All @@ -19,7 +23,20 @@ def main():
parser.add_argument("-m", "--model", required=True)
parser.add_argument("-k", "--input_key", default=None)
parser.add_argument("-s", "--scale", default=None, type=float, help="Downscale the image by the given factor.")
parser.add_argument("-b", "--block_shape", default=None, type=int, nargs=3)
parser.add_argument("-b", "--block_shape", default=None, type=str)
parser.add_argument("--halo", default=None, type=str)
parser.add_argument("--memory", action="store_true", help="Perform prediction in memory and save output as tif.")
parser.add_argument("--time", action="store_true", help="Time prediction process.")
parser.add_argument("--seg_class", default=None, type=str,
help="Segmentation class to load parameters for masking input.")
parser.add_argument("--center_distance_threshold", default=0.4, type=float,
help="The threshold applied to the distance center predictions to derive seeds.")
parser.add_argument("--boundary_distance_threshold", default=None, type=float,
help="The threshold applied to the boundary predictions to derive seeds. \
By default this is set to 'None', \
in which case the boundary distances are not used for the seeds.")
parser.add_argument("--fg_threshold", default=0.5, type=float,
help="The threshold applied to the foreground prediction for deriving the watershed mask.")

args = parser.parse_args()

Expand All @@ -36,21 +53,57 @@ def main():
if args.block_shape is None:
block_shape = (64, 256, 256) if have_cuda else (64, 64, 64)
else:
block_shape = tuple(args.block_shape)
halo = (16, 64, 64) if have_cuda else (8, 32, 32)
block_shape = tuple(json.loads(args.block_shape))

else:
if args.block_shape is None:
chunks = z5py.File(args.input, "r")[args.input_key].chunks
block_shape = tuple([2 * ch for ch in chunks]) if have_cuda else tuple(chunks)
else:
block_shape = tuple(args.block_shape)
block_shape = json.loads(args.block_shape)

if args.halo is None:
halo = (16, 64, 64) if have_cuda else (8, 32, 32)
else:
halo = tuple(json.loads(args.halo))

if args.time:
start = time.perf_counter()

if args.memory:
segmentation = run_unet_prediction(
args.input, args.input_key, output_folder=None, model_path=args.model,
scale=scale, min_size=min_size,
block_shape=block_shape, halo=halo,
seg_class=args.seg_class,
center_distance_threshold = args.center_distance_threshold,
boundary_distance_threshold = args.boundary_distance_threshold,
fg_threshold = args.fg_threshold,
)

abs_path = os.path.abspath(args.input)
basename = ".".join(os.path.basename(abs_path).split(".")[:-1])
output_path = os.path.join(args.output_folder, basename + "_seg.tif")
imageio.imwrite(output_path, segmentation, compression="zlib")
timer_output = os.path.join(args.output_folder, basename + "_timer.json")

else:
run_unet_prediction(
args.input, args.input_key, output_folder=args.output_folder, model_path=args.model,
scale=scale, min_size=min_size,
block_shape=block_shape, halo=halo,
seg_class=args.seg_class,
center_distance_threshold = args.center_distance_threshold,
boundary_distance_threshold = args.boundary_distance_threshold,
fg_threshold = args.fg_threshold,
)
timer_output = os.path.join(args.output_folder, "timer.json")

run_unet_prediction(
args.input, args.input_key, args.output_folder, args.model,
scale=scale, min_size=min_size,
block_shape=block_shape, halo=halo,
)
if args.time:
duration = time.perf_counter() - start
time_dict = {"total_duration[s]": duration}
with open(timer_output, "w") as f:
json.dump(time_dict, f, indent='\t', separators=(',', ': '))


if __name__ == "__main__":
Expand Down