Skip to content

Commit f12f9d3

Browse files
Merge pull request #22 from computational-cell-analytics/parallelize_prediction
Parallelize prediction
2 parents f1578b9 + 5a5207d commit f12f9d3

File tree

6 files changed

+367
-10
lines changed

6 files changed

+367
-10
lines changed

flamingo_tools/segmentation/unet_prediction.py

Lines changed: 107 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import vigra
1111
import torch
1212
import z5py
13+
import json
1314

1415
from elf.wrapper import ThresholdWrapper, SimpleTransformationWrapper
1516
from elf.wrapper.resized_volume import ResizedVolume
@@ -18,6 +19,11 @@
1819
from torch_em.util.prediction import predict_with_halo
1920
from tqdm import tqdm
2021

22+
"""
23+
Prediction using distance U-Net.
24+
Parallelization using multiple GPUs is currently only possible by calling functions directly.
25+
Functions for the parallelization end with '_slurm' and divide the process into preprocessing, prediction, and segmentation.
26+
"""
2127

2228
class SelectChannel(SimpleTransformationWrapper):
2329
def __init__(self, volume, channel):
@@ -37,13 +43,13 @@ def ndim(self):
3743
return self._volume.ndim - 1
3844

3945

40-
def prediction_impl(input_path, input_key, output_folder, model_path, scale, block_shape, halo):
46+
def prediction_impl(input_path, input_key, output_folder, model_path, scale, block_shape, halo, prediction_instances=1, slurm_task_id=0, mean=None, std=None):
4147
with warnings.catch_warnings():
4248
warnings.simplefilter("ignore")
4349
if os.path.isdir(model_path):
4450
model = load_model(model_path)
4551
else:
46-
model = torch.load(model_path)
52+
model = torch.load(model_path, weights_only=False)
4753

