Skip to content

Commit 359b03b

Browse files
authored
Merge pull request #325 from nipreps/codex/analyze-and-report-gradient-data-update-points
ENH: Adopt row-major gradient tables
2 parents e15d600 + 9dbe77a commit 359b03b

File tree

16 files changed

+164
-96
lines changed

16 files changed

+164
-96
lines changed

docs/notebooks/data_structures.ipynb

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@
9393
}
9494
],
9595
"source": [
96-
"plot_gradients(dwi.gradients.T);"
96+
"plot_gradients(dwi.gradients);"
9797
]
9898
},
9999
{
@@ -192,7 +192,7 @@
192192
}
193193
],
194194
"source": [
195-
"plot_gradients(dwi.gradients.T);"
195+
"plot_gradients(dwi.gradients);"
196196
]
197197
},
198198
{
@@ -213,7 +213,7 @@
213213
],
214214
"source": [
215215
"# Select a b-value\n",
216-
"b2000_gradientmask = dwi.gradients[-1, ...] == 2000\n",
216+
"b2000_gradientmask = dwi.gradients[:, -1] == 2000\n",
217217
"\n",
218218
"# Select b=2000\n",
219219
"data, _, grad = dwi[b2000_gradientmask]\n",

docs/notebooks/pet_motion_estimation.ipynb

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -2428,10 +2428,11 @@
24282428
]
24292429
},
24302430
{
2431-
"metadata": {},
24322431
"cell_type": "code",
2433-
"outputs": [],
24342432
"execution_count": null,
2433+
"id": "d3627d44376b27f4",
2434+
"metadata": {},
2435+
"outputs": [],
24352436
"source": [
24362437
"import numpy as np\n",
24372438
"import pandas as pd\n",
@@ -2441,20 +2442,20 @@
24412442
"# Assume `affines` is the list of affine matrices computed earlier\n",
24422443
"motion_parameters = []\n",
24432444
"\n",
2444-
"for idx, affine in enumerate(affines):\n",
2445+
"for _idx, affine in enumerate(affines):\n",
24452446
" tx, ty, tz, rx, ry, rz = extract_motion_parameters(affine)\n",
24462447
" motion_parameters.append([tx, ty, tz, rx, ry, rz])\n",
24472448
"\n",
24482449
"motion_parameters = np.array(motion_parameters)\n",
24492450
"estimated_fd = compute_fd_from_motion(motion_parameters)"
2450-
],
2451-
"id": "d3627d44376b27f4"
2451+
]
24522452
},
24532453
{
2454-
"metadata": {},
24552454
"cell_type": "code",
2456-
"outputs": [],
24572455
"execution_count": null,
2456+
"id": "4f141ebdb1643673",
2457+
"metadata": {},
2458+
"outputs": [],
24582459
"source": [
24592460
"# Set up the matplotlib figure\n",
24602461
"import matplotlib.pyplot as plt\n",
@@ -2466,20 +2467,20 @@
24662467
"plot_volumewise_motion(np.arange(len(estimated_fd)), motion_parameters)\n",
24672468
"\n",
24682469
"plt.show()"
2469-
],
2470-
"id": "4f141ebdb1643673"
2470+
]
24712471
},
24722472
{
2473-
"metadata": {},
24742473
"cell_type": "markdown",
2475-
"source": "For the dataset used in this example, we have access to the ground truth motion parameters that were used to corrupt the motion-free dataset. Let's now plot the ground truth motion to enable a visual comparison with the estimated motion.",
2476-
"id": "e3f45164598d16f0"
2474+
"id": "e3f45164598d16f0",
2475+
"metadata": {},
2476+
"source": "For the dataset used in this example, we have access to the ground truth motion parameters that were used to corrupt the motion-free dataset. Let's now plot the ground truth motion to enable a visual comparison with the estimated motion."
24772477
},
24782478
{
2479-
"metadata": {},
24802479
"cell_type": "code",
2481-
"outputs": [],
24822480
"execution_count": null,
2481+
"id": "1009ea77e1bdd0ee",
2482+
"metadata": {},
2483+
"outputs": [],
24832484
"source": [
24842485
"from nifreeze.viz.motion_viz import plot_volumewise_motion\n",
24852486
"\n",
@@ -2505,14 +2506,13 @@
25052506
"\n",
25062507
"plt.tight_layout()\n",
25072508
"plt.show()"
2508-
],
2509-
"id": "1009ea77e1bdd0ee"
2509+
]
25102510
},
25112511
{
2512-
"metadata": {},
25132512
"cell_type": "markdown",
2514-
"source": "Let's plot the estimated and the ground truth framewise displacement.",
2515-
"id": "113b4b4d1361b5ec"
2513+
"id": "113b4b4d1361b5ec",
2514+
"metadata": {},
2515+
"source": "Let's plot the estimated and the ground truth framewise displacement."
25162516
},
25172517
{
25182518
"cell_type": "code",

docs/usage.rst

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,11 @@ To utilize *NiFreeze* functionalities within your Python module or script, follo
3434
Use the appropriate parameters for the particular imaging modality (e.g.
3535
dMRI, fMRI, or PET) that you are using.
3636

37-
For example, for dMRI data, ensure the gradient table is provided. It
38-
should have one column per diffusion-weighted image. The first three rows
39-
represent the gradient directions, and the last row indicates the timing
40-
and strength of the gradients in units of s/mm² ``[ R A S+ b ]``.
37+
For example, for dMRI data, ensure the gradient table is provided. NiFreeze
38+
expects the table to have one row per diffusion-weighted image, with the
39+
first three columns storing the gradient direction components and the last
40+
column indicating the timing and strength of the gradients in units of
41+
s/mm² ``[ R A S+ b ]``.
4142

4243
.. code-block:: python
4344

src/nifreeze/data/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ def __len__(self) -> int:
100100

101101
def _getextra(self, idx: int | slice | tuple | np.ndarray) -> tuple[Unpack[Ts]]:
102102
"""
103-
Extracts extra fields synchronized with the indexed access of the corresponding data object.
103+
Extract extra fields for a given index of the corresponding data object.
104104
105105
Parameters
106106
----------

src/nifreeze/data/dmri.py

Lines changed: 80 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -71,12 +71,43 @@ class DWI(BaseDataset[np.ndarray]):
7171
bzero: np.ndarray = attrs.field(default=None, repr=_data_repr, eq=attrs.cmp_using(eq=_cmp))
7272
"""A *b=0* reference map, preferably obtained by some smart averaging."""
7373
gradients: np.ndarray = attrs.field(default=None, repr=_data_repr, eq=attrs.cmp_using(eq=_cmp))
74-
"""A 2D numpy array of the gradient table (4xN)."""
74+
"""A 2D numpy array of the gradient table (``N`` orientations x ``C`` components)."""
7575
eddy_xfms: list = attrs.field(default=None)
7676
"""List of transforms to correct for estimated eddy current distortions."""
7777

78+
def __attrs_post_init__(self) -> None:
79+
self._normalize_gradients()
80+
81+
def _normalize_gradients(self) -> None:
82+
if self.gradients is None:
83+
return
84+
85+
gradients = np.asarray(self.gradients)
86+
if gradients.ndim != 2:
87+
raise ValueError("Gradient table must be a 2D array")
88+
89+
n_volumes = None
90+
if self.dataobj is not None:
91+
try:
92+
n_volumes = self.dataobj.shape[-1]
93+
except Exception: # pragma: no cover - extremely defensive
94+
n_volumes = None
95+
96+
if n_volumes is not None and gradients.shape[0] != n_volumes:
97+
if gradients.shape[1] == n_volumes:
98+
gradients = gradients.T
99+
else:
100+
raise ValueError(
101+
"Gradient table shape does not match the number of diffusion volumes: "
102+
f"expected {n_volumes} rows, found {gradients.shape[0]}"
103+
)
104+
elif n_volumes is None and gradients.shape[1] > gradients.shape[0]:
105+
gradients = gradients.T
106+
107+
self.gradients = gradients
108+
78109
def _getextra(self, idx: int | slice | tuple | np.ndarray) -> tuple[np.ndarray]:
79-
return (self.gradients[..., idx],)
110+
return (self.gradients[idx, ...],)
80111

81112
# For the sake of the docstring
82113
def __getitem__(
@@ -99,8 +130,8 @@ def __getitem__(
99130
motion_affine : :obj:`~numpy.ndarray` or ``None``
100131
The corresponding per-volume motion affine(s) or ``None`` if identity transform(s).
101132
gradient : :obj:`~numpy.ndarray`
102-
The corresponding gradient(s), which may have shape ``(4,)`` if a single volume
103-
or ``(4, k)`` if multiple volumes, or ``None`` if gradients are not available.
133+
The corresponding gradient(s), which may have shape ``(C,)`` if a single volume
134+
or ``(k, C)`` if multiple volumes, or ``None`` if gradients are not available.
104135
105136
"""
106137

@@ -126,11 +157,11 @@ def from_filename(cls, filename: Path | str) -> Self:
126157

127158
@property
128159
def bvals(self):
129-
return self.gradients[-1, ...]
160+
return self.gradients[..., -1]
130161

131162
@property
132163
def bvecs(self):
133-
return self.gradients[:-1, ...]
164+
return self.gradients[..., :-1]
134165

135166
def get_shells(
136167
self,
@@ -160,14 +191,12 @@ def get_shells(
160191
"""
161192

162193
_, bval_groups, bval_estimated = find_shelling_scheme(
163-
self.gradients[-1, ...],
194+
self.bvals,
164195
num_bins=num_bins,
165196
multishell_nonempty_bin_count_thr=multishell_nonempty_bin_count_thr,
166197
bval_cap=bval_cap,
167198
)
168-
indices = [
169-
np.hstack(np.where(np.isin(self.gradients[-1, ...], bvals))) for bvals in bval_groups
170-
]
199+
indices = [np.where(np.isin(self.bvals, bvals))[0] for bvals in bval_groups]
171200
return list(zip(bval_estimated, indices, strict=True))
172201

173202
def to_filename(
@@ -232,16 +261,14 @@ def to_nifti(
232261
bvals = self.bvals
233262

234263
# Rotate b-vectors if self.motion_affines is not None
235-
bvecs = (
236-
np.array(
237-
[
238-
transform_fsl_bvec(bvec, affine, self.affine, invert=True)
239-
for bvec, affine in zip(self.gradients.T, self.motion_affines, strict=True)
240-
]
241-
).T
242-
if self.motion_affines is not None
243-
else self.bvecs
244-
)
264+
if self.motion_affines is not None:
265+
rotated = [
266+
transform_fsl_bvec(bvec, affine, self.affine, invert=True)
267+
for bvec, affine in zip(self.gradients[:, :3], self.motion_affines, strict=True)
268+
]
269+
bvecs = np.asarray(rotated)
270+
else:
271+
bvecs = self.bvecs
245272

246273
# Parent's to_nifti to handle the primary NIfTI export.
247274
nii = super().to_nifti(
@@ -266,7 +293,7 @@ def to_nifti(
266293
# If inserting a b0 volume is requested, add the corresponding null
267294
# gradient value to the bval/bvec pair
268295
bvals = np.concatenate((np.zeros(1), bvals))
269-
bvecs = np.concatenate((np.zeros(3)[:, np.newaxis], bvecs), axis=-1)
296+
bvecs = np.vstack((np.zeros((1, bvecs.shape[1])), bvecs))
270297

271298
if filename is not None:
272299
# Convert filename to a Path object.
@@ -279,9 +306,8 @@ def to_nifti(
279306
bvecs_file = out_root.with_suffix(".bvec")
280307
bvals_file = out_root.with_suffix(".bval")
281308

282-
# Save bvecs and bvals to text files
283-
# Each row of bvecs is one direction (3 rows, N columns).
284-
np.savetxt(bvecs_file, bvecs, fmt=f"%.{bvecs_dec_places}f")
309+
# Save bvecs and bvals to text files. BIDS expects 3 rows x N columns.
310+
np.savetxt(bvecs_file, bvecs.T, fmt=f"%.{bvecs_dec_places}f")
285311
np.savetxt(bvals_file, bvals[np.newaxis, :], fmt=f"%.{bvals_dec_places}f")
286312

287313
return nii
@@ -313,10 +339,12 @@ def from_nii(
313339
motion_file : :obj:`os.pathlike`, optional
314340
A file containing head motion affine matrices (linear)
315341
gradients_file : :obj:`os.pathlike`, optional
316-
A text file containing the gradients table, shape (4, N) or (N, 4).
317-
If provided, it supersedes any .bvec / .bval combination.
342+
A text file containing the gradients table, shape (N, C) where the last column
343+
stores the b-values. If provided following the column-major convention(C, N),
344+
it will be transposed automatically. If provided, it supersedes any .bvec / .bval
345+
combination.
318346
bvec_file : :obj:`os.pathlike`, optional
319-
A text file containing b-vectors, shape (3, N).
347+
A text file containing b-vectors, shape (N, 3) or (3, N).
320348
bval_file : :obj:`os.pathlike`, optional
321349
A text file containing b-values, shape (N,).
322350
b0_file : :obj:`os.pathlike`, optional
@@ -359,31 +387,48 @@ def from_nii(
359387
stacklevel=2,
360388
)
361389
elif bvec_file and bval_file:
362-
bvecs = np.loadtxt(bvec_file, dtype="float32") # shape (3, N)
363-
if bvecs.shape[0] != 3 and bvecs.shape[1] == 3:
390+
bvecs = np.loadtxt(bvec_file, dtype="float32")
391+
if bvecs.ndim == 1:
392+
bvecs = bvecs[np.newaxis, :]
393+
if bvecs.shape[1] != 3 and bvecs.shape[0] == 3:
364394
bvecs = bvecs.T
365395

366-
bvals = np.loadtxt(bval_file, dtype="float32") # shape (N,)
367-
# Stack to shape (4, N)
368-
grad = np.vstack((bvecs, bvals))
396+
bvals = np.loadtxt(bval_file, dtype="float32")
397+
if bvals.ndim > 1:
398+
bvals = np.squeeze(bvals)
399+
grad = np.column_stack((bvecs, bvals))
369400
else:
370401
raise RuntimeError(
371402
"No gradient data provided. "
372403
"Please specify either a gradients_file or (bvec_file & bval_file)."
373404
)
374405

406+
if grad.ndim == 1:
407+
grad = grad[np.newaxis, :]
408+
409+
if grad.shape[1] < 2:
410+
raise ValueError("Gradient table must have at least two columns (direction + b-value).")
411+
412+
if grad.shape[1] != 4:
413+
if grad.shape[0] == 4:
414+
grad = grad.T
415+
else:
416+
raise ValueError(
417+
"Gradient table must have four columns (3 direction components and one b-value)."
418+
)
419+
375420
# 3) Create the DWI instance. We'll filter out volumes where b-value > b0_thres
376421
# as "DW volumes" if the user wants to store only the high-b volumes here
377-
gradmsk = (grad[-1] if grad.shape[0] == 4 else grad[:, -1]) > b0_thres
422+
gradmsk = grad[:, -1] > b0_thres
378423

379-
# The shape checking is somewhat flexible: (4, N) or (N, 4)
380424
dwi_obj = DWI(
381425
dataobj=fulldata[..., gradmsk],
382426
affine=img.affine,
383427
# We'll assign the filtered gradients below.
384428
)
385429

386-
dwi_obj.gradients = grad[:, gradmsk] if grad.shape[0] == 4 else grad[gradmsk, :].T
430+
dwi_obj.gradients = grad[gradmsk, :]
431+
dwi_obj._normalize_gradients()
387432

388433
# 4) b=0 volume (bzero)
389434
# If the user provided a b0_file, load it

src/nifreeze/data/filtering.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -240,7 +240,7 @@ def dwi_select_shells(
240240
Parameters
241241
----------
242242
gradients : :obj:`~numpy.ndarray`
243-
Gradients.
243+
Gradients arranged as ``(N, C)`` with the last column storing b-values.
244244
index : :obj:`int`
245245
Index of the shell data.
246246
atol_low : :obj:`float`, optional

src/nifreeze/model/_dipy.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,20 @@
3838
)
3939

4040

41+
def _cartesian_components(gtab: GradientTable | np.ndarray) -> np.ndarray:
42+
"""Return the gradient directions as Cartesian components."""
43+
44+
if hasattr(gtab, "bvecs"):
45+
components = gtab.bvecs
46+
else:
47+
components = np.asarray(gtab)
48+
if components.ndim == 1:
49+
components = components[np.newaxis, :]
50+
if components.shape[-1] > 3:
51+
components = components[:, :-1]
52+
return components
53+
54+
4155
def gp_prediction(
4256
model: GaussianProcessRegressor,
4357
gtab: GradientTable | np.ndarray,
@@ -69,7 +83,7 @@ def gp_prediction(
6983
7084
"""
7185

72-
X = gtab.bvecs.T if hasattr(gtab, "bvecs") else np.asarray(gtab)
86+
X = _cartesian_components(gtab)
7387

7488
# Check it's fitted as they do in sklearn internally
7589
# https://github.com/scikit-learn/scikit-learn/blob/972e17fe1aa12d481b120ad4a3dc076bae736931/\
@@ -167,7 +181,7 @@ def fit(
167181
# Extract b-vecs: scikit-learn wants (n_samples, n_features)
168182
# where n_features is 3, and n_samples the different diffusion-encoding
169183
# gradient orientations.
170-
X = gtab.bvecs if hasattr(gtab, "bvecs") else np.asarray(gtab)
184+
X = _cartesian_components(gtab)
171185

172186
# Data must have shape (n_samples, n_targets) where n_samples is
173187
# the number of diffusion-encoding gradient orientations, and n_targets

0 commit comments

Comments
 (0)