Skip to content

Commit 4394134

Browse files
committed
Fixed issue for setting chunks
1 parent dad0b31 commit 4394134

File tree

1 file changed

+5
-2
lines changed

1 file changed

+5
-2
lines changed

flamingo_tools/segmentation/unet_prediction.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,11 +60,14 @@ def prediction_impl(input_path, input_key, output_folder, model_path, scale, blo
6060

6161
if input_key is None:
6262
input_ = imageio.imread(input_path)
63+
chunks = (64, 64, 64)
6364
elif s3 is not None:
6465
with zarr.open(input_path, mode="r") as f:
6566
input_ = f[input_key]
67+
chunks = input_.chunks()
6668
else:
6769
input_ = open_file(input_path, "r")[input_key]
70+
chunks = (64, 64, 64)
6871

6972
if scale is None or scale == 1:
7073
original_shape = None
@@ -95,7 +98,7 @@ def prediction_impl(input_path, input_key, output_folder, model_path, scale, blo
9598
# Compute the global mean and standard deviation.
9699
n_threads = min(16, mp.cpu_count())
97100
mean, std = parallel.mean_and_std(
98-
input_, block_shape=tuple([2* i for i in input_.chunks]), n_threads=n_threads, verbose=True,
101+
input_, block_shape=tuple([2* i for i in chunks]), n_threads=n_threads, verbose=True,
99102
mask=image_mask
100103
)
101104
print("Mean and standard deviation computed for the full volume:")
@@ -163,7 +166,7 @@ def find_mask(input_path, input_key, output_folder, s3=None):
163166
else:
164167
fin = open_file(input_path, "r")
165168
raw = fin[input_key]
166-
chunks = raw.chunks
169+
chunks = (64, 64, 64)
167170

168171
block_shape = tuple(2 * ch for ch in chunks)
169172
blocking = nt.blocking([0, 0, 0], raw.shape, block_shape)

0 commit comments

Comments
 (0)