Skip to content

Commit 6cc3f1d

Browse files
authored
Merge pull request fideus-labs#165 from thewtex/metadata-simplify-3
metadata simplify 3
2 parents 572d1b1 + 8e3ef78 commit 6cc3f1d

File tree

4 files changed

+193
-197
lines changed

4 files changed

+193
-197
lines changed

ngff_zarr/methods/_itk.py

Lines changed: 69 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
_update_previous_dim_factors,
1212
_get_block,
1313
_spatial_dims,
14+
_next_scale_metadata,
15+
_next_block_shape,
1416
)
1517

1618
_image_dims: Tuple[str, str, str, str] = ("x", "y", "z", "t")
@@ -142,47 +144,47 @@ def _downsample_itk_bin_shrink(
142144
)
143145
previous_image = _align_chunks(previous_image, default_chunks, dim_factors)
144146

147+
translation, scale = _next_scale_metadata(
148+
previous_image, dim_factors, spatial_dims
149+
)
150+
151+
# Blocks 0, ..., N-2 have the same shape
152+
block_0_input = _get_block(previous_image, 0)
153+
next_block_0_shape = _next_block_shape(
154+
previous_image, dim_factors, spatial_dims, block_0_input
155+
)
156+
block_0_size = []
157+
for dim in spatial_dims:
158+
if dim in previous_image.dims:
159+
block_0_size.append(block_0_input.shape[previous_image.dims.index(dim)])
160+
else:
161+
block_0_size.append(1)
162+
block_0_size.reverse()
163+
164+
# Block N-1 may be smaller than preceding blocks
165+
block_neg1_input = _get_block(previous_image, -1)
166+
next_block_neg1_shape = _next_block_shape(
167+
previous_image, dim_factors, spatial_dims, block_neg1_input
168+
)
169+
145170
shrink_factors = [dim_factors[sd] for sd in spatial_dims]
146171

147-
block_0 = _get_block(previous_image, 0)
172+
dtype = block_0_input.dtype
148173

149-
# For consistency for now, do not utilize direction until there is standardized support for
150-
# direction cosines / orientation in OME-NGFF
151-
# block_0.attrs.pop("direction", None)
152-
if "c" in previous_image.dims:
153-
raise ValueError(
154-
"Downsampling with ITK BinShrinkImageFilter does not support channel dimension 'c'. "
155-
"Use ITK Gaussian downsampling instead."
156-
)
157-
block_input = itk.image_from_array(np.ones_like(block_0))
158-
spacing = [previous_image.scale[d] for d in spatial_dims]
159-
block_input.SetSpacing(spacing)
160-
origin = [previous_image.translation[d] for d in spatial_dims]
161-
block_input.SetOrigin(origin)
162-
filt = itk.BinShrinkImageFilter.New(block_input, shrink_factors=shrink_factors)
163-
filt.UpdateOutputInformation()
164-
block_output = filt.GetOutput()
165-
scale = {_image_dims[i]: s for (i, s) in enumerate(block_output.GetSpacing())}
166-
translation = {
167-
_image_dims[i]: s for (i, s) in enumerate(block_output.GetOrigin())
168-
}
169-
dtype = block_output.dtype
170174
output_chunks = list(previous_image.data.chunks)
175+
output_chunks_start = 0
176+
while previous_image.dims[output_chunks_start] not in _spatial_dims:
177+
output_chunks_start += 1
178+
output_chunks = output_chunks[output_chunks_start:]
179+
next_block_0_shape = next_block_0_shape[output_chunks_start:]
171180
for i, c in enumerate(output_chunks):
172181
output_chunks[i] = [
173-
block_output.shape[i],
182+
next_block_0_shape[i],
174183
] * len(c)
175184

