1111 _update_previous_dim_factors ,
1212 _get_block ,
1313 _spatial_dims ,
14+ _next_scale_metadata ,
15+ _next_block_shape ,
1416)
1517
1618_image_dims : Tuple [str , str , str , str ] = ("x" , "y" , "z" , "t" )
@@ -142,47 +144,47 @@ def _downsample_itk_bin_shrink(
142144 )
143145 previous_image = _align_chunks (previous_image , default_chunks , dim_factors )
144146
147+ translation , scale = _next_scale_metadata (
148+ previous_image , dim_factors , spatial_dims
149+ )
150+
151+ # Blocks 0, ..., N-2 have the same shape
152+ block_0_input = _get_block (previous_image , 0 )
153+ next_block_0_shape = _next_block_shape (
154+ previous_image , dim_factors , spatial_dims , block_0_input
155+ )
156+ block_0_size = []
157+ for dim in spatial_dims :
158+ if dim in previous_image .dims :
159+ block_0_size .append (block_0_input .shape [previous_image .dims .index (dim )])
160+ else :
161+ block_0_size .append (1 )
162+ block_0_size .reverse ()
163+
164+ # Block N-1 may be smaller than preceding blocks
165+ block_neg1_input = _get_block (previous_image , - 1 )
166+ next_block_neg1_shape = _next_block_shape (
167+ previous_image , dim_factors , spatial_dims , block_neg1_input
168+ )
169+
145170 shrink_factors = [dim_factors [sd ] for sd in spatial_dims ]
146171
147- block_0 = _get_block ( previous_image , 0 )
172+ dtype = block_0_input . dtype
148173
149- # For consistency for now, do not utilize direction until there is standardized support for
150- # direction cosines / orientation in OME-NGFF
151- # block_0.attrs.pop("direction", None)
152- if "c" in previous_image .dims :
153- raise ValueError (
154- "Downsampling with ITK BinShrinkImageFilter does not support channel dimension 'c'. "
155- "Use ITK Gaussian downsampling instead."
156- )
157- block_input = itk .image_from_array (np .ones_like (block_0 ))
158- spacing = [previous_image .scale [d ] for d in spatial_dims ]
159- block_input .SetSpacing (spacing )
160- origin = [previous_image .translation [d ] for d in spatial_dims ]
161- block_input .SetOrigin (origin )
162- filt = itk .BinShrinkImageFilter .New (block_input , shrink_factors = shrink_factors )
163- filt .UpdateOutputInformation ()
164- block_output = filt .GetOutput ()
165- scale = {_image_dims [i ]: s for (i , s ) in enumerate (block_output .GetSpacing ())}
166- translation = {
167- _image_dims [i ]: s for (i , s ) in enumerate (block_output .GetOrigin ())
168- }
169- dtype = block_output .dtype
170174 output_chunks = list (previous_image .data .chunks )
175+ output_chunks_start = 0
176+ while previous_image .dims [output_chunks_start ] not in _spatial_dims :
177+ output_chunks_start += 1
178+ output_chunks = output_chunks [output_chunks_start :]
179+ next_block_0_shape = next_block_0_shape [output_chunks_start :]
171180 for i , c in enumerate (output_chunks ):
172181 output_chunks [i ] = [
173- block_output . shape [i ],
182+ next_block_0_shape [i ],
174183 ] * len (c )
175184
176- block_neg1 = _get_block (previous_image , - 1 )
177- # block_neg1.attrs.pop("direction", None)
178- block_input = itk .image_from_array (np .ones_like (block_neg1 ))
179- block_input .SetSpacing (spacing )
180- block_input .SetOrigin (origin )
181- filt = itk .BinShrinkImageFilter .New (block_input , shrink_factors = shrink_factors )
182- filt .UpdateOutputInformation ()
183- block_output = filt .GetOutput ()
185+ next_block_neg1_shape = next_block_neg1_shape [output_chunks_start :]
184186 for i in range (len (output_chunks )):
185- output_chunks [i ][- 1 ] = block_output . shape [i ]
187+ output_chunks [i ][- 1 ] = next_block_neg1_shape [i ]
186188 output_chunks [i ] = tuple (output_chunks [i ])
187189 output_chunks = tuple (output_chunks )
188190
@@ -211,6 +213,7 @@ def _downsample_itk_gaussian(
211213 ngff_image : NgffImage , default_chunks , out_chunks , scale_factors
212214):
213215 import itk
216+ from itkwasm_downsample import gaussian_kernel_radius
214217
215218 # Optionally run accelerated smoothing with itk-vkfft
216219 if "VkFFTBackend" in dir (itk ):
@@ -235,80 +238,57 @@ def _downsample_itk_gaussian(
235238 )
236239 previous_image = _align_chunks (previous_image , default_chunks , dim_factors )
237240
238- shrink_factors = [ dim_factors [ sd ] for sd in spatial_dims ]
239-
240- # Compute metadata for region splitting
241+ translation , scale = _next_scale_metadata (
242+ previous_image , dim_factors , spatial_dims
243+ )
241244
242245 # Blocks 0, ..., N-2 have the same shape
243246 block_0_input = _get_block (previous_image , 0 )
247+ next_block_0_shape = _next_block_shape (
248+ previous_image , dim_factors , spatial_dims , block_0_input
249+ )
250+ block_0_size = []
251+ for dim in spatial_dims :
252+ if dim in previous_image .dims :
253+ block_0_size .append (block_0_input .shape [previous_image .dims .index (dim )])
254+ else :
255+ block_0_size .append (1 )
256+ block_0_size .reverse ()
257+
244258 # Block N-1 may be smaller than preceding blocks
245259 block_neg1_input = _get_block (previous_image , - 1 )
246-
247- # Compute overlap for Gaussian blurring for all blocks
248- block_0_image = itk .image_from_array (np .ones_like (block_0_input ))
249- input_spacing = [previous_image .scale [d ] for d in spatial_dims ]
250- block_0_image .SetSpacing (input_spacing )
251- input_origin = [previous_image .translation [d ] for d in spatial_dims ]
252- block_0_image .SetOrigin (input_origin )
260+ next_block_neg1_shape = _next_block_shape (
261+ previous_image , dim_factors , spatial_dims , block_neg1_input
262+ )
253263
254264 # pixel units
265+ # Compute metadata for region splitting
266+ shrink_factors = [dim_factors [sd ] for sd in spatial_dims ]
255267 sigma_values = _compute_sigma (shrink_factors )
256- kernel_radius = _compute_itk_gaussian_kernel_radius (
257- itk . size ( block_0_image ), sigma_values
258- )
268+ kernel_radius = gaussian_kernel_radius ( size = block_0_size , sigma = sigma_values )
269+
270+ dtype = block_0_input . dtype
259271
260- # Compute output size and spatial metadata for blocks 0, .., N-2
261- filt = itk .BinShrinkImageFilter .New (
262- block_0_image , shrink_factors = shrink_factors
263- )
264- filt .UpdateOutputInformation ()
265- block_output = filt .GetOutput ()
266- block_0_output_spacing = block_output .GetSpacing ()
267- block_0_output_origin = block_output .GetOrigin ()
268-
269- scale = {_image_dims [i ]: s for (i , s ) in enumerate (block_0_output_spacing )}
270- translation = {_image_dims [i ]: s for (i , s ) in enumerate (block_0_output_origin )}
271- dtype = block_output .dtype
272-
273- computed_size = [
274- int (block_len / shrink_factor )
275- for block_len , shrink_factor in zip (itk .size (block_0_image ), shrink_factors )
276- ]
277- assert all (
278- itk .size (block_output )[dim ] == computed_size [dim ]
279- for dim in range (block_output .ndim )
280- )
281272 output_chunks = list (previous_image .data .chunks )
282- if "t" in previous_image .dims :
283- dims = list (previous_image .dims )
284- t_index = dims .index ("t" )
285- output_chunks .pop (t_index )
273+ output_chunks_start = 0
274+ while previous_image .dims [output_chunks_start ] not in _spatial_dims :
275+ output_chunks_start += 1
276+ output_chunks = output_chunks [output_chunks_start :]
277+ next_block_0_shape = next_block_0_shape [output_chunks_start :]
286278 for i , c in enumerate (output_chunks ):
287279 output_chunks [i ] = [
288- block_output . shape [i ],
280+ next_block_0_shape [i ],
289281 ] * len (c )
290- # Compute output size for block N-1
291- block_neg1_image = itk .image_from_array (np .ones_like (block_neg1_input ))
292- block_neg1_image .SetSpacing (input_spacing )
293- block_neg1_image .SetOrigin (input_origin )
294- filt .SetInput (block_neg1_image )
295- filt .UpdateOutputInformation ()
296- block_output = filt .GetOutput ()
297- computed_size = [
298- int (block_len / shrink_factor )
299- for block_len , shrink_factor in zip (
300- itk .size (block_neg1_image ), shrink_factors
301- )
302- ]
303- assert all (
304- itk .size (block_output )[dim ] == computed_size [dim ]
305- for dim in range (block_output .ndim )
306- )
282+
283+ next_block_neg1_shape = next_block_neg1_shape [output_chunks_start :]
307284 for i in range (len (output_chunks )):
308- output_chunks [i ][- 1 ] = block_output . shape [i ]
285+ output_chunks [i ][- 1 ] = next_block_neg1_shape [i ]
309286 output_chunks [i ] = tuple (output_chunks [i ])
310287 output_chunks = tuple (output_chunks )
311288
289+ if "t" in previous_image .dims :
290+ t_index = previous_image .dims .index ("t" )
291+
312292 if "t" in previous_image .dims :
313293 all_timepoints = []
314294 for timepoint in range (previous_image .data .shape [t_index ]):
0 commit comments