Skip to content

Commit d26f2c6

Browse files
committed
Small changes and documentation
1 parent 5473462 commit d26f2c6

File tree

5 files changed

+44
-42
lines changed

5 files changed

+44
-42
lines changed
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
1-
from .unet_prediction import run_unet_prediction, run_unet_prediction_slurm
1+
from .unet_prediction import run_unet_prediction
22
from .postprocessing import filter_isolated_objects

flamingo_tools/segmentation/unet_prediction.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,11 @@
1919
from torch_em.util.prediction import predict_with_halo
2020
from tqdm import tqdm
2121

22+
"""
23+
Prediction using distance U-Net.
24+
Parallelization using multiple GPUs is currently only possible by calling functions directly.
25+
Functions for the parallelization end with '_slurm' and divide the process into preprocessing, prediction, and segmentation.
26+
"""
2227

2328
class SelectChannel(SimpleTransformationWrapper):
2429
def __init__(self, volume, channel):
@@ -281,7 +286,9 @@ def run_unet_prediction(
281286
pmap_out = os.path.join(output_folder, "predictions.zarr")
282287
segmentation_impl(pmap_out, output_folder, min_size=min_size, original_shape=original_shape)
283288

284-
def run_unet_prediction_slurm_preprocess(
289+
#---Workflow for parallel prediction using slurm---
290+
291+
def run_unet_prediction_preprocess_slurm(
285292
input_path, input_key, output_folder,
286293
):
287294
"""

scripts/convert_tif_to_n5.py

Lines changed: 24 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -5,29 +5,36 @@
55

66

77
def main(input_path, output_path):
8-
if not os.path.isfile(input_path):
9-
sys.exit("Input file does not exist.")
8+
"""
9+
Convert tif file to n5 format.
10+
If no output_path is supplied, the output file is created in the same directory as the input.
1011
11-
if input_path.split(".")[-1] not in ["TIFF", "TIF", "tiff", "tif"]:
12-
sys.exit("Input file must be in tif format.")
12+
:param str input_path: Input tif
13+
:param str output_path: Output path for n5 format
14+
"""
15+
if not os.path.isfile(input_path):
16+
sys.exit("Input file does not exist.")
1317

14-
basename = "".join(input_path.split("/")[-1].split(".")[:-1])
15-
input_dir = input_path.split(basename)[0]
16-
input_dir = os.path.abspath(input_dir)
18+
if input_path.split(".")[-1] not in ["TIFF", "TIF", "tiff", "tif"]:
19+
sys.exit("Input file must be in tif format.")
1720

18-
if "" == output_path:
19-
output_path = os.path.join(input_dir, basename + ".n5")
20-
img = imageio.imread(input_path)
21-
pybdv.make_bdv(img, output_path)
21+
basename = "".join(input_path.split("/")[-1].split(".")[:-1])
22+
input_dir = input_path.split(basename)[0]
23+
input_dir = os.path.abspath(input_dir)
24+
25+
if "" == output_path:
26+
output_path = os.path.join(input_dir, basename + ".n5")
27+
img = imageio.imread(input_path)
28+
pybdv.make_bdv(img, output_path)
2229

2330
if __name__ == "__main__":
2431

25-
parser = argparse.ArgumentParser(
26-
description="Script to transform file from tif into n5 format.")
32+
parser = argparse.ArgumentParser(
33+
description="Script to transform file from tif into n5 format.")
2734

28-
parser.add_argument('input', type=str, help="Input file")
29-
parser.add_argument('-o', "--output", type=str, default="", help="Output file")
35+
parser.add_argument('input', type=str, help="Input file")
36+
parser.add_argument('-o', "--output", type=str, default="", help="Output file. Default: <basename>.n5")
3037

31-
args = parser.parse_args()
38+
args = parser.parse_args()
3239

33-
main(args.input, args.output)
40+
main(args.input, args.output)

scripts/extract_block.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import os
22
import argparse
33
import numpy as np
4-
import h5py
54
import z5py
65

76
"""
@@ -31,8 +30,6 @@ def main(input_file, output_dir, input_key, resolution, coords, roi_halo):
3130
if "" == output_dir:
3231
output_dir = input_dir
3332

34-
input_key = "setup0/timepoint0/s0"
35-
3633
output_file = os.path.join(output_dir, basename + "_crop" + ".n5")
3734

3835
#M_LR_000167_R, coords = '806,1042,1334', coords = (z, y, x) compared to MoBIE view

scripts/prediction/run_prediction_distance_unet.py

Lines changed: 11 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -6,17 +6,21 @@
66

77
sys.path.append("../..")
88

9+
"""
10+
Prediction using distance U-Net.
11+
Parallelization using multiple GPUs is currently only possible by calling functions located in segmentation/unet_prediction.py directly.
12+
Functions for the parallelization end with '_slurm' and divide the process into preprocessing, prediction, and segmentation.
13+
"""
914

1015
def main():
11-
from flamingo_tools.segmentation import run_unet_prediction, run_unet_prediction_slurm
16+
from flamingo_tools.segmentation import run_unet_prediction
1217

1318
parser = argparse.ArgumentParser()
1419
parser.add_argument("-i", "--input", required=True)
1520
parser.add_argument("-o", "--output_folder", required=True)
1621
parser.add_argument("-m", "--model", required=True)
1722
parser.add_argument("-k", "--input_key", default=None)
1823
parser.add_argument("-s", "--scale", default=None, type=float, help="Downscale the image by the given factor.")
19-
parser.add_argument("-n", "--number_gpu", default=1, type=int, help="Number of GPUs to use in parallel.")
2024

2125
args = parser.parse_args()
2226

@@ -37,24 +41,11 @@ def main():
3741
block_shape = tuple([2 * ch for ch in chunks]) if have_cuda else tuple(chunks)
3842
halo = (16, 64, 64) if have_cuda else (8, 32, 32)
3943

40-
prediction_instances = args.number_gpu if have_cuda else 1
41-
42-
if 1 > prediction_instances:
43-
# FIXME: only does prediction part, no segmentation yet
44-
# FIXME: implement array job
45-
run_unet_prediction_slurm(
46-
args.input, args.input_key, args.output_folder, args.model,
47-
scale=scale,
48-
block_shape=block_shape, halo=halo,
49-
prediction_instances=prediction_instances,
50-
)
51-
else:
52-
53-
run_unet_prediction(
54-
args.input, args.input_key, args.output_folder, args.model,
55-
scale=scale, min_size=min_size,
56-
block_shape=block_shape, halo=halo,
57-
)
44+
run_unet_prediction(
45+
args.input, args.input_key, args.output_folder, args.model,
46+
scale=scale, min_size=min_size,
47+
block_shape=block_shape, halo=halo,
48+
)
5849

5950

6051
if __name__ == "__main__":

0 commit comments

Comments
 (0)