Skip to content

Commit a6ff2d2

Browse files
committed
Distance U-Net prediction with CPU
1 parent 9e4588a commit a6ff2d2

File tree

1 file changed

+12
-8
lines changed

1 file changed

+12
-8
lines changed

flamingo_tools/segmentation/unet_prediction.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -71,19 +71,21 @@ def prediction_impl(input_path, input_key, output_folder, model_path, scale, blo
7171
input_ = ResizedVolume(input_, shape=new_shape, order=3)
7272
image_mask = ResizedVolume(image_mask, new_shape, order=0)
7373

74-
chunks = (128, 128, 128)
75-
block_shape = chunks
76-
7774
have_cuda = torch.cuda.is_available()
78-
assert have_cuda
75+
76+
if have_cuda:
77+
chunks = (128, 128, 128)
78+
79+
if block_shape is None:
80+
block_shape = chunks if have_cuda else input_.chunks
81+
if halo is None:
82+
halo = (16, 32, 32)
7983
if have_cuda:
8084
print("Predict with GPU")
8185
gpu_ids = [0]
8286
else:
8387
print("Predict with CPU")
8488
gpu_ids = ["cpu"]
85-
if halo is None:
86-
halo = (16, 32, 32)
8789

8890
if None == mean or None == std:
8991
# Compute the global mean and standard deviation.
@@ -124,7 +126,7 @@ def postprocess(x):
124126
output = f.require_dataset(
125127
"prediction",
126128
shape=(3,) + input_.shape,
127-
chunks=(1,) + chunks,
129+
chunks=(1,) + block_shape,
128130
compression="gzip",
129131
dtype="float32",
130132
)
@@ -328,7 +330,9 @@ def run_unet_prediction_slurm(
328330
std = None
329331

330332
original_shape = prediction_impl(
331-
input_path, input_key, output_folder, model_path, scale, block_shape, halo, prediction_instances, slurm_task_id, mean=mean, std=std
333+
input_path, input_key, output_folder, model_path, scale, block_shape, halo,
334+
prediction_instances=prediction_instances, slurm_task_id=slurm_task_id,
335+
mean=mean, std=std,
332336
)
333337

334338
# does NOT need GPU, FIXME: only run on CPU

0 commit comments

Comments
 (0)