176-
block_neg1 = _get_block(previous_image, -1)
177-
# block_neg1.attrs.pop("direction", None)
178-
block_input = itk.image_from_array(np.ones_like(block_neg1))
179-
block_input.SetSpacing(spacing)
180-
block_input.SetOrigin(origin)
181-
filt = itk.BinShrinkImageFilter.New(block_input, shrink_factors=shrink_factors)
182-
filt.UpdateOutputInformation()
183-
block_output = filt.GetOutput()
185+
next_block_neg1_shape = next_block_neg1_shape[output_chunks_start:]
184186
for i in range(len(output_chunks)):
185-
output_chunks[i][-1] = block_output.shape[i]
187+
output_chunks[i][-1] = next_block_neg1_shape[i]
186188
output_chunks[i] = tuple(output_chunks[i])
187189
output_chunks = tuple(output_chunks)
188190

@@ -211,6 +213,7 @@ def _downsample_itk_gaussian(
211213
ngff_image: NgffImage, default_chunks, out_chunks, scale_factors
212214
):
213215
import itk
216+
from itkwasm_downsample import gaussian_kernel_radius
214217

215218
# Optionally run accelerated smoothing with itk-vkfft
216219
if "VkFFTBackend" in dir(itk):
@@ -235,80 +238,57 @@ def _downsample_itk_gaussian(
235238
)
236239
previous_image = _align_chunks(previous_image, default_chunks, dim_factors)
237240

238-
shrink_factors = [dim_factors[sd] for sd in spatial_dims]
239-
240-
# Compute metadata for region splitting
241+
translation, scale = _next_scale_metadata(
242+
previous_image, dim_factors, spatial_dims
243+
)
241244

242245
# Blocks 0, ..., N-2 have the same shape
243246
block_0_input = _get_block(previous_image, 0)
247+
next_block_0_shape = _next_block_shape(
248+
previous_image, dim_factors, spatial_dims, block_0_input
249+
)
250+
block_0_size = []
251+
for dim in spatial_dims:
252+
if dim in previous_image.dims:
253+
block_0_size.append(block_0_input.shape[previous_image.dims.index(dim)])
254+
else:
255+
block_0_size.append(1)
256+
block_0_size.reverse()
257+
244258
# Block N-1 may be smaller than preceding blocks
245259
block_neg1_input = _get_block(previous_image, -1)
246-
247-
# Compute overlap for Gaussian blurring for all blocks
248-
block_0_image = itk.image_from_array(np.ones_like(block_0_input))
249-
input_spacing = [previous_image.scale[d] for d in spatial_dims]
250-
block_0_image.SetSpacing(input_spacing)
251-
input_origin = [previous_image.translation[d] for d in spatial_dims]
252-
block_0_image.SetOrigin(input_origin)
260+
next_block_neg1_shape = _next_block_shape(
261+
previous_image, dim_factors, spatial_dims, block_neg1_input
262+
)
253263

254264
# pixel units
265+
# Compute metadata for region splitting
266+
shrink_factors = [dim_factors[sd] for sd in spatial_dims]
255267
sigma_values = _compute_sigma(shrink_factors)
256-
kernel_radius = _compute_itk_gaussian_kernel_radius(
257-
itk.size(block_0_image), sigma_values
258-
)
268+
kernel_radius = gaussian_kernel_radius(size=block_0_size, sigma=sigma_values)
269+
270+
dtype = block_0_input.dtype
259271

