Skip to content

Commit 40f0813

Browse files
committed
Fixed requirement of SLURM_ARRAY_TASK_ID
1 parent 8f2dfb1 commit 40f0813

File tree

2 files changed

+28
-23
lines changed

2 files changed

+28
-23
lines changed

flamingo_tools/segmentation/unet_prediction.py

Lines changed: 11 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def prediction_impl(input_path, input_key, output_folder, model_path, scale, blo
8080
halo = (16, 32, 32)
8181

8282
# Compute the global mean and standard deviation.
83-
n_threads = min(2, mp.cpu_count())
83+
n_threads = min(16, mp.cpu_count())
8484
mean, std = parallel.mean_and_std(
8585
input_, block_shape=tuple([2* i for i in input_.chunks]), n_threads=n_threads, verbose=True,
8686
mask=image_mask
@@ -243,25 +243,17 @@ def run_unet_prediction(
243243
output_folder, model_path,
244244
min_size, scale=None,
245245
block_shape=None, halo=None,
246-
prediction_instances=1,
247246
):
248-
if prediction_instances > 1:
249-
run_unet_prediction_slurm(
250-
input_path, input_key, output_folder, model_path,
251-
scale=scale, block_shape=block_shape, halo=halo,
252-
prediction_instances=prediction_instances,
253-
)
254-
else:
255-
os.makedirs(output_folder, exist_ok=True)
247+
os.makedirs(output_folder, exist_ok=True)
256248

257-
find_mask(input_path, input_key, output_folder)
249+
find_mask(input_path, input_key, output_folder)
258250

259-
original_shape = prediction_impl(
260-
input_path, input_key, output_folder, model_path, scale, block_shape, halo
261-
)
251+
original_shape = prediction_impl(
252+
input_path, input_key, output_folder, model_path, scale, block_shape, halo
253+
)
262254

263-
pmap_out = os.path.join(output_folder, "predictions.zarr")
264-
segmentation_impl(pmap_out, output_folder, min_size=min_size, original_shape=original_shape)
255+
pmap_out = os.path.join(output_folder, "predictions.zarr")
256+
segmentation_impl(pmap_out, output_folder, min_size=min_size, original_shape=original_shape)
265257

266258
def run_unet_prediction_slurm(
267259
input_path, input_key, output_folder, model_path,
@@ -271,8 +263,11 @@ def run_unet_prediction_slurm(
271263
os.makedirs(output_folder, exist_ok=True)
272264
prediction_instances = int(prediction_instances)
273265
slurm_task_id = os.environ.get("SLURM_ARRAY_TASK_ID")
266+
274267
if slurm_task_id is not None:
275268
slurm_task_id = int(slurm_task_id)
269+
else:
270+
raise ValueError("The SLURM_ARRAY_TASK_ID is not set. Ensure that you are using the '-a' option with SBATCH.")
276271

277272
find_mask(input_path, input_key, output_folder)
278273

scripts/prediction/run_prediction_distance_unet.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99

1010
def main():
11-
from flamingo_tools.segmentation import run_unet_prediction
11+
from flamingo_tools.segmentation import run_unet_prediction, run_unet_prediction_slurm
1212

1313
parser = argparse.ArgumentParser()
1414
parser.add_argument("-i", "--input", required=True)
@@ -39,12 +39,22 @@ def main():
3939

4040
prediction_instances = args.number_gpu if have_cuda else 1
4141

42-
run_unet_prediction(
43-
args.input, args.input_key, args.output_folder, args.model,
44-
scale=scale, min_size=min_size,
45-
block_shape=block_shape, halo=halo,
46-
prediction_instances=prediction_instances,
47-
)
42+
if 1 > prediction_instances:
43+
# FIXME: only does prediction part, no segmentation yet
44+
# FIXME: implement array job
45+
run_unet_prediction_slurm(
46+
args.input, args.input_key, args.output_folder, args.model,
47+
scale=scale,
48+
block_shape=block_shape, halo=halo,
49+
prediction_instances=prediction_instances,
50+
)
51+
else:
52+
53+
run_unet_prediction(
54+
args.input, args.input_key, args.output_folder, args.model,
55+
scale=scale, min_size=min_size,
56+
block_shape=block_shape, halo=halo,
57+
)
4858

4959

5060
if __name__ == "__main__":

0 commit comments

Comments
 (0)