Skip to content

Commit 7e4a085

Browse files
jhlegarretaoesteban
andcommitted
ENH: Add framewise displacement peak detection function
Add framewise displacement peak detection function. Take advantage of the commit to document the `nifreeze.analysis.motion` module. Co-authored-by: Oscar Esteban <[email protected]>
1 parent 0fb310a commit 7e4a085

File tree

3 files changed

+103
-0
lines changed

3 files changed

+103
-0
lines changed

src/nifreeze/analysis/filtering.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
# emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: nil -*-
2+
# vi: set ft=python sts=4 ts=4 sw=4 et:
3+
#
4+
# Copyright The NiPreps Developers <[email protected]>
5+
#
6+
# Licensed under the Apache License, Version 2.0 (the "License");
7+
# you may not use this file except in compliance with the License.
8+
# You may obtain a copy of the License at
9+
#
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing, software
13+
# distributed under the License is distributed on an "AS IS" BASIS,
14+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
# See the License for the specific language governing permissions and
16+
# limitations under the License.
17+
#
18+
# We support and encourage derived works from this project, please read
19+
# about our expectations at
20+
#
21+
# https://www.nipreps.org/community/licensing/
22+
#
23+
"""Analysis data filtering."""
24+
25+
import numpy as np
26+
27+
28+
def normalize(x: np.ndarray):
29+
r"""Normalize data using the z-score.
30+
31+
The z-score normalization is computed as:
32+
33+
.. math::
34+
35+
z_i = \frac{x_i - \mu}{\sigma}
36+
37+
where $x_i$ is the framewise displacement at point $i$, $\mu$ is the mean
38+
of all values, $\sigma$ is the standard deviation of the values, and $z_i$
39+
is the normalized z-score.
40+
41+
Parameters
42+
----------
43+
x : :obj:`~numpy.ndarray`
44+
Data to be normalized.
45+
46+
Returns
47+
-------
48+
:obj:`~numpy.ndarray`
49+
Normalized data.
50+
"""
51+
52+
return (x - np.mean(x)) / np.std(x)

src/nifreeze/analysis/motion.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,10 @@
2020
#
2121
# https://www.nipreps.org/community/licensing/
2222
#
23+
"""Motion analysis."""
2324

2425
import numpy as np
26+
from scipy.stats import zscore
2527

2628

2729
def compute_percentage_change(
@@ -57,3 +59,33 @@ def compute_percentage_change(
5759
rel_diff[mask] = 100 * (test[mask] - reference[mask]) / reference[mask]
5860

5961
return rel_diff
62+
63+
64+
def identify_spikes(fd: np.ndarray, threshold: float = 2.0):
65+
"""Identify motion spikes in framewise displacement data.
66+
67+
Identifies high-motion frames as timepoint exceeding a given threshold value
68+
based on z-score normalized framewise displacement (FD) values.
69+
70+
Parameters
71+
----------
72+
fd : :obj:`~numpy.ndarray`
73+
Framewise displacement data.
74+
threshold : :obj:`float`, optional
75+
Threshold value to determine motion spikes.
76+
77+
Returns
78+
-------
79+
indices : :obj:`~numpy.ndarray`
80+
Indices of identified motion spikes.
81+
mask : :obj:`~numpy.ndarray`
82+
Mask of identified motion spikes.
83+
"""
84+
85+
# Normalize (z-score)
86+
fd_norm = zscore(fd)
87+
88+
mask = fd_norm > threshold
89+
indices = np.where(mask)[0]
90+
91+
return indices, mask

test/test_analysis.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
compute_z_score,
3232
identify_bland_altman_salient_data,
3333
)
34+
from nifreeze.analysis.motion import identify_spikes
3435

3536

3637
def test_compute_z_score():
@@ -141,3 +142,21 @@ def test_identify_bland_altman_salient_data():
141142

142143
assert len(salient_data[BASalientEntity.RIGHT_INDICES.value]) == top_n
143144
assert len(salient_data[BASalientEntity.RIGHT_MASK.value]) == len(_data1)
145+
146+
147+
def test_identify_spikes(request):
148+
rng = request.node.rng
149+
150+
n_samples = 450
151+
152+
fd = rng.normal(0, 5, n_samples)
153+
threshold = 2.0
154+
155+
expected_indices = np.asarray([5, 57, 85, 100, 127, 180, 191, 202, 335, 393, 409])
156+
expected_mask = np.zeros(n_samples, dtype=bool)
157+
expected_mask[expected_indices] = True
158+
159+
obtained_indices, obtained_mask = identify_spikes(fd, threshold=threshold)
160+
161+
assert np.array_equal(obtained_indices, expected_indices)
162+
assert np.array_equal(obtained_mask, expected_mask)

0 commit comments

Comments
 (0)