260-
# Compute output size and spatial metadata for blocks 0, .., N-2
261-
filt = itk.BinShrinkImageFilter.New(
262-
block_0_image, shrink_factors=shrink_factors
263-
)
264-
filt.UpdateOutputInformation()
265-
block_output = filt.GetOutput()
266-
block_0_output_spacing = block_output.GetSpacing()
267-
block_0_output_origin = block_output.GetOrigin()
268-
269-
scale = {_image_dims[i]: s for (i, s) in enumerate(block_0_output_spacing)}
270-
translation = {_image_dims[i]: s for (i, s) in enumerate(block_0_output_origin)}
271-
dtype = block_output.dtype
272-
273-
computed_size = [
274-
int(block_len / shrink_factor)
275-
for block_len, shrink_factor in zip(itk.size(block_0_image), shrink_factors)
276-
]
277-
assert all(
278-
itk.size(block_output)[dim] == computed_size[dim]
279-
for dim in range(block_output.ndim)
280-
)
281272
output_chunks = list(previous_image.data.chunks)
282-
if "t" in previous_image.dims:
283-
dims = list(previous_image.dims)
284-
t_index = dims.index("t")
285-
output_chunks.pop(t_index)
273+
output_chunks_start = 0
274+
while previous_image.dims[output_chunks_start] not in _spatial_dims:
275+
output_chunks_start += 1
276+
output_chunks = output_chunks[output_chunks_start:]
277+
next_block_0_shape = next_block_0_shape[output_chunks_start:]
286278
for i, c in enumerate(output_chunks):
287279
output_chunks[i] = [
288-
block_output.shape[i],
280+
next_block_0_shape[i],
289281
] * len(c)
290-
# Compute output size for block N-1
291-
block_neg1_image = itk.image_from_array(np.ones_like(block_neg1_input))
292-
block_neg1_image.SetSpacing(input_spacing)
293-
block_neg1_image.SetOrigin(input_origin)
294-
filt.SetInput(block_neg1_image)
295-
filt.UpdateOutputInformation()
296-
block_output = filt.GetOutput()
297-
computed_size = [
298-
int(block_len / shrink_factor)
299-
for block_len, shrink_factor in zip(
300-
itk.size(block_neg1_image), shrink_factors
301-
)
302-
]
303-
assert all(
304-
itk.size(block_output)[dim] == computed_size[dim]
305-
for dim in range(block_output.ndim)
306-
)
282+
283+
next_block_neg1_shape = next_block_neg1_shape[output_chunks_start:]
307284
for i in range(len(output_chunks)):
308-
output_chunks[i][-1] = block_output.shape[i]
285+
output_chunks[i][-1] = next_block_neg1_shape[i]
309286
output_chunks[i] = tuple(output_chunks[i])
310287
output_chunks = tuple(output_chunks)
311288

289+
if "t" in previous_image.dims:
290+
t_index = previous_image.dims.index("t")
291+
312292
if "t" in previous_image.dims:
313293
all_timepoints = []
314294
for timepoint in range(previous_image.data.shape[t_index]):

ngff_zarr/methods/_itkwasm.py

Lines changed: 31 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
_get_block,
1515
_spatial_dims,
1616
_spatial_dims_last_zyx,
17+
_next_scale_metadata,
18+
_next_block_shape,
1719
)
1820

