Skip to content

Commit 69d8849

Browse files
committed
FIX: Fix miscellaneous oversights in the base DWI model
Fix miscellaneous oversights in the base DWI model: - Provide the `None` default value for the `n_jobs` kwarg if not provided in the dictionary. Fixes: ``` n_models = self._fit( index, > n_jobs=kwargs.pop("n_jobs"), **kwargs, ) E KeyError: 'n_jobs' ``` - Make the diffusion gradients have a unit norm in the gtab generation fixture. Fixes: ``` bvecs = np.where(np.isnan(bvecs), 0, bvecs) bvecs_close_to_1 = abs(vector_norm(bvecs) - 1) <= atol if bvecs.shape[1] != 3: raise ValueError("bvecs should be (N, 3)") if not np.all(bvecs_close_to_1[dwi_mask]): > raise ValueError( "The vectors in bvecs should be unit (The tolerance " "can be modified as an input parameter)" ) E ValueError: The vectors in bvecs should be unit (The tolerance can be modified as an input parameter) ``` Add a test to check that DIPY models (e.g. DTI) can be instantiated and that the returned prediction has the expected shape.
1 parent 1b6b681 commit 69d8849

File tree

3 files changed

+30
-2
lines changed

3 files changed

+30
-2
lines changed

src/nifreeze/model/dmri.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,7 @@ def fit_predict(self, index: int | None = None, **kwargs):
178178
kwargs.pop("omp_nthreads", None) # Drop omp_nthreads
179179
n_models = self._fit(
180180
index,
181-
n_jobs=kwargs.pop("n_jobs"),
181+
n_jobs=kwargs.pop("n_jobs", None),
182182
**kwargs,
183183
)
184184

test/conftest.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
import nitransforms as nt
3030
import numpy as np
3131
import pytest
32+
from dipy.core.geometry import normalized_vector
3233
from dipy.io.gradients import read_bvals_bvecs
3334

3435
from nifreeze.data.dmri import DWI
@@ -251,7 +252,9 @@ def setup_random_gtab_data(request):
251252
bvals_shells = _generate_random_choices(request, shells, n_gradients)
252253

253254
bvals = np.hstack([b0s * [0], bvals_shells])
254-
bvecs = np.hstack([np.zeros((3, b0s)), rng.random((3, n_gradients))])
255+
bvecs = np.hstack(
256+
[np.zeros((3, b0s)), normalized_vector(rng.random((3, n_gradients)), axis=0)]
257+
)
255258

256259
return bvals, bvecs
257260

test/test_model.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,31 @@ def test_gp_model(evals, S0, snr, hsph_dirs, bval_shell):
157157
assert prediction.shape == (2,)
158158

159159

160+
@pytest.mark.random_dwi_data(50, (14, 16, 8), True)
161+
def test_dti_model(setup_random_dwi_data):
162+
(
163+
dwi_dataobj,
164+
affine,
165+
brainmask_dataobj,
166+
b0_dataobj,
167+
gradients,
168+
_,
169+
) = setup_random_dwi_data
170+
171+
dataset = DWI(
172+
dataobj=dwi_dataobj,
173+
affine=affine,
174+
brainmask=brainmask_dataobj,
175+
bzero=b0_dataobj,
176+
gradients=gradients,
177+
)
178+
179+
dtimodel = model.DTIModel(dataset)
180+
predicted = dtimodel.fit_predict(4)
181+
182+
assert predicted.shape == dwi_dataobj.shape[:-1]
183+
184+
160185
def test_factory(datadir):
161186
"""Check that the two different initialisations result in the same models"""
162187

0 commit comments

Comments
 (0)