Skip to content

Commit b4e5173

Browse files
oestebanjhlegarreta
authored andcommitted
enh: add decimator and gaussian filter
1 parent 6b6ba70 commit b4e5173

File tree

2 files changed

+215
-1
lines changed

2 files changed

+215
-1
lines changed

src/nifreeze/data/filtering.py

Lines changed: 147 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,12 @@
2424

2525
from __future__ import annotations
2626

27+
from numbers import Number
28+
2729
import numpy as np
28-
from scipy.ndimage import median_filter
30+
from nibabel import Nifti1Image, load
31+
from scipy.ndimage import gaussian_filter as _gs
32+
from scipy.ndimage import map_coordinates, median_filter
2933
from skimage.morphology import ball
3034

3135
DEFAULT_DTYPE = "int16"
@@ -92,3 +96,145 @@ def advanced_clip(
9296
data = np.round(255 * data).astype(dtype)
9397

9498
return data
99+
100+
101+
def gaussian_filter(
102+
data: np.ndarray,
103+
vox_width: float | tuple[float, float, float],
104+
) -> np.ndarray:
105+
"""
106+
Applies a Gaussian smoothing filter to a n-dimensional array.
107+
108+
This function smooths the input data using a Gaussian filter with a specified
109+
width (sigma) in voxels along each relevant dimension. It automatically
110+
handles different data dimensionalities (2D, 3D, or 4D) and ensures that
111+
smoothing is not applied along the time or orientation dimension (if present
112+
in 4D data).
113+
114+
Parameters
115+
----------
116+
data : :obj:`~numpy.ndarray`
117+
The input data array.
118+
vox_width : :obj:`float` or :obj:`tuple` of three :obj:`float`
119+
The smoothing kernel width (sigma) in voxels. If a single :obj:`float` is provided,
120+
it is applied uniformly across all spatial dimensions. Alternatively, a
121+
tuple of three floats can be provided to specify different sigma values
122+
for each spatial dimension (x, y, z).
123+
124+
Returns
125+
-------
126+
:obj:`~numpy.ndarray`
127+
The smoothed data array.
128+
129+
"""
130+
131+
data = np.squeeze(data) # Drop unused dimensions
132+
ndim = data.ndim
133+
134+
if isinstance(vox_width, Number):
135+
vox_width = tuple([vox_width] * min(3, ndim))
136+
137+
# Do not smooth across time/orientation (if present in 4D data)
138+
if ndim == 4 and len(vox_width) == 3:
139+
vox_width = (*vox_width, 0)
140+
141+
return _gs(data, vox_width)
142+
143+
144+
def decimate(
145+
in_file: str,
146+
factor: int | tuple[int, int, int],
147+
smooth: bool | tuple[int, int, int] = True,
148+
order: int = 3,
149+
nonnegative: bool = True,
150+
) -> Nifti1Image:
151+
"""
152+
Decimates a 3D or 4D Nifti image by a specified downsampling factor.
153+
154+
This function downsamples a Nifti image by averaging voxels within a user-defined
155+
factor in each spatial dimension. It optionally applies Gaussian smoothing
156+
before downsampling to reduce aliasing artifacts. The function also handles
157+
updating the affine transformation matrix to reflect the change in voxel size.
158+
159+
Parameters
160+
----------
161+
in_file : :obj:`str`
162+
Path to the input NIfTI image file.
163+
factor : :obj:`int` or :obj:`tuple`
164+
The downsampling factor. If a single integer is provided, it is applied
165+
uniformly across all spatial dimensions. Alternatively, a tuple of three
166+
integers can be provided to specify different downsampling factors for each
167+
spatial dimension (x, y, z). Values must be greater than 0.
168+
smooth : :obj:`bool` or :obj:`tuple`, optional (default=``True``)
169+
Controls application of Gaussian smoothing before downsampling. If True,
170+
a smoothing kernel size equal to the downsampling factor is applied.
171+
Alternatively, a tuple of three integers can be provided to specify
172+
different smoothing kernel sizes for each spatial dimension. Setting to
173+
False disables smoothing.
174+
order : :obj:`int`, optional (default=3)
175+
The order of the spline interpolation used for downsampling. Higher
176+
orders provide smoother results but are computationally more expensive.
177+
nonnegative : :obj:`bool`, optional (default=``True``)
178+
If True, negative values in the downsampled data are set to zero.
179+
180+
Returns
181+
-------
182+
:obj:`~nibabel.Nifti1Image`
183+
The downsampled NIfTI image object.
184+
185+
"""
186+
187+
imnii = load(in_file)
188+
data = np.squeeze(imnii.get_fdata()) # Remove unused dimensions
189+
datashape = data.shape
190+
ndim = data.ndim
191+
192+
if isinstance(factor, Number):
193+
factor = tuple([factor] * min(3, ndim))
194+
195+
if any(f <= 0 for f in factor[:3]):
196+
raise ValueError("All spatial downsampling factors must be positive.")
197+
198+
if ndim == 4 and len(factor) == 3:
199+
factor = (*factor, 0)
200+
201+
if smooth:
202+
if smooth is True:
203+
smooth = factor
204+
data = gaussian_filter(data, smooth)
205+
206+
# Create downsampled grid
207+
down_grid = np.array(
208+
np.meshgrid(
209+
*[np.arange(_s, step=int(_f) or 1) for _s, _f in zip(datashape, factor)],
210+
indexing="ij",
211+
)
212+
)
213+
new_shape = down_grid.shape[1:]
214+
215+
# Update affine transformation
216+
newaffine = imnii.affine.copy()
217+
newaffine[:3, :3] = np.array(factor[:3]) * newaffine[:3, :3]
218+
219+
# TODO: Update offset so new array is aligned with original
220+
221+
# Resample data on the new grid
222+
resampled = map_coordinates(
223+
data,
224+
down_grid.reshape((ndim, np.prod(new_shape))),
225+
order=order,
226+
mode="constant",
227+
cval=0,
228+
prefilter=True,
229+
).reshape(new_shape)
230+
231+
# Set negative values to zero (optional)
232+
if order > 2 and nonnegative:
233+
resampled[resampled < 0] = 0
234+
235+
# Create new Nifti image with updated information
236+
newnii = Nifti1Image(resampled, newaffine, imnii.header)
237+
newnii.set_sform(newaffine, code=1)
238+
newnii.set_qform(newaffine, code=1)
239+
240+
return newnii

test/test_filtering.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
# emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: nil -*-
2+
# vi: set ft=python sts=4 ts=4 sw=4 et:
3+
#
4+
# Copyright 2024 The NiPreps Developers <[email protected]>
5+
#
6+
# Licensed under the Apache License, Version 2.0 (the "License");
7+
# you may not use this file except in compliance with the License.
8+
# You may obtain a copy of the License at
9+
#
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing, software
13+
# distributed under the License is distributed on an "AS IS" BASIS,
14+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
# See the License for the specific language governing permissions and
16+
# limitations under the License.
17+
#
18+
# We support and encourage derived works from this project, please read
19+
# about our expectations at
20+
#
21+
# https://www.nipreps.org/community/licensing/
22+
#
23+
"""Unit tests exercising data filtering utilities."""
24+
import nibabel as nb
25+
import numpy as np
26+
27+
import pytest
28+
29+
from nifreeze.data.filtering import decimate
30+
31+
32+
@pytest.mark.parametrize(
33+
("size", "block_size"),
34+
[
35+
((20, 20, 20), (5, 5, 5),)
36+
],
37+
)
38+
def test_decimation(tmp_path, size, block_size):
39+
"""Exercise decimation."""
40+
41+
# Calculate the number of sub-blocks in each dimension
42+
num_blocks = [s // b for s, b in zip(size, block_size)]
43+
44+
# Create the empty array
45+
voxel_array = np.zeros(size, dtype=int)
46+
47+
# Fill the array with increasing values based on sub-block position
48+
current_block = 0
49+
for k in range(num_blocks[2]):
50+
for j in range(num_blocks[1]):
51+
for i in range(num_blocks[0]):
52+
voxel_array[
53+
i * block_size[0]:(i + 1) * block_size[0],
54+
j * block_size[1]:(j + 1) * block_size[1],
55+
k * block_size[2]:(k + 1) * block_size[2]
56+
] = current_block
57+
current_block += 1
58+
59+
fname = tmp_path / "test_img.nii.gz"
60+
61+
nb.Nifti1Image(voxel_array, None, None).to_filename(fname)
62+
63+
# Need to define test oracle. For now, just see if it doesn't smoke.
64+
decimate(fname, factor=2, smooth=False, order=1)
65+
66+
# out.to_filename(tmp_path / "decimated.nii.gz")
67+
68+
# import pdb; pdb.set_trace()

0 commit comments

Comments
 (0)