Skip to content

Commit 4ffe15e

Browse files
Fix tests
1 parent 326bc96 commit 4ffe15e

File tree

1 file changed

+3
-4
lines changed

1 file changed

+3
-4
lines changed

flamingo_tools/segmentation/unet_prediction.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -149,17 +149,16 @@ def postprocess(x):
149149
else:
150150
postprocess = None if output_channels > 1 else lambda x: x.squeeze()
151151

152+
gpu_ids, block_shape, halo = _get_device_and_tiling(block_shape, halo, input_)
153+
shape = input_.shape
154+
ndim = len(shape)
152155
if output_channels > 1:
153156
output_shape = (output_channels,) + input_.shape
154157
output_chunks = (1,) + block_shape
155158
else:
156159
output_shape = input_.shape
157160
output_chunks = block_shape
158161

159-
shape = input_.shape
160-
ndim = len(shape)
161-
gpu_ids, block_shape, halo = _get_device_and_tiling(block_shape, halo, input_)
162-
163162
blocking = nt.blocking([0] * ndim, shape, block_shape)
164163
n_blocks = blocking.numberOfBlocks
165164
if prediction_instances != 1:

0 commit comments

Comments
 (0)