@@ -213,6 +213,7 @@ def _downsample_itk_gaussian(
213213 ngff_image : NgffImage , default_chunks , out_chunks , scale_factors
214214):
215215 import itk
216+ from itkwasm_downsample import gaussian_kernel_radius
216217
217218 # Optionally run accelerated smoothing with itk-vkfft
218219 if "VkFFTBackend" in dir (itk ):
@@ -237,80 +238,57 @@ def _downsample_itk_gaussian(
237238 )
238239 previous_image = _align_chunks (previous_image , default_chunks , dim_factors )
239240
240- shrink_factors = [ dim_factors [ sd ] for sd in spatial_dims ]
241-
242- # Compute metadata for region splitting
241+ translation , scale = _next_scale_metadata (
242+ previous_image , dim_factors , spatial_dims
243+ )
243244
244245 # Blocks 0, ..., N-2 have the same shape
245246 block_0_input = _get_block (previous_image , 0 )
247+ next_block_0_shape = _next_block_shape (
248+ previous_image , dim_factors , spatial_dims , block_0_input
249+ )
250+ block_0_size = []
251+ for dim in spatial_dims :
252+ if dim in previous_image .dims :
253+ block_0_size .append (block_0_input .shape [previous_image .dims .index (dim )])
254+ else :
255+ block_0_size .append (1 )
256+ block_0_size .reverse ()
257+
246258 # Block N-1 may be smaller than preceding blocks
247259 block_neg1_input = _get_block (previous_image , - 1 )
248-
249- # Compute overlap for Gaussian blurring for all blocks
250- block_0_image = itk .image_from_array (np .ones_like (block_0_input ))
251- input_spacing = [previous_image .scale [d ] for d in spatial_dims ]
252- block_0_image .SetSpacing (input_spacing )
253- input_origin = [previous_image .translation [d ] for d in spatial_dims ]
254- block_0_image .SetOrigin (input_origin )
260+ next_block_neg1_shape = _next_block_shape (
261+ previous_image , dim_factors , spatial_dims , block_neg1_input
262+ )
255263
256264 # pixel units
265+ # Compute metadata for region splitting
266+ shrink_factors = [dim_factors [sd ] for sd in spatial_dims ]
257267 sigma_values = _compute_sigma (shrink_factors )
258- kernel_radius = _compute_itk_gaussian_kernel_radius (
259- itk . size ( block_0_image ), sigma_values
260- )
268+ kernel_radius = gaussian_kernel_radius ( size = block_0_size , sigma = sigma_values )
269+
270+ dtype = block_0_input . dtype
261271
262- # Compute output size and spatial metadata for blocks 0, .., N-2
263- filt = itk .BinShrinkImageFilter .New (
264- block_0_image , shrink_factors = shrink_factors
265- )
266- filt .UpdateOutputInformation ()
267- block_output = filt .GetOutput ()
268- block_0_output_spacing = block_output .GetSpacing ()
269- block_0_output_origin = block_output .GetOrigin ()
270-
271- scale = {_image_dims [i ]: s for (i , s ) in enumerate (block_0_output_spacing )}
272- translation = {_image_dims [i ]: s for (i , s ) in enumerate (block_0_output_origin )}
273- dtype = block_output .dtype
274-
275- computed_size = [
276- int (block_len / shrink_factor )
277- for block_len , shrink_factor in zip (itk .size (block_0_image ), shrink_factors )
278- ]
279- assert all (
280- itk .size (block_output )[dim ] == computed_size [dim ]
281- for dim in range (block_output .ndim )
282- )
283272 output_chunks = list (previous_image .data .chunks )
284- if "t" in previous_image .dims :
285- dims = list (previous_image .dims )
286- t_index = dims .index ("t" )
287- output_chunks .pop (t_index )
273+ output_chunks_start = 0
274+ while previous_image .dims [output_chunks_start ] not in _spatial_dims :
275+ output_chunks_start += 1
276+ output_chunks = output_chunks [output_chunks_start :]
277+ next_block_0_shape = next_block_0_shape [output_chunks_start :]
288278 for i , c in enumerate (output_chunks ):
289279 output_chunks [i ] = [
290- block_output . shape [i ],
280+ next_block_0_shape [i ],
291281 ] * len (c )
292- # Compute output size for block N-1
293- block_neg1_image = itk .image_from_array (np .ones_like (block_neg1_input ))
294- block_neg1_image .SetSpacing (input_spacing )
295- block_neg1_image .SetOrigin (input_origin )
296- filt .SetInput (block_neg1_image )
297- filt .UpdateOutputInformation ()
298- block_output = filt .GetOutput ()
299- computed_size = [
300- int (block_len / shrink_factor )
301- for block_len , shrink_factor in zip (
302- itk .size (block_neg1_image ), shrink_factors
303- )
304- ]
305- assert all (
306- itk .size (block_output )[dim ] == computed_size [dim ]
307- for dim in range (block_output .ndim )
308- )
282+
283+ next_block_neg1_shape = next_block_neg1_shape [output_chunks_start :]
309284 for i in range (len (output_chunks )):
310- output_chunks [i ][- 1 ] = block_output . shape [i ]
285+ output_chunks [i ][- 1 ] = next_block_neg1_shape [i ]
311286 output_chunks [i ] = tuple (output_chunks [i ])
312287 output_chunks = tuple (output_chunks )
313288
289+ if "t" in previous_image .dims :
290+ t_index = previous_image .dims .index ("t" )
291+
314292 if "t" in previous_image .dims :
315293 all_timepoints = []
316294 for timepoint in range (previous_image .data .shape [t_index ]):
0 commit comments