Skip to content

Commit a4bea6b

Browse files
authored
Merge pull request #71 from jhlegarreta/RefactorFilteringCode
REF: Refactor filters into ``filtering``
2 parents e964b96 + 770ba3d commit a4bea6b

File tree

6 files changed

+353
-43
lines changed

6 files changed

+353
-43
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,7 @@ env = "PYTHONHASHSEED=0"
221221
markers = [
222222
"random_gtab_data: Custom marker for random gtab data tests",
223223
"random_dwi_data: Custom marker for random dwi data tests",
224+
"random_uniform_4d_data: Custom marker for random 4d data tests",
224225
]
225226
filterwarnings = [
226227
"ignore::DeprecationWarning",

src/nifreeze/data/filtering.py

Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,12 @@
2828
from scipy.ndimage import median_filter
2929
from skimage.morphology import ball
3030

31+
from nifreeze.data.dmri import DEFAULT_CLIP_PERCENTILE
32+
3133
DEFAULT_DTYPE = "int16"
3234
"""The default image's data type."""
35+
BVAL_ATOL = 100.0
36+
"""b-value tolerance value."""
3337

3438

3539
def advanced_clip(
@@ -96,3 +100,161 @@ def advanced_clip(
96100
data = np.round(255 * data).astype(dtype)
97101

98102
return data
103+
104+
105+
def robust_minmax_normalization(
106+
data: np.ndarray,
107+
mask: np.ndarray | None = None,
108+
p_min: float = 5.0,
109+
p_max: float = 95.0,
110+
inplace: bool = False,
111+
) -> np.ndarray | None:
112+
r"""Normalize min-max percentiles of each volume to the grand min-max
113+
percentiles.
114+
115+
Robust min/max normalization of the volumes in the dataset following:
116+
117+
.. math::
118+
\text{data}_{\text{normalized}} = \frac{(\text{data} - p_{min}) \cdot p_{\text{mean}}}{p_{\text{range}}} + p_{min}^{\text{mean}}
119+
120+
where
121+
122+
.. math::
123+
p_{\text{range}} = p_{max} - p_{min}, \quad p_{\text{mean}} = \frac{1}{N} \sum_{i=1}^N p_{\text{range}_i}, \quad p_{min}^{\text{mean}} = \frac{1}{N} \sum_{i=1}^N p_{5_i}
124+
125+
If a mask is provided, only the data within the mask are considered.
126+
127+
Parameters
128+
----------
129+
data : :obj:`~numpy.ndarray`
130+
Data to be normalized.
131+
mask : :obj:`~numpy.ndarray`, optional
132+
Mask. If provided, only the data within the mask are considered.
133+
p_min : :obj:`float`, optional
134+
The lower percentile value for normalization.
135+
p_max : :obj:`float`, optional
136+
The upper percentile value for normalization.
137+
inplace : :obj:`bool`, optional
138+
If ``False``, the normalization is performed on the original data.
139+
140+
Returns
141+
-------
142+
data : :obj:`~numpy.ndarray` or None
143+
Normalized data or ``None`` if ``inplace`` is ``True``.
144+
"""
145+
146+
normalized = data if inplace else data.copy()
147+
148+
mask = mask if mask is not None else np.ones(data.shape[-1], dtype=bool)
149+
volumes = data[..., mask]
150+
reshape_shape = (-1, volumes.shape[-1]) if mask is None else (-1, sum(mask))
151+
reshaped_data = volumes.reshape(reshape_shape)
152+
p5 = np.percentile(reshaped_data, p_min, axis=0)
153+
p95 = np.percentile(reshaped_data, p_max, axis=0) - p5
154+
normalized[..., mask] = (volumes - p5) * p95.mean() / p95 + p5.mean()
155+
156+
if inplace:
157+
return None
158+
159+
return normalized
160+
161+
162+
def grand_mean_normalization(
163+
data: np.ndarray,
164+
mask: np.ndarray | None = None,
165+
center: float = DEFAULT_CLIP_PERCENTILE,
166+
inplace: bool = False,
167+
) -> np.ndarray | None:
168+
"""Robust grand mean normalization.
169+
170+
Regresses out global signal differences so that data are normalized and
171+
centered around a given value.
172+
173+
If a mask is provided, only the data within the mask are considered.
174+
175+
Parameters
176+
----------
177+
data : :obj:`~numpy.ndarray`
178+
Data to be normalized.
179+
mask : :obj:`~numpy.ndarray`, optional
180+
Mask. If provided, only the data within the mask are considered.
181+
center : float, optional
182+
Central value around which to normalize the data.
183+
inplace : :obj:`bool`, optional
184+
If ``False``, the normalization is performed on the original data.
185+
186+
Returns
187+
-------
188+
data : :obj:`~numpy.ndarray` or None
189+
Normalized data or ``None`` if ``inplace`` is ``True``.
190+
"""
191+
192+
normalized = data if inplace else data.copy()
193+
194+
mask = mask if mask is not None else np.ones(data.shape[-1], dtype=bool)
195+
volumes = data[..., mask]
196+
197+
centers = np.median(volumes, axis=(0, 1, 2))
198+
reference = np.percentile(centers[centers >= 1.0], center)
199+
centers[centers < 1.0] = reference
200+
drift = reference / centers
201+
normalized[..., mask] = volumes * drift
202+
203+
if inplace:
204+
return None
205+
206+
return normalized
207+
208+
209+
def dwi_select_shells(
210+
gradients: np.ndarray,
211+
index: int,
212+
atol_low: float | None = None,
213+
atol_high: float | None = None,
214+
) -> np.ndarray:
215+
"""Select DWI shells around the given index and lower and upper b-value
216+
bounds.
217+
218+
Computes a boolean mask of the DWI shells around the given index with the
219+
provided lower and upper bound b-values.
220+
221+
If ``atol_low`` and ``atol_high`` are both ``None``, the returned shell mask
222+
corresponds to the lengths of the diffusion-sensitizing gradients.
223+
224+
Parameters
225+
----------
226+
gradients : :obj:`~numpy.ndarray`
227+
Gradients.
228+
index : :obj:`int`
229+
Index of the shell data.
230+
atol_low : :obj:`float`, optional
231+
A lower bound for the b-value.
232+
atol_high : :obj:`float`, optional
233+
An upper bound for the b-value.
234+
235+
Returns
236+
-------
237+
shellmask : :obj:`~numpy.ndarray`
238+
Shell mask.
239+
"""
240+
241+
bvalues = gradients[:, -1]
242+
bcenter = bvalues[index]
243+
244+
shellmask = np.ones(len(bvalues), dtype=bool)
245+
shellmask[index] = False # Drop the held-out index
246+
247+
if atol_low is None and atol_high is None:
248+
return shellmask
249+
250+
atol_low = 0 if atol_low is None else atol_low
251+
atol_high = gradients[:, -1].max() if atol_high is None else atol_high
252+
253+
# Keep only b-values within the range defined by atol_high and atol_low
254+
shellmask[bvalues > (bcenter + atol_high)] = False
255+
shellmask[bvalues < (bcenter - atol_low)] = False
256+
257+
if not shellmask.sum():
258+
raise RuntimeError(f"Shell corresponding to index {index} (b={bcenter}) is empty.")
259+
260+
return shellmask

src/nifreeze/model/dmri.py

Lines changed: 17 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,8 @@
2727
from dipy.core.gradients import gradient_table_from_bvals_bvecs
2828
from joblib import Parallel, delayed
2929

30-
from nifreeze.data.dmri import (
31-
DEFAULT_CLIP_PERCENTILE,
32-
DTI_MIN_ORIENTATIONS,
33-
DWI,
34-
)
30+
from nifreeze.data.dmri import DTI_MIN_ORIENTATIONS, DWI
31+
from nifreeze.data.filtering import BVAL_ATOL, dwi_select_shells, grand_mean_normalization
3532
from nifreeze.model.base import BaseModel, ExpectationModel
3633

3734
S0_EPSILON = 1e-6
@@ -215,14 +212,14 @@ def fit_predict(self, index: int | None = None, **kwargs):
215212
class AverageDWIModel(ExpectationModel):
216213
"""A trivial model that returns an average DWI volume."""
217214

218-
__slots__ = ("_th_low", "_th_high", "_detrend")
215+
__slots__ = ("_atol_low", "_atol_high", "_detrend")
219216

220217
def __init__(
221218
self,
222219
dataset: DWI,
223220
stat: str = "median",
224-
th_low: float = 100.0,
225-
th_high: float = 100.0,
221+
atol_low: float = BVAL_ATOL,
222+
atol_high: float = BVAL_ATOL,
226223
detrend: bool = False,
227224
**kwargs,
228225
):
@@ -235,10 +232,10 @@ def __init__(
235232
Reference to a DWI object.
236233
stat : :obj:`str`, optional
237234
Whether the summary statistic to apply is ``"mean"`` or ``"median"``.
238-
th_low : :obj:`float`, optional
235+
atol_low : :obj:`float`, optional
239236
A lower bound for the b-value corresponding to the diffusion weighted images
240237
that will be averaged.
241-
th_high : :obj:`float`, optional
238+
atol_low : :obj:`float`, optional
242239
An upper bound for the b-value corresponding to the diffusion weighted images
243240
that will be averaged.
244241
detrend : :obj:`bool`, optional
@@ -249,8 +246,8 @@ def __init__(
249246
"""
250247
super().__init__(dataset, stat=stat, **kwargs)
251248

252-
self._th_low = th_low
253-
self._th_high = th_high
249+
self._atol_low = atol_low
250+
self._atol_high = atol_high
254251
self._detrend = detrend
255252

256253
def fit_predict(self, index: int | None = None, *_, **kwargs):
@@ -259,31 +256,22 @@ def fit_predict(self, index: int | None = None, *_, **kwargs):
259256
if index is None:
260257
raise RuntimeError(f"Model {self.__class__.__name__} does not allow locking.")
261258

262-
bvalues = self._dataset.gradients[:, -1]
263-
bcenter = bvalues[index]
264-
265-
shellmask = np.ones(len(self._dataset), dtype=bool)
266-
267-
# Keep only bvalues within the range defined by th_high and th_low
268-
shellmask[index] = False
269-
shellmask[bvalues > (bcenter + self._th_high)] = False
270-
shellmask[bvalues < (bcenter - self._th_low)] = False
271-
272-
if not shellmask.sum():
273-
raise RuntimeError(f"Shell corresponding to index {index} (b={bcenter}) is empty.")
259+
shellmask = dwi_select_shells(
260+
self._dataset.gradients,
261+
index,
262+
atol_low=self._atol_low,
263+
atol_high=self._atol_high,
264+
)
274265

275266
shelldata = self._dataset.dataobj[..., shellmask]
276267

277268
# Regress out global signal differences
278269
if self._detrend:
279-
centers = np.median(shelldata, axis=(0, 1, 2))
280-
reference = np.percentile(centers[centers >= 1.0], DEFAULT_CLIP_PERCENTILE)
281-
centers[centers < 1.0] = reference
282-
drift = reference / centers
283-
shelldata = shelldata * drift
270+
shelldata = grand_mean_normalization(shelldata, mask=None)
284271

285272
# Select the summary statistic
286273
avg_func = np.median if self._stat == "median" else np.mean
274+
287275
# Calculate the average
288276
return avg_func(shelldata, axis=-1)
289277

test/conftest.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ def random_number_generator(request):
171171
@pytest.fixture(autouse=True)
172172
def setup_random_uniform_4d_data(request):
173173
"""Automatically generate random data for tests."""
174-
marker = request.node.get_closest_marker("random_uniform_4d_data_generator")
174+
marker = request.node.get_closest_marker("random_uniform_4d_data")
175175

176176
size = (32, 32, 32, 5)
177177
a = 0.0
@@ -187,15 +187,25 @@ def setup_random_uniform_4d_data(request):
187187
def _generate_random_choices(request, values, count):
188188
rng = request.node.rng
189189

190+
values = set(values)
191+
190192
num_elements = len(values)
191193

192-
# Randomly distribute N among the given values
193-
partitions = rng.multinomial(count, np.ones(num_elements) / num_elements)
194+
if count < num_elements:
195+
raise ValueError(
196+
f"Count must be at least the number of unique values to guarantee inclusion\nProvided: {count} and {values}."
197+
)
198+
199+
# Start by assigning one of each value
200+
selected_values = list(values)
201+
202+
# Distribute remaining count: randomly distribute N among the values
203+
remaining = count - num_elements
204+
partitions = rng.multinomial(remaining, np.ones(num_elements) / num_elements)
194205

195-
# Create a list of selected values
196-
selected_values = [
197-
val for val, count in zip(values, partitions, strict=True) for _ in range(count)
198-
]
206+
# Add the remaining values according to the partitions
207+
for val, extra_count in zip(values, partitions, strict=True):
208+
selected_values.extend([val] * extra_count)
199209

200210
return sorted(selected_values)
201211

0 commit comments

Comments
 (0)