4854
mask_path = os.path.join(output_folder, "mask.zarr")
4955
image_mask = z5py.File(mask_path, "r")["mask"]
@@ -66,23 +72,25 @@ def prediction_impl(input_path, input_key, output_folder, model_path, scale, blo
6672
image_mask = ResizedVolume(image_mask, new_shape, order=0)
6773

6874
have_cuda = torch.cuda.is_available()
75+
6976
if block_shape is None:
70-
block_shape = tuple([2 * ch for ch in input_.chunks]) if have_cuda else input_.chunks
77+
block_shape = (128, 128, 128) if have_cuda else input_.chunks
7178
if halo is None:
72-
halo = (16, 64, 64) if have_cuda else (16, 32, 32)
79+
halo = (16, 32, 32)
7380
if have_cuda:
7481
print("Predict with GPU")
7582
gpu_ids = [0]
7683
else:
7784
print("Predict with CPU")
7885
gpu_ids = ["cpu"]
7986

80-
# Compute the global mean and standard deviation.
81-
n_threads = min(16, mp.cpu_count())
82-
mean, std = parallel.mean_and_std(
83-
input_, block_shape=block_shape, n_threads=n_threads, verbose=True,
84-
mask=image_mask
85-
)
87+
if mean is None or std is None:
88+
# Compute the global mean and standard deviation.
89+
n_threads = min(16, mp.cpu_count())
90+
mean, std = parallel.mean_and_std(
91+
input_, block_shape=block_shape, n_threads=n_threads, verbose=True,
92+
mask=image_mask
93+
)
8694
print("Mean and standard deviation computed for the full volume:")
8795
print(mean, std)
8896

@@ -98,6 +106,17 @@ def postprocess(x):
98106
x[1] = vigra.filters.gaussianSmoothing(x[1], sigma=2.0)
99107
return x
100108

109+
shape = input_.shape
110+
ndim = len(shape)
111+
112+
blocking = nt.blocking([0] * ndim, shape, block_shape)
113+
n_blocks = blocking.numberOfBlocks
114+
if prediction_instances != 1:
115+
iteration_ids = [x.tolist() for x in np.array_split(list(range(n_blocks)), prediction_instances)]
116+
slurm_iteration = iteration_ids[slurm_task_id]
117+
else:
118+
slurm_iteration = list(range(n_blocks))
119+
101120
output_path = os.path.join(output_folder, "predictions.zarr")
102121
with open_file(output_path, "a") as f:
103122
output = f.require_dataset(
@@ -113,6 +132,7 @@ def postprocess(x):
113132
gpu_ids=gpu_ids, block_shape=block_shape, halo=halo,
114133
output=output, preprocess=preprocess, postprocess=postprocess,
115134
mask=image_mask,
135+
iter_list=slurm_iteration,
116136
)
117137

118138
return original_shape
@@ -223,6 +243,30 @@ def write_block(block_id):
223243
tp.map(write_block, range(blocking.numberOfBlocks))
224244

225245

246+
def calc_mean_and_std(input_path, input_key, output_folder):
247+
"""
248+
Calculate mean and standard deviation of full volume.
249+
Parameters are saved in 'mean_std.json' within the output folder.
250+
"""
251+
json_file = os.path.join(output_folder, "mean_std.json")
252+
mask_path = os.path.join(output_folder, "mask.zarr")
253+
image_mask = z5py.File(mask_path, "r")["mask"]
254+
255+
if input_key is None:
256+
input_ = imageio.imread(input_path)
257+
else:
258+
input_ = open_file(input_path, "r")[input_key]
259+
260+
# Compute the global mean and standard deviation.
261+
n_threads = min(16, mp.cpu_count())
262+
mean, std = parallel.mean_and_std(
263+
input_, block_shape=tuple([2* i for i in input_.chunks]), n_threads=n_threads, verbose=True,
264+
mask=image_mask
265+
)
266+
ddict = {"mean":mean, "std":std}
267+
with open(json_file, "w") as f:
268+
json.dump(ddict, f)
269+
226270
def run_unet_prediction(
227271
input_path, input_key,
228272
output_folder, model_path,
@@ -239,3 +283,56 @@ def run_unet_prediction(
239283

240284
pmap_out = os.path.join(output_folder, "predictions.zarr")
241285
segmentation_impl(pmap_out, output_folder, min_size=min_size, original_shape=original_shape)
286+
287+
#---Workflow for parallel prediction using slurm---
288+
289+
def run_unet_prediction_preprocess_slurm(
290+
input_path, input_key, output_folder,
291+
):
292+
"""
293+
Pre-processing for the parallel prediction with U-Net models.
294+
Masks are stored in mask.zarr in the output folder.
295+
The mean and standard deviation are precomputed for later usage during prediction
296+
and stored in a JSON file within the output folder as mean_std.json
297+
"""
298+
find_mask(input_path, input_key, output_folder)
299+
calc_mean_and_std(input_path, input_key, output_folder)
300+
301+
def run_unet_prediction_slurm(
302+
input_path, input_key, output_folder, model_path,
303+
scale=None,
304+
block_shape=None, halo=None, prediction_instances=1,
305+
):
306+
os.makedirs(output_folder, exist_ok=True)
307+
prediction_instances = int(prediction_instances)
308+
slurm_task_id = os.environ.get("SLURM_ARRAY_TASK_ID")
309+
310+
if slurm_task_id is not None:
311+
slurm_task_id = int(slurm_task_id)
312+
else:
313+
raise ValueError("The SLURM_ARRAY_TASK_ID is not set. Ensure that you are using the '-a' option with SBATCH.")
314+
315+
if not os.path.isdir(os.path.join(output_folder, "mask.zarr")):
316+
find_mask(input_path, input_key, output_folder)
317+
318+
# get pre-computed mean and standard deviation of full volume from JSON file
319+
if os.path.isfile(os.path.join(output_folder, "mean_std.json")):
320+
with open(os.path.join(output_folder, "mean_std.json")) as f:
321+
d = json.load(f)
322+
mean = float(d["mean"])
323+
std = float(d["std"])
324+
else:
325+
mean = None
326+
std = None
327+
328+
original_shape = prediction_impl(
329+
input_path, input_key, output_folder, model_path, scale, block_shape, halo,
330+
prediction_instances=prediction_instances, slurm_task_id=slurm_task_id,
331+
mean=mean, std=std,
332+
)
333+
334+
# does NOT need GPU, FIXME: only run on CPU
335+
def run_unet_segmentation_slurm(output_folder, min_size):
336+
min_size = int(min_size)
337+
pmap_out = os.path.join(output_folder, "predictions.zarr")
338+
segmentation_impl(pmap_out, output_folder, min_size=min_size)

scripts/convert_tif_to_n5.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
import os, sys
2+
import argparse
3+
import pybdv
4+
import imageio.v3 as imageio
5+
6+
7+
def main(input_path, output_path):
8+
"""
9+
Convert tif file to n5 format.
10+
If no output_path is supplied, the output file is created in the same directory as the input.
11+
12+
:param str input_path: Input tif
13+
:param str output_path: Output path for n5 format
14+
"""
15+
if not os.path.isfile(input_path):
16+
sys.exit("Input file does not exist.")
17+
18+
if input_path.split(".")[-1] not in ["TIFF", "TIF", "tiff", "tif"]:
19+
sys.exit("Input file must be in tif format.")
20+
21+
basename = "".join(input_path.split("/")[-1].split(".")[:-1])
22+
input_dir = input_path.split(basename)[0]
23+
input_dir = os.path.abspath(input_dir)
24+
25+
if "" == output_path:
26+
output_path = os.path.join(input_dir, basename + ".n5")
27+
img = imageio.imread(input_path)
28+
pybdv.make_bdv(img, output_path)
29+
30+
if __name__ == "__main__":
31+
32+
parser = argparse.ArgumentParser(
33+
description="Script to transform file from tif into n5 format.")
34+
35+
parser.add_argument('input', type=str, help="Input file")
36+
parser.add_argument('-o', "--output", type=str, default="", help="Output file. Default: <basename>.n5")
37+
38+
args = parser.parse_args()
39+
40+
main(args.input, args.output)

scripts/extract_block.py

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
import os
2+
import argparse
3+
import numpy as np
4+
import z5py
5+
import zarr
6+
7+
import s3fs
8+
9+
"""
10+
This script extracts data around an input center coordinate in a given ROI halo.
11+
12+
The support for using an S3 bucket is currently limited to the lightsheet-cochlea bucket with the endpoint url https://s3.fs.gwdg.de.
13+
If more use cases appear, the script will be generalized.
14+
The usage requires the export of the access and the secret access key within the environment before executing the script.
15+
run the following commands in the shell of your choice, or add them to your ~/.bashrc:
16+
export AWS_ACCESS_KEY_ID=<access key>
17+
export AWS_SECRET_ACCESS_KEY=<secret access key>
18+
"""
19+
20+
21+
def main(input_file, output_dir, input_key, resolution, coords, roi_halo, s3):
22+
"""
23+
24+
:param str input_file: File path to input folder in n5 format
25+
:param str output_dir: output directory for saving cropped n5 file as <basename>_crop.n5
26+
:param str input_key: Key for accessing volume in n5 format, e.g. 'setup0/s0'
27+
:param float resolution: Resolution of input data in micrometer
28+
:param str coords: Center coordinates of extracted 3D volume in format 'x,y,z'
29+
:param str roi_halo: ROI halo of extracted 3D volume in format 'x,y,z'
30+
:param bool s3: Flag for using an S3 bucket
31+
"""
32+
33+
coords = [int(r) for r in coords.split(",")]
34+
roi_halo = [int(r) for r in roi_halo.split(",")]
35+
36+
coord_string = "-".join([str(c) for c in coords])
37+
38+
# Dimensions are inversed to view in MoBIE (x y z) -> (z y x)
39+
coords.reverse()
40+
roi_halo.reverse()
41+
42+
input_content = list(filter(None, input_file.split("/")))
43+
44+
if s3:
45+
basename = input_content[0] + "_" + input_content[-1].split(".")[0]
46+
else:
47+
basename = "".join(input_content[-1].split(".")[:-1])
48+
49+
input_dir = input_file.split(basename)[0]
50+
input_dir = os.path.abspath(input_dir)
51+
52+
if output_dir == "":
53+
output_dir = input_dir
54+
55+
output_file = os.path.join(output_dir, basename + "_crop_" + coord_string + ".n5")
56+
57+
coords = np.array(coords)
58+
coords = coords / resolution
59+
coords = np.round(coords).astype(np.int32)
60+
61+
roi = tuple(slice(co - rh, co + rh) for co, rh in zip(coords, roi_halo))
62+
63+
if s3:
64+
65+
# Define S3 bucket and OME-Zarr dataset path
66+
67+
bucket_name = "cochlea-lightsheet"
68+
zarr_path = f"{bucket_name}/{input_file}"
69+
70+
# Create an S3 filesystem
71+
fs = s3fs.S3FileSystem(
72+
client_kwargs={"endpoint_url": "https://s3.fs.gwdg.de"},
73+
anon=False
74+
)
75+
76+
if not fs.exists(zarr_path):
77+
print("Error: Path does not exist!")
78+
79+
# Open the OME-Zarr dataset
80+
store = zarr.storage.FSStore(zarr_path, fs=fs)
81+
print(f"Opening file {zarr_path} from the S3 bucket.")
82+
83+
with zarr.open(store, mode="r") as f:
84+
raw = f[input_key][roi]
85+
86+
else:
87+
with z5py.File(input_file, "r") as f:
88+
raw = f[input_key][roi]
89+
90+
with z5py.File(output_file, "w") as f_out:
91+
f_out.create_dataset("raw", data=raw, compression="gzip")
92+
93+
if __name__ == "__main__":
94+
95+
parser = argparse.ArgumentParser(
96+
description="Script to extract region of interest (ROI) block around center coordinate.")
97+
98+
parser.add_argument('input', type=str, help="Input file in n5 format.")
99+
parser.add_argument('-o', "--output", type=str, default="", help="Output directory")
100+
parser.add_argument('-c', "--coord", type=str, required=True, help="3D coordinate in format 'x,y,z' as center of extracted block.")
101+
102+
parser.add_argument('-k', "--input_key", type=str, default="setup0/timepoint0/s0", help="Input key for data in input file")
103+
parser.add_argument('-r', "--resolution", type=float, default=0.38, help="Resolution of input in micrometer")
104+
105+
parser.add_argument("--roi_halo", type=str, default="128,128,64", help="ROI halo around center coordinate in format 'x,y,z'")
106+
parser.add_argument("--s3", action="store_true", help="Use S3 bucket")
107+
108+
args = parser.parse_args()
109+
110+
main(args.input, args.output, args.input_key, args.resolution, args.coord, args.roi_halo, args.s3)

scripts/prediction/count_cells.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
import argparse
2+
import os
3+
import sys
4+
5+
from elf.parallel import unique
6+
from elf.io import open_file
7+
8+
sys.path.append("../..")
9+
10+
11+
def main():
12+
parser = argparse.ArgumentParser()
13+
parser.add_argument("-o", "--output_folder", type=str, required=True, help="Output directory containing segmentation.zarr")
14+
parser.add_argument("-m", "--min_size", type=int, default=1000, help="Minimal number of voxel size for counting object")
15+
args = parser.parse_args()
16+
17+
seg_path = os.path.join(args.output_folder, "segmentation.zarr")
18+
seg_key = "segmentation"
19+
20+
file = open_file(seg_path, mode='r')
21+
dataset = file[seg_key]
22+
23+
ids, counts = unique(dataset, return_counts=True)
24+
25+
# You can change the minimal size for objects to be counted here:
26+
min_size = args.min_size
27+
28+
counts = counts[counts > min_size]
29+
print("Number of objects:", len(counts))
30+
31+
if __name__ == "__main__":
32+
main()

scripts/prediction/run_prediction_distance_unet.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,11 @@
66

77
sys.path.append("../..")
88

9+
"""
10+
Prediction using distance U-Net.
11+
Parallelization using multiple GPUs is currently only possible by calling functions located in segmentation/unet_prediction.py directly.
12+
Functions for the parallelization end with '_slurm' and divide the process into preprocessing, prediction, and segmentation.
13+
"""
914

1015
def main():
1116
from flamingo_tools.segmentation import run_unet_prediction

0 commit comments

Comments
 (0)