Skip to content

Commit 8e3ef78

Browse files
committed
ENH: Simplify ITK_GAUSSION metadata computation
1 parent 6a62c2e commit 8e3ef78

File tree

1 file changed

+35
-57
lines changed

1 file changed

+35
-57
lines changed

ngff_zarr/methods/_itk.py

Lines changed: 35 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,7 @@ def _downsample_itk_gaussian(
213213
ngff_image: NgffImage, default_chunks, out_chunks, scale_factors
214214
):
215215
import itk
216+
from itkwasm_downsample import gaussian_kernel_radius
216217

217218
# Optionally run accelerated smoothing with itk-vkfft
218219
if "VkFFTBackend" in dir(itk):
@@ -237,80 +238,57 @@ def _downsample_itk_gaussian(
237238
)
238239
previous_image = _align_chunks(previous_image, default_chunks, dim_factors)
239240

240-
shrink_factors = [dim_factors[sd] for sd in spatial_dims]
241-
242-
# Compute metadata for region splitting
241+
translation, scale = _next_scale_metadata(
242+
previous_image, dim_factors, spatial_dims
243+
)
243244

244245
# Blocks 0, ..., N-2 have the same shape
245246
block_0_input = _get_block(previous_image, 0)
247+
next_block_0_shape = _next_block_shape(
248+
previous_image, dim_factors, spatial_dims, block_0_input
249+
)
250+
block_0_size = []
251+
for dim in spatial_dims:
252+
if dim in previous_image.dims:
253+
block_0_size.append(block_0_input.shape[previous_image.dims.index(dim)])
254+
else:
255+
block_0_size.append(1)
256+
block_0_size.reverse()
257+
246258
# Block N-1 may be smaller than preceding blocks
247259
block_neg1_input = _get_block(previous_image, -1)
248-
249-
# Compute overlap for Gaussian blurring for all blocks
250-
block_0_image = itk.image_from_array(np.ones_like(block_0_input))
251-
input_spacing = [previous_image.scale[d] for d in spatial_dims]
252-
block_0_image.SetSpacing(input_spacing)
253-
input_origin = [previous_image.translation[d] for d in spatial_dims]
254-
block_0_image.SetOrigin(input_origin)
260+
next_block_neg1_shape = _next_block_shape(
261+
previous_image, dim_factors, spatial_dims, block_neg1_input
262+
)
255263

256264
# pixel units
265+
# Compute metadata for region splitting
266+
shrink_factors = [dim_factors[sd] for sd in spatial_dims]
257267
sigma_values = _compute_sigma(shrink_factors)
258-
kernel_radius = _compute_itk_gaussian_kernel_radius(
259-
itk.size(block_0_image), sigma_values
260-
)
268+
kernel_radius = gaussian_kernel_radius(size=block_0_size, sigma=sigma_values)
269+
270+
dtype = block_0_input.dtype
261271

262-
# Compute output size and spatial metadata for blocks 0, .., N-2
263-
filt = itk.BinShrinkImageFilter.New(
264-
block_0_image, shrink_factors=shrink_factors
265-
)
266-
filt.UpdateOutputInformation()
267-
block_output = filt.GetOutput()
268-
block_0_output_spacing = block_output.GetSpacing()
269-
block_0_output_origin = block_output.GetOrigin()
270-
271-
scale = {_image_dims[i]: s for (i, s) in enumerate(block_0_output_spacing)}
272-
translation = {_image_dims[i]: s for (i, s) in enumerate(block_0_output_origin)}
273-
dtype = block_output.dtype
274-
275-
computed_size = [
276-
int(block_len / shrink_factor)
277-
for block_len, shrink_factor in zip(itk.size(block_0_image), shrink_factors)
278-
]
279-
assert all(
280-
itk.size(block_output)[dim] == computed_size[dim]
281-
for dim in range(block_output.ndim)
282-
)
283272
output_chunks = list(previous_image.data.chunks)
284-
if "t" in previous_image.dims:
285-
dims = list(previous_image.dims)
286-
t_index = dims.index("t")
287-
output_chunks.pop(t_index)
273+
output_chunks_start = 0
274+
while previous_image.dims[output_chunks_start] not in _spatial_dims:
275+
output_chunks_start += 1
276+
output_chunks = output_chunks[output_chunks_start:]
277+
next_block_0_shape = next_block_0_shape[output_chunks_start:]
288278
for i, c in enumerate(output_chunks):
289279
output_chunks[i] = [
290-
block_output.shape[i],
280+
next_block_0_shape[i],
291281
] * len(c)
292-
# Compute output size for block N-1
293-
block_neg1_image = itk.image_from_array(np.ones_like(block_neg1_input))
294-
block_neg1_image.SetSpacing(input_spacing)
295-
block_neg1_image.SetOrigin(input_origin)
296-
filt.SetInput(block_neg1_image)
297-
filt.UpdateOutputInformation()
298-
block_output = filt.GetOutput()
299-
computed_size = [
300-
int(block_len / shrink_factor)
301-
for block_len, shrink_factor in zip(
302-
itk.size(block_neg1_image), shrink_factors
303-
)
304-
]
305-
assert all(
306-
itk.size(block_output)[dim] == computed_size[dim]
307-
for dim in range(block_output.ndim)
308-
)
282+
283+
next_block_neg1_shape = next_block_neg1_shape[output_chunks_start:]
309284
for i in range(len(output_chunks)):
310-
output_chunks[i][-1] = block_output.shape[i]
285+
output_chunks[i][-1] = next_block_neg1_shape[i]
311286
output_chunks[i] = tuple(output_chunks[i])
312287
output_chunks = tuple(output_chunks)
313288

289+
if "t" in previous_image.dims:
290+
t_index = previous_image.dims.index("t")
291+
314292
if "t" in previous_image.dims:
315293
all_timepoints = []
316294
for timepoint in range(previous_image.data.shape[t_index]):

0 commit comments

Comments
 (0)