Skip to content

Commit 28d36ff

Browse files
committed
Allow prediction in memory, time process
1 parent 1d1adee commit 28d36ff

File tree

2 files changed

+159
-64
lines changed

2 files changed

+159
-64
lines changed

flamingo_tools/segmentation/unet_prediction.py

Lines changed: 111 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ def prediction_impl(
6767
slurm_task_id=0,
6868
mean=None,
6969
std=None,
70+
mask=None
7071
):
7172
"""@private
7273
"""
@@ -79,18 +80,20 @@ def prediction_impl(
7980

8081
input_ = read_image_data(input_path, input_key)
8182
chunks = getattr(input_, "chunks", (64, 64, 64))
82-
mask_path = os.path.join(output_folder, "mask.zarr")
83-
84-
if os.path.exists(mask_path):
85-
image_mask = z5py.File(mask_path, "r")["mask"]
86-
# resize mask
87-
image_shape = input_.shape
88-
mask_shape = image_mask.shape
89-
if image_shape != mask_shape:
90-
image_mask = ResizedVolume(image_mask, image_shape, order=0)
9183

84+
if output_folder is None:
85+
image_mask = mask
9286
else:
93-
image_mask = None
87+
mask_path = os.path.join(output_folder, "mask.zarr")
88+
if os.path.exists(mask_path):
89+
image_mask = z5py.File(mask_path, "r")["mask"]
90+
# resize mask
91+
image_shape = input_.shape
92+
mask_shape = image_mask.shape
93+
if image_shape != mask_shape:
94+
image_mask = ResizedVolume(image_mask, image_shape, order=0)
95+
else:
96+
image_mask = mask
9497

9598
if scale is None or np.isclose(scale, 1):
9699
original_shape = None
@@ -162,16 +165,8 @@ def postprocess(x):
162165
else:
163166
slurm_iteration = list(range(n_blocks))
164167

165-
output_path = os.path.join(output_folder, "predictions.zarr")
166-
with open_file(output_path, "a") as f:
167-
output = f.require_dataset(
168-
"prediction",
169-
shape=output_shape,
170-
chunks=output_chunks,
171-
compression="gzip",
172-
dtype="float32",
173-
)
174-
168+
if output_folder is None:
169+
output = np.zeros(output_shape, dtype=np.float32)
175170
predict_with_halo(
176171
input_, model,
177172
gpu_ids=gpu_ids, block_shape=block_shape, halo=halo,
@@ -180,10 +175,37 @@ def postprocess(x):
180175
iter_list=slurm_iteration,
181176
)
182177

183-
return original_shape
178+
else:
179+
output_path = os.path.join(output_folder, "predictions.zarr")
180+
with open_file(output_path, "a") as f:
181+
output = f.require_dataset(
182+
"prediction",
183+
shape=output_shape,
184+
chunks=output_chunks,
185+
compression="gzip",
186+
dtype="float32",
187+
)
188+
189+
predict_with_halo(
190+
input_, model,
191+
gpu_ids=gpu_ids, block_shape=block_shape, halo=halo,
192+
output=output, preprocess=preprocess, postprocess=postprocess,
193+
mask=image_mask,
194+
iter_list=slurm_iteration,
195+
)
196+
197+
if output_folder is None:
198+
return original_shape, output
199+
else:
200+
return original_shape, None
184201

185202

186-
def find_mask(input_path: str, input_key: Optional[str], output_folder: str, seg_class: Optional[str] = "sgn") -> None:
203+
def find_mask(
204+
input_path: str,
205+
input_key: Optional[str],
206+
output_folder: Optional[str],
207+
seg_class: Optional[str] = "sgn"
208+
) -> None:
187209
"""Determine the mask for running prediction.
188210
189211
The mask corresponds to data that contains actual signal and not just noise.
@@ -197,9 +219,6 @@ def find_mask(input_path: str, input_key: Optional[str], output_folder: str, seg
197219
output_folder: The output folder for storing the mask data.
198220
seg_class: Specifier for exclusion criterias for mask generation.
199221
"""
200-
mask_path = os.path.join(output_folder, "mask.zarr")
201-
f = z5py.File(mask_path, "a")
202-
203222
# set parameters for the exclusion of chunks within mask generation
204223
if seg_class == "sgn":
205224
upper_percentile = 95
@@ -214,18 +233,24 @@ def find_mask(input_path: str, input_key: Optional[str], output_folder: str, seg
214233
min_intensity = 200
215234
print("Calculating mask with default values.")
216235

217-
mask_key = "mask"
218-
if mask_key in f:
219-
return
220-
221236
raw = read_image_data(input_path, input_key)
222237
chunks = getattr(raw, "chunks", (64, 64, 64))
223238

224239
block_shape = tuple(2 * ch for ch in chunks)
225240
blocking = nt.blocking([0, 0, 0], raw.shape, block_shape)
226241
n_blocks = blocking.numberOfBlocks
227242

228-
ds_mask = f.create_dataset(mask_key, shape=raw.shape, compression="gzip", dtype="uint8", chunks=block_shape)
243+
if output_folder is None:
244+
ds_mask = np.zeros(raw.shape, dtype=np.uint64)
245+
246+
else:
247+
mask_path = os.path.join(output_folder, "mask.zarr")
248+
f = z5py.File(mask_path, "a")
249+
mask_key = "mask"
250+
if mask_key in f:
251+
return
252+
253+
ds_mask = f.create_dataset(mask_key, shape=raw.shape, compression="gzip", dtype="uint8", chunks=block_shape)
229254

230255
# TODO more sophisticated criterion?!
231256
def find_mask_block(block_id):
@@ -240,15 +265,20 @@ def find_mask_block(block_id):
240265
with futures.ThreadPoolExecutor(n_threads) as tp:
241266
list(tqdm(tp.map(find_mask_block, range(n_blocks)), total=n_blocks))
242267

268+
if output_folder is None:
269+
return ds_mask
270+
else:
271+
return None
272+
243273

244274
def distance_watershed_implementation(
245275
input_path: str,
246-
output_folder: str,
247-
min_size: int,
276+
output_folder: Optional[str] = None,
277+
min_size: int = 1000,
248278
center_distance_threshold: float = 0.4,
249279
boundary_distance_threshold: Optional[float] = None,
250280
fg_threshold: float = 0.5,
251-
original_shape: Optional[Tuple[int, int, int]] = None,
281+
original_shape: Optional[Tuple[int, int, int]] = None
252282
) -> None:
253283
"""Parallel implementation of the distance-prediction based watershed.
254284
@@ -262,7 +292,10 @@ def distance_watershed_implementation(
262292
fg_threshold: The threshold applied to the foreground prediction for deriving the watershed mask.
263293
original_shape: The original shape to resize the segmentation to.
264294
"""
265-
input_ = open_file(input_path, "r")["prediction"]
295+
if isinstance(input_path, str):
296+
input_ = open_file(input_path, "r")["prediction"]
297+
else:
298+
input_ = input_path
266299

267300
# Limit the number of cores for parallelization.
268301
n_threads = min(16, mp.cpu_count())
@@ -280,13 +313,17 @@ def distance_watershed_implementation(
280313
# center_distances = SimpleTransformationWrapper(center_distances, transformation=smoothing)
281314
# boundary_distances = SimpleTransformationWrapper(boundary_distances, transformation=smoothing)
282315

283-
# Allocate an zarr array for the seeds.
284-
block_shape = center_distances.chunks
285-
seed_path = os.path.join(output_folder, "seeds.zarr")
286-
seed_file = open_file(os.path.join(seed_path), "a")
287-
seeds = seed_file.require_dataset(
288-
"seeds", shape=center_distances.shape, chunks=block_shape, compression="gzip", dtype="uint64"
289-
)
316+
# Allocate the (zarr) array for the seeds.
317+
if output_folder is None:
318+
block_shape = (20, 128, 128)
319+
seeds = np.zeros(center_distances.shape, dtype=np.uint64)
320+
else:
321+
block_shape = center_distances.chunks
322+
seed_path = os.path.join(output_folder, "seeds.zarr")
323+
seed_file = open_file(os.path.join(seed_path), "a")
324+
seeds = seed_file.require_dataset(
325+
"seeds", shape=center_distances.shape, chunks=block_shape, compression="gzip", dtype="uint64"
326+
)
290327

291328
# Compute the seed inputs:
292329
# First, threshold the center distances.
@@ -301,12 +338,15 @@ def distance_watershed_implementation(
301338
data=seed_inputs, out=seeds, block_shape=block_shape, mask=mask, verbose=True, n_threads=n_threads
302339
)
303340

304-
# Allocate the zarr array for the segmentation.
305-
seg_path = os.path.join(output_folder, "segmentation.zarr" if original_shape is None else "seg_downscaled.zarr")
306-
seg_file = open_file(seg_path, "a")
307-
seg = seg_file.create_dataset(
308-
"segmentation", shape=seeds.shape, chunks=block_shape, compression="gzip", dtype="uint64"
309-
)
341+
# Allocate the (zarr) array for the segmentation.
342+
if output_folder is None:
343+
seg = np.zeros(seeds.shape, dtype=np.uint64)
344+
else:
345+
seg_path = os.path.join(output_folder, "segmentation.zarr" if original_shape is None else "seg_downscaled.zarr")
346+
seg_file = open_file(seg_path, "a")
347+
seg = seg_file.create_dataset(
348+
"segmentation", shape=seeds.shape, chunks=block_shape, compression="gzip", dtype="uint64"
349+
)
310350

311351
# Compute the segmentation with a seeded watershed
312352
halo = (2, 8, 8)
@@ -341,6 +381,11 @@ def write_block(block_id):
341381
with futures.ThreadPoolExecutor(n_threads) as tp:
342382
tp.map(write_block, range(blocking.numberOfBlocks))
343383

384+
if output_folder is None:
385+
return seg
386+
else:
387+
return None
388+
344389

345390
def calc_mean_and_std(input_path: str, input_key: str, output_folder: str) -> None:
346391
"""Calculate mean and standard deviation of the input volume.
@@ -372,7 +417,7 @@ def calc_mean_and_std(input_path: str, input_key: str, output_folder: str) -> No
372417
def run_unet_prediction(
373418
input_path: str,
374419
input_key: Optional[str],
375-
output_folder: str,
420+
output_folder: Optional[str],
376421
model_path: str,
377422
min_size: int,
378423
scale: Optional[float] = None,
@@ -403,22 +448,33 @@ def run_unet_prediction(
403448
fg_threshold: The threshold applied to the foreground prediction for deriving the watershed mask.
404449
seg_class: Specifier for exclusion criterias for mask generation.
405450
"""
406-
os.makedirs(output_folder, exist_ok=True)
451+
if output_folder is not None:
452+
os.makedirs(output_folder, exist_ok=True)
407453

408454
if use_mask:
409-
find_mask(input_path, input_key, output_folder, seg_class=seg_class)
410-
original_shape = prediction_impl(
411-
input_path, input_key, output_folder, model_path, scale, block_shape, halo
455+
mask = find_mask(input_path, input_key, output_folder=output_folder, seg_class=seg_class)
456+
else:
457+
mask = None
458+
459+
original_shape, prediction = prediction_impl(
460+
input_path=input_path, input_key=input_key, output_folder=output_folder, model_path=model_path, scale=scale,
461+
block_shape=block_shape, halo=halo, mask=mask
412462
)
413463

414-
pmap_out = os.path.join(output_folder, "predictions.zarr")
415-
distance_watershed_implementation(
464+
if output_folder is None:
465+
pmap_out = prediction
466+
else:
467+
pmap_out = os.path.join(output_folder, "predictions.zarr")
468+
469+
segmentation = distance_watershed_implementation(
416470
pmap_out, output_folder, min_size=min_size, original_shape=original_shape,
417471
center_distance_threshold=center_distance_threshold,
418472
boundary_distance_threshold=boundary_distance_threshold,
419-
fg_threshold=fg_threshold,
473+
fg_threshold=fg_threshold
420474
)
421475

476+
return segmentation
477+
422478

423479
#
424480
# ---Workflow for parallel prediction using slurm---

scripts/prediction/run_prediction_distance_unet.py

Lines changed: 48 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,11 @@
55
prediction, and segmentation.
66
"""
77
import argparse
8+
import json
9+
import time
10+
import os
811

12+
import imageio.v3 as imageio
913
import torch
1014
import z5py
1115

@@ -19,7 +23,12 @@ def main():
1923
parser.add_argument("-m", "--model", required=True)
2024
parser.add_argument("-k", "--input_key", default=None)
2125
parser.add_argument("-s", "--scale", default=None, type=float, help="Downscale the image by the given factor.")
22-
parser.add_argument("-b", "--block_shape", default=None, type=int, nargs=3)
26+
parser.add_argument("-b", "--block_shape", default=None, type=str)
27+
parser.add_argument("--halo", default=None, type=str)
28+
parser.add_argument("--memory", action="store_true", help="Perform prediction in memory and save output as tif.")
29+
parser.add_argument("--time", action="store_true", help="Time prediction process.")
30+
parser.add_argument("--seg_class", default=None, type=str,
31+
help="Segmentation class to load parameters for masking input.")
2332

2433
args = parser.parse_args()
2534

@@ -36,21 +45,51 @@ def main():
3645
if args.block_shape is None:
3746
block_shape = (64, 256, 256) if have_cuda else (64, 64, 64)
3847
else:
39-
block_shape = tuple(args.block_shape)
40-
halo = (16, 64, 64) if have_cuda else (8, 32, 32)
48+
block_shape = tuple(json.loads(args.block_shape))
49+
4150
else:
4251
if args.block_shape is None:
4352
chunks = z5py.File(args.input, "r")[args.input_key].chunks
4453
block_shape = tuple([2 * ch for ch in chunks]) if have_cuda else tuple(chunks)
4554
else:
46-
block_shape = tuple(args.block_shape)
55+
block_shape = json.loads(args.block_shape)
56+
57+
if args.halo is None:
4758
halo = (16, 64, 64) if have_cuda else (8, 32, 32)
59+
else:
60+
halo = tuple(json.loads(args.halo))
61+
62+
if args.time:
63+
start = time.perf_counter()
64+
65+
if args.memory:
66+
segmentation = run_unet_prediction(
67+
args.input, args.input_key, output_folder=None, model_path=args.model,
68+
scale=scale, min_size=min_size,
69+
block_shape=block_shape, halo=halo,
70+
seg_class=args.seg_class,
71+
)
72+
73+
abs_path = os.path.abspath(args.input)
74+
basename = ".".join(os.path.basename(abs_path).split(".")[:-1])
75+
output_path = os.path.join(args.output_folder, basename + "_seg.tif")
76+
imageio.imwrite(output_path, segmentation, compression="zlib")
77+
timer_output = os.path.join(args.output_folder, basename + "_timer.json")
78+
79+
else:
80+
run_unet_prediction(
81+
args.input, args.input_key, output_folder=args.output_folder, model_path=args.model,
82+
scale=scale, min_size=min_size,
83+
block_shape=block_shape, halo=halo,
84+
seg_class=args.seg_class,
85+
)
86+
timer_output = os.path.join(args.output_folder, "timer.json")
4887

49-
run_unet_prediction(
50-
args.input, args.input_key, args.output_folder, args.model,
51-
scale=scale, min_size=min_size,
52-
block_shape=block_shape, halo=halo,
53-
)
88+
if args.time:
89+
duration = time.perf_counter() - start
90+
time_dict = {"total_duration[s]": duration}
91+
with open(timer_output, "w") as f:
92+
json.dump(time_dict, f, indent='\t', separators=(',', ': '))
5493

5594

5695
if __name__ == "__main__":

0 commit comments

Comments
 (0)