Skip to content

Commit 1d0ed5a

Browse files
authored
Merge pull request #38 from computational-cell-analytics/time_prediction_in_memory
The code is used to segment SGN and IHC crops in memory. This allows us to test and compare the network's performance to that of other segmentation options, such as StarDist, Cellpose, and Micro-Sam. Additional options were added to the distance U-Net prediction to allow the customization of the boundary distance threshold for the IHC segmentation.
2 parents bf1079c + b24546f commit 1d0ed5a

File tree

2 files changed

+173
-64
lines changed

2 files changed

+173
-64
lines changed

flamingo_tools/segmentation/unet_prediction.py

Lines changed: 111 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ def prediction_impl(
6767
slurm_task_id=0,
6868
mean=None,
6969
std=None,
70+
mask=None
7071
):
7172
"""@private
7273
"""
@@ -79,18 +80,20 @@ def prediction_impl(
7980

8081
input_ = read_image_data(input_path, input_key)
8182
chunks = getattr(input_, "chunks", (64, 64, 64))
82-
mask_path = os.path.join(output_folder, "mask.zarr")
83-
84-
if os.path.exists(mask_path):
85-
image_mask = z5py.File(mask_path, "r")["mask"]
86-
# resize mask
87-
image_shape = input_.shape
88-
mask_shape = image_mask.shape
89-
if image_shape != mask_shape:
90-
image_mask = ResizedVolume(image_mask, image_shape, order=0)
9183

84+
if output_folder is None:
85+
image_mask = mask
9286
else:
93-
image_mask = None
87+
mask_path = os.path.join(output_folder, "mask.zarr")
88+
if os.path.exists(mask_path):
89+
image_mask = z5py.File(mask_path, "r")["mask"]
90+
# resize mask
91+
image_shape = input_.shape
92+
mask_shape = image_mask.shape
93+
if image_shape != mask_shape:
94+
image_mask = ResizedVolume(image_mask, image_shape, order=0)
95+
else:
96+
image_mask = mask
9497

9598
if scale is None or np.isclose(scale, 1):
9699
original_shape = None
@@ -162,16 +165,8 @@ def postprocess(x):
162165
else:
163166
slurm_iteration = list(range(n_blocks))
164167

165-
output_path = os.path.join(output_folder, "predictions.zarr")
166-
with open_file(output_path, "a") as f:
167-
output = f.require_dataset(
168-
"prediction",
169-
shape=output_shape,
170-
chunks=output_chunks,
171-
compression="gzip",
172-
dtype="float32",
173-
)
174-
168+
if output_folder is None:
169+
output = np.zeros(output_shape, dtype=np.float32)
175170
predict_with_halo(
176171
input_, model,
177172
gpu_ids=gpu_ids, block_shape=block_shape, halo=halo,
@@ -180,10 +175,37 @@ def postprocess(x):
180175
iter_list=slurm_iteration,
181176
)
182177

183-
return original_shape
178+
else:
179+
output_path = os.path.join(output_folder, "predictions.zarr")
180+
with open_file(output_path, "a") as f:
181+
output = f.require_dataset(
182+
"prediction",
183+
shape=output_shape,
184+
chunks=output_chunks,
185+
compression="gzip",
186+
dtype="float32",
187+
)
188+
189+
predict_with_halo(
190+
input_, model,
191+
gpu_ids=gpu_ids, block_shape=block_shape, halo=halo,
192+
output=output, preprocess=preprocess, postprocess=postprocess,
193+
mask=image_mask,
194+
iter_list=slurm_iteration,
195+
)
196+
197+
if output_folder is None:
198+
return original_shape, output
199+
else:
200+
return original_shape, None
184201

185202

186-
def find_mask(input_path: str, input_key: Optional[str], output_folder: str, seg_class: Optional[str] = "sgn") -> None:
203+
def find_mask(
204+
input_path: str,
205+
input_key: Optional[str],
206+
output_folder: Optional[str],
207+
seg_class: Optional[str] = "sgn"
208+
) -> None:
187209
"""Determine the mask for running prediction.
188210
189211
The mask corresponds to data that contains actual signal and not just noise.
@@ -197,9 +219,6 @@ def find_mask(input_path: str, input_key: Optional[str], output_folder: str, seg
197219
output_folder: The output folder for storing the mask data.
198220
seg_class: Specifier for exclusion criterias for mask generation.
199221
"""
200-
mask_path = os.path.join(output_folder, "mask.zarr")
201-
f = z5py.File(mask_path, "a")
202-
203222
# set parameters for the exclusion of chunks within mask generation
204223
if seg_class == "sgn":
205224
upper_percentile = 95
@@ -214,18 +233,24 @@ def find_mask(input_path: str, input_key: Optional[str], output_folder: str, seg
214233
min_intensity = 200
215234
print("Calculating mask with default values.")
216235

217-
mask_key = "mask"
218-
if mask_key in f:
219-
return
220-
221236
raw = read_image_data(input_path, input_key)
222237
chunks = getattr(raw, "chunks", (64, 64, 64))
223238

224239
block_shape = tuple(2 * ch for ch in chunks)
225240
blocking = nt.blocking([0, 0, 0], raw.shape, block_shape)
226241
n_blocks = blocking.numberOfBlocks
227242

228-
ds_mask = f.create_dataset(mask_key, shape=raw.shape, compression="gzip", dtype="uint8", chunks=block_shape)
243+
if output_folder is None:
244+
ds_mask = np.zeros(raw.shape, dtype=np.uint64)
245+
246+
else:
247+
mask_path = os.path.join(output_folder, "mask.zarr")
248+
f = z5py.File(mask_path, "a")
249+
mask_key = "mask"
250+
if mask_key in f:
251+
return
252+
253+
ds_mask = f.create_dataset(mask_key, shape=raw.shape, compression="gzip", dtype="uint8", chunks=block_shape)
229254

230255
# TODO more sophisticated criterion?!
231256
def find_mask_block(block_id):
@@ -240,15 +265,20 @@ def find_mask_block(block_id):
240265
with futures.ThreadPoolExecutor(n_threads) as tp:
241266
list(tqdm(tp.map(find_mask_block, range(n_blocks)), total=n_blocks))
242267

268+
if output_folder is None:
269+
return ds_mask
270+
else:
271+
return None
272+
243273

244274
def distance_watershed_implementation(
245275
input_path: str,
246-
output_folder: str,
247-
min_size: int,
276+
output_folder: Optional[str] = None,
277+
min_size: int = 1000,
248278
center_distance_threshold: float = 0.4,
249279
boundary_distance_threshold: Optional[float] = None,
250280
fg_threshold: float = 0.5,
251-
original_shape: Optional[Tuple[int, int, int]] = None,
281+
original_shape: Optional[Tuple[int, int, int]] = None
252282
) -> None:
253283
"""Parallel implementation of the distance-prediction based watershed.
254284
@@ -262,7 +292,10 @@ def distance_watershed_implementation(
262292
fg_threshold: The threshold applied to the foreground prediction for deriving the watershed mask.
263293
original_shape: The original shape to resize the segmentation to.
264294
"""
265-
input_ = open_file(input_path, "r")["prediction"]
295+
if isinstance(input_path, str):
296+
input_ = open_file(input_path, "r")["prediction"]
297+
else:
298+
input_ = input_path
266299

267300
# Limit the number of cores for parallelization.
268301
n_threads = min(16, mp.cpu_count())
@@ -280,13 +313,17 @@ def distance_watershed_implementation(
280313
# center_distances = SimpleTransformationWrapper(center_distances, transformation=smoothing)
281314
# boundary_distances = SimpleTransformationWrapper(boundary_distances, transformation=smoothing)
282315

283-
# Allocate an zarr array for the seeds.
284-
block_shape = center_distances.chunks
285-
seed_path = os.path.join(output_folder, "seeds.zarr")
286-
seed_file = open_file(os.path.join(seed_path), "a")
287-
seeds = seed_file.require_dataset(
288-
"seeds", shape=center_distances.shape, chunks=block_shape, compression="gzip", dtype="uint64"
289-
)
316+
# Allocate the (zarr) array for the seeds.
317+
if output_folder is None:
318+
block_shape = (20, 128, 128)
319+
seeds = np.zeros(center_distances.shape, dtype=np.uint64)
320+
else:
321+
block_shape = center_distances.chunks
322+
seed_path = os.path.join(output_folder, "seeds.zarr")
323+
seed_file = open_file(os.path.join(seed_path), "a")
324+
seeds = seed_file.require_dataset(
325+
"seeds", shape=center_distances.shape, chunks=block_shape, compression="gzip", dtype="uint64"
326+
)
290327

291328
# Compute the seed inputs:
292329
# First, threshold the center distances.
@@ -301,12 +338,15 @@ def distance_watershed_implementation(
301338
data=seed_inputs, out=seeds, block_shape=block_shape, mask=mask, verbose=True, n_threads=n_threads
302339
)
303340

304-
# Allocate the zarr array for the segmentation.
305-
seg_path = os.path.join(output_folder, "segmentation.zarr" if original_shape is None else "seg_downscaled.zarr")
306-
seg_file = open_file(seg_path, "a")
307-
seg = seg_file.create_dataset(
308-
"segmentation", shape=seeds.shape, chunks=block_shape, compression="gzip", dtype="uint64"
309-
)
341+
# Allocate the (zarr) array for the segmentation.
342+
if output_folder is None:
343+
seg = np.zeros(seeds.shape, dtype=np.uint64)
344+
else:
345+
seg_path = os.path.join(output_folder, "segmentation.zarr" if original_shape is None else "seg_downscaled.zarr")
346+
seg_file = open_file(seg_path, "a")
347+
seg = seg_file.create_dataset(
348+
"segmentation", shape=seeds.shape, chunks=block_shape, compression="gzip", dtype="uint64"
349+
)
310350

311351
# Compute the segmentation with a seeded watershed
312352
halo = (2, 8, 8)
@@ -341,6 +381,11 @@ def write_block(block_id):
341381
with futures.ThreadPoolExecutor(n_threads) as tp:
342382
tp.map(write_block, range(blocking.numberOfBlocks))
343383

384+
if output_folder is None:
385+
return seg
386+
else:
387+
return None
388+
344389

345390
def calc_mean_and_std(input_path: str, input_key: str, output_folder: str) -> None:
346391
"""Calculate mean and standard deviation of the input volume.
@@ -372,7 +417,7 @@ def calc_mean_and_std(input_path: str, input_key: str, output_folder: str) -> No
372417
def run_unet_prediction(
373418
input_path: str,
374419
input_key: Optional[str],
375-
output_folder: str,
420+
output_folder: Optional[str],
376421
model_path: str,
377422
min_size: int,
378423
scale: Optional[float] = None,
@@ -403,22 +448,33 @@ def run_unet_prediction(
403448
fg_threshold: The threshold applied to the foreground prediction for deriving the watershed mask.
404449
seg_class: Specifier for exclusion criterias for mask generation.
405450
"""
406-
os.makedirs(output_folder, exist_ok=True)
451+
if output_folder is not None:
452+
os.makedirs(output_folder, exist_ok=True)
407453

408454
if use_mask:
409-
find_mask(input_path, input_key, output_folder, seg_class=seg_class)
410-
original_shape = prediction_impl(
411-
input_path, input_key, output_folder, model_path, scale, block_shape, halo
455+
mask = find_mask(input_path, input_key, output_folder=output_folder, seg_class=seg_class)
456+
else:
457+
mask = None
458+
459+
original_shape, prediction = prediction_impl(
460+
input_path=input_path, input_key=input_key, output_folder=output_folder, model_path=model_path, scale=scale,
461+
block_shape=block_shape, halo=halo, mask=mask
412462
)
413463

414-
pmap_out = os.path.join(output_folder, "predictions.zarr")
415-
distance_watershed_implementation(
464+
if output_folder is None:
465+
pmap_out = prediction
466+
else:
467+
pmap_out = os.path.join(output_folder, "predictions.zarr")
468+
469+
segmentation = distance_watershed_implementation(
416470
pmap_out, output_folder, min_size=min_size, original_shape=original_shape,
417471
center_distance_threshold=center_distance_threshold,
418472
boundary_distance_threshold=boundary_distance_threshold,
419-
fg_threshold=fg_threshold,
473+
fg_threshold=fg_threshold
420474
)
421475

476+
return segmentation
477+
422478

423479
#
424480
# ---Workflow for parallel prediction using slurm---

scripts/prediction/run_prediction_distance_unet.py

Lines changed: 62 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,11 @@
55
prediction, and segmentation.
66
"""
77
import argparse
8+
import json
9+
import time
10+
import os
811

12+
import imageio.v3 as imageio
913
import torch
1014
import z5py
1115

@@ -19,7 +23,20 @@ def main():
1923
parser.add_argument("-m", "--model", required=True)
2024
parser.add_argument("-k", "--input_key", default=None)
2125
parser.add_argument("-s", "--scale", default=None, type=float, help="Downscale the image by the given factor.")
22-
parser.add_argument("-b", "--block_shape", default=None, type=int, nargs=3)
26+
parser.add_argument("-b", "--block_shape", default=None, type=str)
27+
parser.add_argument("--halo", default=None, type=str)
28+
parser.add_argument("--memory", action="store_true", help="Perform prediction in memory and save output as tif.")
29+
parser.add_argument("--time", action="store_true", help="Time prediction process.")
30+
parser.add_argument("--seg_class", default=None, type=str,
31+
help="Segmentation class to load parameters for masking input.")
32+
parser.add_argument("--center_distance_threshold", default=0.4, type=float,
33+
help="The threshold applied to the distance center predictions to derive seeds.")
34+
parser.add_argument("--boundary_distance_threshold", default=None, type=float,
35+
help="The threshold applied to the boundary predictions to derive seeds. \
36+
By default this is set to 'None', \
37+
in which case the boundary distances are not used for the seeds.")
38+
parser.add_argument("--fg_threshold", default=0.5, type=float,
39+
help="The threshold applied to the foreground prediction for deriving the watershed mask.")
2340

2441
args = parser.parse_args()
2542

@@ -36,21 +53,57 @@ def main():
3653
if args.block_shape is None:
3754
block_shape = (64, 256, 256) if have_cuda else (64, 64, 64)
3855
else:
39-
block_shape = tuple(args.block_shape)
40-
halo = (16, 64, 64) if have_cuda else (8, 32, 32)
56+
block_shape = tuple(json.loads(args.block_shape))
57+
4158
else:
4259
if args.block_shape is None:
4360
chunks = z5py.File(args.input, "r")[args.input_key].chunks
4461
block_shape = tuple([2 * ch for ch in chunks]) if have_cuda else tuple(chunks)
4562
else:
46-
block_shape = tuple(args.block_shape)
63+
block_shape = json.loads(args.block_shape)
64+
65+
if args.halo is None:
4766
halo = (16, 64, 64) if have_cuda else (8, 32, 32)
67+
else:
68+
halo = tuple(json.loads(args.halo))
69+
70+
if args.time:
71+
start = time.perf_counter()
72+
73+
if args.memory:
74+
segmentation = run_unet_prediction(
75+
args.input, args.input_key, output_folder=None, model_path=args.model,
76+
scale=scale, min_size=min_size,
77+
block_shape=block_shape, halo=halo,
78+
seg_class=args.seg_class,
79+
center_distance_threshold = args.center_distance_threshold,
80+
boundary_distance_threshold = args.boundary_distance_threshold,
81+
fg_threshold = args.fg_threshold,
82+
)
83+
84+
abs_path = os.path.abspath(args.input)
85+
basename = ".".join(os.path.basename(abs_path).split(".")[:-1])
86+
output_path = os.path.join(args.output_folder, basename + "_seg.tif")
87+
imageio.imwrite(output_path, segmentation, compression="zlib")
88+
timer_output = os.path.join(args.output_folder, basename + "_timer.json")
89+
90+
else:
91+
run_unet_prediction(
92+
args.input, args.input_key, output_folder=args.output_folder, model_path=args.model,
93+
scale=scale, min_size=min_size,
94+
block_shape=block_shape, halo=halo,
95+
seg_class=args.seg_class,
96+
center_distance_threshold = args.center_distance_threshold,
97+
boundary_distance_threshold = args.boundary_distance_threshold,
98+
fg_threshold = args.fg_threshold,
99+
)
100+
timer_output = os.path.join(args.output_folder, "timer.json")
48101

49-
run_unet_prediction(
50-
args.input, args.input_key, args.output_folder, args.model,
51-
scale=scale, min_size=min_size,
52-
block_shape=block_shape, halo=halo,
53-
)
102+
if args.time:
103+
duration = time.perf_counter() - start
104+
time_dict = {"total_duration[s]": duration}
105+
with open(timer_output, "w") as f:
106+
json.dump(time_dict, f, indent='\t', separators=(',', ': '))
54107

55108

56109
if __name__ == "__main__":

0 commit comments

Comments
 (0)