Skip to content

Commit 236f7d0

Browse files
committed
Prediction distance unet with multiple GPUs
1 parent 51cc1b9 commit 236f7d0

File tree

2 files changed

+61
-11
lines changed

2 files changed

+61
-11
lines changed

flamingo_tools/segmentation/unet_prediction.py

Lines changed: 57 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -37,13 +37,13 @@ def ndim(self):
3737
return self._volume.ndim - 1
3838

3939

40-
def prediction_impl(input_path, input_key, output_folder, model_path, scale, block_shape, halo):
40+
def prediction_impl(input_path, input_key, output_folder, model_path, scale, block_shape, halo, prediction_instances=1, slurm_task_id=0):
4141
with warnings.catch_warnings():
4242
warnings.simplefilter("ignore")
4343
if os.path.isdir(model_path):
4444
model = load_model(model_path)
4545
else:
46-
model = torch.load(model_path)
46+
model = torch.load(model_path, weights_only=False)
4747

4848
mask_path = os.path.join(output_folder, "mask.zarr")
4949
image_mask = z5py.File(mask_path, "r")["mask"]
@@ -65,22 +65,24 @@ def prediction_impl(input_path, input_key, output_folder, model_path, scale, blo
6565
input_ = ResizedVolume(input_, shape=new_shape, order=3)
6666
image_mask = ResizedVolume(image_mask, new_shape, order=0)
6767

68+
chunks = (128, 128, 128)
69+
block_shape = chunks
70+
6871
have_cuda = torch.cuda.is_available()
69-
if block_shape is None:
70-
block_shape = tuple([2 * ch for ch in input_.chunks]) if have_cuda else input_.chunks
71-
if halo is None:
72-
halo = (16, 64, 64) if have_cuda else (16, 32, 32)
72+
assert have_cuda
7373
if have_cuda:
7474
print("Predict with GPU")
7575
gpu_ids = [0]
7676
else:
7777
print("Predict with CPU")
7878
gpu_ids = ["cpu"]
79+
if halo is None:
80+
halo = (16, 32, 32)
7981

8082
# Compute the global mean and standard deviation.
81-
n_threads = min(16, mp.cpu_count())
83+
n_threads = min(2, mp.cpu_count())
8284
mean, std = parallel.mean_and_std(
83-
input_, block_shape=block_shape, n_threads=n_threads, verbose=True,
85+
input_, block_shape=tuple([2* i for i in input_.chunks]), n_threads=n_threads, verbose=True,
8486
mask=image_mask
8587
)
8688
print("Mean and standard deviation computed for the full volume:")
@@ -98,12 +100,24 @@ def postprocess(x):
98100
x[1] = vigra.filters.gaussianSmoothing(x[1], sigma=2.0)
99101
return x
100102

103+
shape = input_.shape
104+
ndim = len(shape)
105+
106+
blocking = nt.blocking([0] * ndim, shape, block_shape)
107+
n_blocks = blocking.numberOfBlocks
108+
iteration_ids = []
109+
if 1 != prediction_instances:
110+
iteration_ids = [x.tolist() for x in np.array_split(list(range(n_blocks)), prediction_instances)]
111+
slurm_iteration = iteration_ids[slurm_task_id]
112+
else:
113+
slurm_iteration = list(range(n_blocks))
114+
101115
output_path = os.path.join(output_folder, "predictions.zarr")
102116
with open_file(output_path, "a") as f:
103117
output = f.require_dataset(
104118
"prediction",
105119
shape=(3,) + input_.shape,
106-
chunks=(1,) + block_shape,
120+
chunks=(1,) + chunks,
107121
compression="gzip",
108122
dtype="float32",
109123
)
@@ -113,6 +127,7 @@ def postprocess(x):
113127
gpu_ids=gpu_ids, block_shape=block_shape, halo=halo,
114128
output=output, preprocess=preprocess, postprocess=postprocess,
115129
mask=image_mask,
130+
iter_list=slurm_iteration,
116131
)
117132

118133
return original_shape
@@ -228,14 +243,45 @@ def run_unet_prediction(
228243
output_folder, model_path,
229244
min_size, scale=None,
230245
block_shape=None, halo=None,
246+
prediction_instances=1,
247+
):
248+
if prediction_instances > 1:
249+
run_unet_prediction_slurm(
250+
input_path, input_key, output_folder, model_path,
251+
scale=scale, block_shape=block_shape, halo=halo,
252+
prediction_instances=prediction_instances,
253+
)
254+
else:
255+
os.makedirs(output_folder, exist_ok=True)
256+
257+
find_mask(input_path, input_key, output_folder)
258+
259+
original_shape = prediction_impl(
260+
input_path, input_key, output_folder, model_path, scale, block_shape, halo
261+
)
262+
263+
pmap_out = os.path.join(output_folder, "predictions.zarr")
264+
segmentation_impl(pmap_out, output_folder, min_size=min_size, original_shape=original_shape)
265+
266+
def run_unet_prediction_slurm(
267+
input_path, input_key, output_folder, model_path,
268+
scale=None,
269+
block_shape=None, halo=None, prediction_instances=1,
231270
):
232271
os.makedirs(output_folder, exist_ok=True)
272+
prediction_instances = int(prediction_instances)
273+
slurm_task_id = os.environ.get("SLURM_ARRAY_TASK_ID")
274+
if slurm_task_id is not None:
275+
slurm_task_id = int(slurm_task_id)
233276

234277
find_mask(input_path, input_key, output_folder)
235278

236279
original_shape = prediction_impl(
237-
input_path, input_key, output_folder, model_path, scale, block_shape, halo
280+
input_path, input_key, output_folder, model_path, scale, block_shape, halo, prediction_instances, slurm_task_id
238281
)
239282

283+
# does NOT need GPU, FIXME: only run on CPU
284+
def run_unet_segmentation_slurm(output_folder, min_size):
285+
min_size = int(min_size)
240286
pmap_out = os.path.join(output_folder, "predictions.zarr")
241-
segmentation_impl(pmap_out, output_folder, min_size=min_size, original_shape=original_shape)
287+
segmentation_impl(pmap_out, output_folder, min_size=min_size)

scripts/prediction/run_prediction_distance_unet.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ def main():
1616
parser.add_argument("-m", "--model", required=True)
1717
parser.add_argument("-k", "--input_key", default=None)
1818
parser.add_argument("-s", "--scale", default=None, type=float, help="Downscale the image by the given factor.")
19+
parser.add_argument("-n", "--number_gpu", default=1, type=int, help="Number of GPUs to use in parallel.")
1920

2021
args = parser.parse_args()
2122

@@ -36,10 +37,13 @@ def main():
3637
block_shape = tuple([2 * ch for ch in chunks]) if have_cuda else tuple(chunks)
3738
halo = (16, 64, 64) if have_cuda else (8, 32, 32)
3839

40+
prediction_instances = args.number_gpu if have_cuda else 1
41+
3942
run_unet_prediction(
4043
args.input, args.input_key, args.output_folder, args.model,
4144
scale=scale, min_size=min_size,
4245
block_shape=block_shape, halo=halo,
46+
prediction_instances=prediction_instances,
4347
)
4448

4549

0 commit comments

Comments
 (0)