Skip to content

Commit 9d3a61f

Browse files
committed
Fixed issue with chunk reference
1 parent a6ff2d2 commit 9d3a61f

File tree

1 file changed

+2
-5
lines changed

1 file changed

+2
-5
lines changed

flamingo_tools/segmentation/unet_prediction.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -73,11 +73,8 @@ def prediction_impl(input_path, input_key, output_folder, model_path, scale, blo
7373

7474
have_cuda = torch.cuda.is_available()
7575

76-
if have_cuda:
77-
chunks = (128, 128, 128)
78-
7976
if block_shape is None:
80-
block_shape = chunks if have_cuda else input_.chunks
77+
block_shape = (128, 128, 128) if have_cuda else input_.chunks
8178
if halo is None:
8279
halo = (16, 32, 32)
8380
if have_cuda:
@@ -91,7 +88,7 @@ def prediction_impl(input_path, input_key, output_folder, model_path, scale, blo
9188
# Compute the global mean and standard deviation.
9289
n_threads = min(16, mp.cpu_count())
9390
mean, std = parallel.mean_and_std(
94-
input_, block_shape=tuple([2* i for i in input_.chunks]), n_threads=n_threads, verbose=True,
91+
input_, block_shape=block_shape, n_threads=n_threads, verbose=True,
9592
mask=image_mask
9693
)
9794
print("Mean and standard deviation computed for the full volume:")

0 commit comments

Comments
 (0)