Skip to content

Commit 1f93599

Browse files
committed
FIX: Fix miscellaneous oversights in the base DWI model
Fix miscellaneous oversights in the base DWI model: - Provide the `None` default value for the `n_jobs` kwarg if not provided in the dictionary. Fixes: ``` n_models = self._fit( index, > n_jobs=kwargs.pop("n_jobs"), **kwargs, ) E KeyError: 'n_jobs' ``` - Make the diffusion gradients have a unit norm in the gtab generation fixture. Fixes: ``` bvecs = np.where(np.isnan(bvecs), 0, bvecs) bvecs_close_to_1 = abs(vector_norm(bvecs) - 1) <= atol if bvecs.shape[1] != 3: raise ValueError("bvecs should be (N, 3)") if not np.all(bvecs_close_to_1[dwi_mask]): > raise ValueError( "The vectors in bvecs should be unit (The tolerance " "can be modified as an input parameter)" ) E ValueError: The vectors in bvecs should be unit (The tolerance can be modified as an input parameter) ``` Add a test to check that DIPY models (e.g. DTI) can be instantiated and that the returned prediction has the expected shape.
1 parent 1b6b681 commit 1f93599

File tree

3 files changed

+30
-2
lines changed

3 files changed

+30
-2
lines changed

src/nifreeze/model/dmri.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,7 @@ def fit_predict(self, index: int | None = None, **kwargs):
178178
kwargs.pop("omp_nthreads", None) # Drop omp_nthreads
179179
n_models = self._fit(
180180
index,
181-
n_jobs=kwargs.pop("n_jobs"),
181+
n_jobs=kwargs.pop("n_jobs", None),
182182
**kwargs,
183183
)
184184

test/conftest.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
import nitransforms as nt
3030
import numpy as np
3131
import pytest
32+
from dipy.core.geometry import normalized_vector
3233
from dipy.io.gradients import read_bvals_bvecs
3334

3435
from nifreeze.data.dmri import DWI
@@ -251,7 +252,7 @@ def setup_random_gtab_data(request):
251252
bvals_shells = _generate_random_choices(request, shells, n_gradients)
252253

253254
bvals = np.hstack([b0s * [0], bvals_shells])
254-
bvecs = np.hstack([np.zeros((3, b0s)), rng.random((3, n_gradients))])
255+
bvecs = np.hstack([np.zeros((3, b0s)), normalized_vector(rng.random((3, n_gradients)), axis=0)])
255256

256257
return bvals, bvecs
257258

test/test_model.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,33 @@ def test_gp_model(evals, S0, snr, hsph_dirs, bval_shell):
157157
assert prediction.shape == (2,)
158158

159159

160+
161+
162+
@pytest.mark.random_dwi_data(50, (14, 16, 8), True)
163+
def test_dti_model(setup_random_dwi_data):
164+
(
165+
dwi_dataobj,
166+
affine,
167+
brainmask_dataobj,
168+
b0_dataobj,
169+
gradients,
170+
_,
171+
) = setup_random_dwi_data
172+
173+
dataset = DWI(
174+
dataobj=dwi_dataobj,
175+
affine=affine,
176+
brainmask=brainmask_dataobj,
177+
bzero=b0_dataobj,
178+
gradients=gradients
179+
)
180+
181+
dtimodel = model.DTIModel(dataset)
182+
predicted = dtimodel.fit_predict(4)
183+
184+
assert predicted.shape == dwi_dataobj.shape[:-1]
185+
186+
160187
def test_factory(datadir):
161188
"""Check that the two different initialisations result in the same models"""
162189

0 commit comments

Comments
 (0)