Skip to content

Commit ae4f087

Browse files
committed
fix: repair tests
1 parent 500a4b2 commit ae4f087

File tree

3 files changed

+19
-13
lines changed

3 files changed

+19
-13
lines changed

src/nifreeze/data/dmri.py

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ class DWI(BaseDataset):
4949
be unwarped.
5050
"""
5151
gradients = attr.ib(default=None, repr=_data_repr, eq=attr.cmp_using(eq=_cmp))
52-
"""A 2D numpy array of the gradient table in RAS+B format (Nx4)."""
52+
"""A 2D numpy array of the gradient table (4xN)."""
5353
eddy_xfms = attr.ib(default=None)
5454
"""List of transforms to correct for estimatted eddy current distortions."""
5555

@@ -73,12 +73,12 @@ def __getitem__(
7373
The corresponding per-volume motion affine(s) or `None` if identity transform(s).
7474
gradient : np.ndarray
7575
The corresponding gradient(s), which may have shape ``(4,)`` if a single volume
76-
or ``(k, 4)`` if multiple volumes, or None if gradients are not available.
76+
or ``(4, k)`` if multiple volumes, or None if gradients are not available.
7777
7878
"""
7979

8080
data, affine = super().__getitem__(idx)
81-
return data, affine, self.gradients[idx, ...]
81+
return data, affine, self.gradients[..., idx]
8282

8383
def set_transform(self, index: int, affine: np.ndarray, order: int = 3) -> None:
8484
"""
@@ -106,14 +106,14 @@ def set_transform(self, index: int, affine: np.ndarray, order: int = 3) -> None:
106106
shape=self.dataobj.shape[:3], affine=self.affine
107107
)
108108
xform = Affine(matrix=affine, reference=reference)
109-
bvec = self.gradients[index, :3]
109+
bvec = self.gradients[:3, index]
110110

111111
# invert transform transform b-vector and origin
112112
r_bvec = (~xform).map([bvec, (0.0, 0.0, 0.0)])
113113
# Reset b-vector's origin
114114
new_bvec = r_bvec[1] - r_bvec[0]
115115
# Normalize and update
116-
self.gradients[index, :3] = new_bvec / np.linalg.norm(new_bvec)
116+
self.gradients[:3, index] = new_bvec / np.linalg.norm(new_bvec)
117117

118118
super().set_transform(index, affine, order)
119119

@@ -172,7 +172,7 @@ def to_filename(
172172
with h5py.File(filename, "r+") as out_file:
173173
out_file.attrs["Type"] = "dmri"
174174

175-
def to_nifti(self, filename: Path | str) -> None:
175+
def to_nifti(self, filename: Path | str, insert_b0: bool = False) -> None:
176176
"""
177177
Write a NIfTI 1.0 file to disk, and also write out the gradient table
178178
to sidecar text files (.bvec, .bval).
@@ -183,8 +183,15 @@ def to_nifti(self, filename: Path | str) -> None:
183183
The output NIfTI file path.
184184
185185
"""
186-
# First call the parent's to_nifti to handle the primary NIfTI export.
187-
super().to_nifti(filename)
186+
if not insert_b0:
187+
# Parent's to_nifti to handle the primary NIfTI export.
188+
super().to_nifti(filename)
189+
else:
190+
data = np.concatenate((self.bzero[..., np.newaxis], self.dataobj), axis=-1)
191+
nii = nb.Nifti1Image(data, self.affine, self.datahdr)
192+
if self.datahdr is None:
193+
nii.header.set_xyzt_units("mm")
194+
nii.to_filename(filename)
188195

189196
# Convert filename to a Path object.
190197
out_root = Path(filename).absolute()
@@ -202,8 +209,8 @@ def to_nifti(self, filename: Path | str) -> None:
202209

203210
# Save bvecs and bvals to text files
204211
# Each row of bvecs is one direction (3 rows, N columns).
205-
np.savetxt(bvecs_file, self.gradients[..., :3].T, fmt="%.6f")
206-
np.savetxt(bvals_file, self.gradients[..., -1], fmt="%.6f")
212+
np.savetxt(bvecs_file, self.gradients[:3, ...].T, fmt="%.6f")
213+
np.savetxt(bvals_file, self.gradients[:3, ...], fmt="%.6f")
207214

208215

209216
def load(
@@ -297,7 +304,7 @@ def load(
297304
# We'll assign the filtered gradients below.
298305
)
299306

300-
dwi_obj.gradients = grad[:, gradmsk] if grad.shape[0] == 4 else grad[gradmsk, :]
307+
dwi_obj.gradients = grad[:, gradmsk] if grad.shape[0] == 4 else grad[gradmsk, :].T
301308

302309
# 6) b=0 volume (bzero)
303310
# If the user provided a b0_file, load it

test/test_data_base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
import numpy as np
3030
import pytest
3131

32-
from nifreeze import NFDH5_EXT, BaseDataset, load
32+
from nifreeze.data.base import NFDH5_EXT, BaseDataset, load
3333

3434

3535
@pytest.fixture

test/test_data_dmri.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,6 @@ def test_equality_operator(tmp_path):
185185
gradients_file=gradients_fname,
186186
b0_file=b0_fname,
187187
brainmask_file=brainmask_fname,
188-
fmap_file=fieldmap_fname,
189188
b0_thres=b0_thres,
190189
)
191190
hdf5_filename = tmp_path / "test_dwi.h5"

0 commit comments

Comments
 (0)