Skip to content

Commit 0fb310a

Browse files
authored
Merge pull request #90 from jhlegarreta/AddMotionVizHelperRoutines
ENH: Add motion visualization helper routines
2 parents 6b187f8 + d123760 commit 0fb310a

File tree

4 files changed

+464
-6
lines changed

4 files changed

+464
-6
lines changed

docs/notebooks/bold_realignment.ipynb

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,9 @@
179179
"metadata": {},
180180
"outputs": [],
181181
"source": [
182+
"from nifreeze.viz.motion_viz import plot_framewise_displacement\n",
183+
"\n",
184+
"\n",
182185
"def plot_profile(image_path, axis=None, indexing=None, cmap=\"gray\", label=None, figsize=(15, 1.7)):\n",
183186
" \"\"\"Plots a single image slice on a given axis or a new figure if axis is None.\"\"\"\n",
184187
" # Load the image\n",
@@ -262,12 +265,9 @@
262265
"\n",
263266
" # Plot the framewise displacement on the first axis\n",
264267
" fd_axis = axes[0]\n",
265-
" timepoints = np.arange(len(afni_fd)) # Assuming afni_fd and nifreeze_fd have the same length\n",
266-
" fd_axis.plot(timepoints, afni_fd, label=\"AFNI 3dVolreg FD\", color=\"blue\")\n",
267-
" fd_axis.plot(timepoints, nifreeze_fd, label=\"nifreeze FD\", color=\"orange\")\n",
268-
" fd_axis.set_ylabel(\"FD (mm)\")\n",
269-
" fd_axis.legend(loc=\"upper right\")\n",
270-
" fd_axis.set_xticks([]) # Hide x-ticks to keep x-axis clean\n",
268+
" _ = plot_framewise_displacement(\n",
269+
" afni_fd, nifreeze_fd, \"AFNI 3dVolreg FD\", \"nifreeze FD\", ax=fd_axis\n",
270+
" )\n",
271271
"\n",
272272
" # Set labels for profile plots if not provided\n",
273273
" if labels is None or isinstance(labels, str):\n",

src/nifreeze/analysis/motion.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
# emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: nil -*-
2+
# vi: set ft=python sts=4 ts=4 sw=4 et:
3+
#
4+
# Copyright The NiPreps Developers <[email protected]>
5+
#
6+
# Licensed under the Apache License, Version 2.0 (the "License");
7+
# you may not use this file except in compliance with the License.
8+
# You may obtain a copy of the License at
9+
#
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing, software
13+
# distributed under the License is distributed on an "AS IS" BASIS,
14+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
# See the License for the specific language governing permissions and
16+
# limitations under the License.
17+
#
18+
# We support and encourage derived works from this project, please read
19+
# about our expectations at
20+
#
21+
# https://www.nipreps.org/community/licensing/
22+
#
23+
24+
import numpy as np
25+
26+
27+
def compute_percentage_change(
28+
reference: np.ndarray,
29+
test: np.ndarray,
30+
mask: np.ndarray,
31+
) -> np.ndarray:
32+
"""Compute motion change between reference and test as a percentage.
33+
34+
If a mask is provided, the computation is only provided within the mask.
35+
Also, null values are ignored.
36+
37+
Parameters
38+
----------
39+
reference : :obj:`~numpy.ndarray`
40+
Reference imaging volume.
41+
test : :obj:`~numpy.ndarray`
42+
Test (shifted) imaging volume.
43+
mask : :obj:`~numpy.ndarray`
44+
Mask for value consideration.
45+
46+
Returns
47+
-------
48+
rel_diff : :obj:`~numpy.ndarray`
49+
Motion change between reference and test.
50+
"""
51+
52+
# Avoid divide-by-zero errors
53+
eps = 1e-5
54+
rel_diff = np.zeros_like(reference)
55+
mask = mask.copy()
56+
mask[reference <= eps] = False
57+
rel_diff[mask] = 100 * (test[mask] - reference[mask]) / reference[mask]
58+
59+
return rel_diff

src/nifreeze/viz/motion_viz.py

Lines changed: 241 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,241 @@
1+
# emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: nil -*-
2+
# vi: set ft=python sts=4 ts=4 sw=4 et:
3+
#
4+
# Copyright The NiPreps Developers <[email protected]>
5+
#
6+
# Licensed under the Apache License, Version 2.0 (the "License");
7+
# you may not use this file except in compliance with the License.
8+
# You may obtain a copy of the License at
9+
#
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing, software
13+
# distributed under the License is distributed on an "AS IS" BASIS,
14+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
# See the License for the specific language governing permissions and
16+
# limitations under the License.
17+
#
18+
# We support and encourage derived works from this project, please read
19+
# about our expectations at
20+
#
21+
# https://www.nipreps.org/community/licensing/
22+
#
23+
24+
from typing import Union
25+
26+
import matplotlib.cm as cm
27+
import matplotlib.colors as mcolors
28+
import matplotlib.pyplot as plt
29+
import numpy as np
30+
import pandas as pd
31+
from matplotlib.axes import Axes
32+
from scipy.ndimage import gaussian_filter
33+
34+
ORIENTATIONS = ["sagittal", "coronal", "axial"]
35+
36+
37+
def _extract_slice(img_data: np.ndarray, orientation: str, slice_idx: int) -> np.ndarray:
38+
"""Extract slice data from the given volume at the given orientation slice index.
39+
40+
Parameters
41+
----------
42+
img_data : :obj:`~numpy.ndarray`
43+
Image data to be sliced.
44+
orientation : :obj:`str`
45+
Orientation. Can be one of obj:`ORIENTATIONS`.
46+
slice_idx : :obj:`int`
47+
Slice index.
48+
49+
Returns
50+
-------
51+
:obj:`~numpy.ndarray`
52+
Image slice.
53+
"""
54+
55+
axis = ORIENTATIONS.index(orientation)
56+
57+
axis_sizw = img_data.shape[axis]
58+
59+
if not (0 <= slice_idx < axis_sizw):
60+
raise IndexError(
61+
f"Slice index {slice_idx} out of bounds for axis {orientation} with size {axis_sizw}"
62+
)
63+
64+
slice_obj: list[int | slice] = [slice(None)] * 3
65+
slice_obj[axis] = slice_idx
66+
slice_2d = img_data[tuple(slice_obj)]
67+
return slice_2d if axis == 2 else np.rot90(slice_2d)
68+
69+
70+
def plot_framewise_displacement(
71+
fd: pd.DataFrame,
72+
labels: list,
73+
cmap_name: str = "viridis",
74+
ax: Union[Axes, None] = None,
75+
) -> Axes:
76+
"""Plot frame-wise displacement data.
77+
78+
Plots the frame-wise displacement data corresponding to different
79+
realizations.
80+
81+
Parameters
82+
----------
83+
fd : :obj:`~pd.DataFrame`
84+
Frame-wise displacement values corresponding.
85+
labels : :obj:`list`
86+
Labels for legend.
87+
cmap_name : str, optional
88+
Colormap name.
89+
ax : :obj:`Axes`, optional
90+
Figure axes.
91+
92+
Returns
93+
-------
94+
ax : :obj:`Axes`
95+
Figure plot axis.
96+
"""
97+
98+
n_cols = len(fd.columns)
99+
n_labels = len(labels)
100+
101+
if n_cols != n_labels:
102+
raise ValueError(
103+
f"The number of realizations and labels does not match: {n_cols}; {n_labels}"
104+
)
105+
106+
if ax is None:
107+
fig, ax = plt.subplots(1, 1, figsize=(10, 6), constrained_layout=True)
108+
109+
# Plot the framewise displacement
110+
n_frames = fd.index.to_numpy()
111+
112+
cmap = cm.get_cmap(cmap_name, n_cols)
113+
colors = [mcolors.to_hex(cmap(i)) for i in range(n_cols)]
114+
115+
for i, col in enumerate(fd.columns):
116+
ax.plot(n_frames, fd[col], label=labels[i], color=colors[i])
117+
118+
ax.set_ylabel("FD (mm)")
119+
ax.legend(loc="upper right")
120+
ax.set_xticks([]) # Hide x-ticks to keep x-axis clean
121+
122+
return ax
123+
124+
125+
def plot_volumewise_motion(
126+
frames: np.ndarray,
127+
motion_params: np.ndarray,
128+
ax: np.ndarray | None = None,
129+
) -> np.ndarray:
130+
"""Plot mean volume-wise motion parameters.
131+
132+
Plots the mean translation and rotation parameters along the ``x``, `y``,
133+
and ``z`` axes.
134+
135+
Parameters
136+
----------
137+
frames : :obj:`~numpy.ndarray`
138+
Frame indices.
139+
motion_params : :obj:`~numpy.ndarray`
140+
Motion parameters.Motion parameters: translation and rotation. Each row
141+
represents one frame, and columns represent each coordinate axis ``x``,
142+
`y``, and ``z``. Translation parameters are followed by rotation
143+
parameters column-wise.
144+
ax : :obj:`~numpy.ndarray`, optional
145+
Figure axes.
146+
147+
Returns
148+
-------
149+
ax : :obj:`~numpy.ndarray`
150+
Figure plot axes array.
151+
"""
152+
153+
if ax is None:
154+
fig, ax = plt.subplots(2, 1, figsize=(10, 6), sharex=True, constrained_layout=True)
155+
156+
# Plot translations
157+
ax[0].plot(frames, motion_params[:, 0], label="x")
158+
ax[0].plot(frames, motion_params[:, 1], label="y")
159+
ax[0].plot(frames, motion_params[:, 2], label="z")
160+
ax[0].set_ylabel("Translation (mm)")
161+
ax[0].legend(loc="upper right")
162+
ax[0].set_title("Translation vs frames")
163+
164+
# Plot rotations
165+
ax[1].plot(frames, motion_params[:, 3], label="Rx")
166+
ax[1].plot(frames, motion_params[:, 4], label="Ry")
167+
ax[1].plot(frames, motion_params[:, 5], label="Rz")
168+
ax[1].set_ylabel("Rotation (deg)")
169+
ax[1].set_xlabel("Time (s)")
170+
ax[1].legend(loc="upper right")
171+
ax[1].set_title("Rotation vs frames")
172+
173+
return ax
174+
175+
176+
def plot_motion_overlay(
177+
rel_diff: np.ndarray,
178+
img_data: np.ndarray,
179+
brain_mask: np.ndarray,
180+
orientation: str,
181+
slice_idx: int,
182+
smooth: bool = True,
183+
ax: Union[Axes, None] = None,
184+
) -> Axes:
185+
"""Plot motion relative difference as an overlay on a given orientation and slice of the imaging data.
186+
187+
The values of the relative difference can optionally be smoothed using a
188+
Gaussian filter for a more appealing visual result.
189+
190+
Parameters
191+
----------
192+
rel_diff : :obj:`~numpy.ndarray`
193+
Relative motion difference.
194+
img_data : :obj:`~numpy.ndarray`
195+
Imaging data.
196+
brain_mask : :obj:`~numpy.ndarray`
197+
Brain mask.
198+
orientation : :obj:`str`
199+
Orientation. Can be one of obj:`ORIENTATIONS`.
200+
slice_idx : :obj:`int`
201+
Slice index to be plot.
202+
smooth : :obj:`bool`, optional
203+
``True`` to smooth the motion relative difference.
204+
ax : :obj:`Axes`, optional
205+
Figure axis.
206+
207+
Returns
208+
-------
209+
ax : :obj:`Axes`
210+
Figure plot axis.
211+
"""
212+
213+
# Check dimensionality
214+
if img_data.shape != rel_diff.shape:
215+
raise IndexError(
216+
f"Dimension mismatch: imaging data shape {img_data.shape}, overlay shape {rel_diff.shape}"
217+
)
218+
219+
# Smooth the relative difference
220+
smoothed_diff = rel_diff
221+
if smooth:
222+
smoothed_diff = gaussian_filter(rel_diff, sigma=1)
223+
224+
# Mask the background
225+
masked_img_data = np.where(brain_mask, img_data, np.nan)
226+
masked_smooth_diff = np.where(brain_mask, smoothed_diff, np.nan)
227+
228+
masked_img_slice = _extract_slice(masked_img_data, orientation, slice_idx)
229+
diff_img_slice = _extract_slice(masked_smooth_diff, orientation, slice_idx)
230+
231+
# Show overlay on a slice
232+
if ax is None:
233+
fig, ax = plt.subplots(1, 1, figsize=(10, 5), constrained_layout=True)
234+
235+
ax.imshow(masked_img_slice, cmap="gray")
236+
im = ax.imshow(diff_img_slice, cmap="bwr", alpha=0.5)
237+
ax.figure.colorbar(im, ax=ax, label="Relative Difference (%)")
238+
ax.set_title("Smoothed Relative Difference Overlay")
239+
ax.axis("off")
240+
241+
return ax

0 commit comments

Comments
 (0)