1921
_image_dims: Tuple[str, str, str, str] = ("x", "y", "z", "t")
@@ -80,8 +82,7 @@ def _itkwasm_chunk_bin_shrink(
8082
def _downsample_itkwasm(
8183
ngff_image: NgffImage, default_chunks, out_chunks, scale_factors, smoothing
8284
):
83-
import itkwasm
84-
from itkwasm_downsample import downsample_bin_shrink, gaussian_kernel_radius
85+
from itkwasm_downsample import gaussian_kernel_radius
8586

8687
multiscales = [
8788
ngff_image,
@@ -98,84 +99,61 @@ def _downsample_itkwasm(
9899
scale_factor, spatial_dims, previous_dim_factors
99100
)
100101
previous_image = _align_chunks(previous_image, default_chunks, dim_factors)
102+
101103
# Operate on a contiguous spatial block
102104
previous_image = _spatial_dims_last_zyx(previous_image)
103105
if tuple(previous_image.dims) != dims:
104106
transposed_dims = True
105107
reorder = [previous_image.dims.index(dim) for dim in dims]
106108

107-
shrink_factors = [dim_factors[sd] for sd in spatial_dims]
108-
109-
# Compute metadata for region splitting
109+
translation, scale = _next_scale_metadata(
110+
previous_image, dim_factors, spatial_dims
111+
)
110112

111113
# Blocks 0, ..., N-2 have the same shape
112114
block_0_input = _get_block(previous_image, 0)
115+
next_block_0_shape = _next_block_shape(
116+
previous_image, dim_factors, spatial_dims, block_0_input
117+
)
118+
block_0_size = []
119+
for dim in spatial_dims:
120+
if dim in previous_image.dims:
121+
block_0_size.append(block_0_input.shape[previous_image.dims.index(dim)])
122+
else:
123+
block_0_size.append(1)
124+
block_0_size.reverse()
125+
113126
# Block N-1 may be smaller than preceding blocks
114127
block_neg1_input = _get_block(previous_image, -1)
128+
next_block_neg1_shape = _next_block_shape(
129+
previous_image, dim_factors, spatial_dims, block_neg1_input
130+
)
115131

116132
# Compute overlap for Gaussian blurring for all blocks
117133
is_vector = previous_image.dims[-1] == "c"
118-
block_0_image = itkwasm.image_from_array(
119-
np.ones_like(block_0_input), is_vector=is_vector
120-
)
121-
input_spacing = [previous_image.scale[d] for d in spatial_dims]
122-
block_0_image.spacing = input_spacing
123-
input_origin = [previous_image.translation[d] for d in spatial_dims]
124-
block_0_image.origin = input_origin
125134

126135
# pixel units
136+
# Compute metadata for region splitting
137+
shrink_factors = [dim_factors[sd] for sd in spatial_dims]
127138
sigma_values = _compute_sigma(shrink_factors)
128-
kernel_radius = gaussian_kernel_radius(
129-
size=block_0_image.size, sigma=sigma_values
130-
)
139+
kernel_radius = gaussian_kernel_radius(size=block_0_size, sigma=sigma_values)
140+
141+
dtype = block_0_input.dtype
131142

132-
# Compute output size and spatial metadata for blocks 0, .., N-2
133-
block_output = downsample_bin_shrink(
134-
block_0_image, shrink_factors, information_only=False
135-
)
136-
block_0_output_spacing = block_output.spacing
137-
block_0_output_origin = block_output.origin
138-
139-
scale = {_image_dims[i]: s for (i, s) in enumerate(block_0_output_spacing)}
140-
translation = {_image_dims[i]: s for (i, s) in enumerate(block_0_output_origin)}
141-
dtype = block_output.data.dtype
142-
143-
computed_size = [
144-
int(block_len / shrink_factor)
145-
for block_len, shrink_factor in zip(block_0_image.size, shrink_factors)
146-
]
147-
assert all(
148-
block_output.size[dim] == computed_size[dim]
149-
for dim in range(len(block_output.size))
150-
)
151143
output_chunks = list(previous_image.data.chunks)
152144
output_chunks_start = 0
153145
while previous_image.dims[output_chunks_start] not in _spatial_dims:
154146
output_chunks_start += 1
155147
output_chunks = output_chunks[output_chunks_start:]
148+
next_block_0_shape = next_block_0_shape[output_chunks_start:]
156149
for i, c in enumerate(output_chunks):
157150
output_chunks[i] = [
158-
block_output.data.shape[i],
151+
next_block_0_shape[i],
159152
] * len(c)
160-
# Compute output size for block N-1
161-
block_neg1_image = itkwasm.image_from_array(
162-
np.ones_like(block_neg1_input), is_vector=is_vector
163-
)
164-
block_neg1_image.spacing = input_spacing
165-
block_neg1_image.origin = input_origin
166-
block_output = downsample_bin_shrink(
167-
block_neg1_image, shrink_factors, information_only=False
168-
)
169-
computed_size = [
170-
int(block_len / shrink_factor)
171-
for block_len, shrink_factor in zip(block_neg1_image.size, shrink_factors)
172-
]
173-
assert all(
174-
block_output.size[dim] == computed_size[dim]
175-
for dim in range(len(block_output.size))
176-
)
153+
154+
next_block_neg1_shape = next_block_neg1_shape[output_chunks_start:]
177155
for i in range(len(output_chunks)):
178-
output_chunks[i][-1] = block_output.data.shape[i]
156+
output_chunks[i][-1] = next_block_neg1_shape[i]
179157
output_chunks[i] = tuple(output_chunks[i])
180158
output_chunks = tuple(output_chunks)
181159

0 commit comments

Comments
 (0)