Skip to content

Commit d16c97c

Browse files
authored
Merge pull request #153 from nipreps/enh/dmri-models-bzero
ENH: Improved handling of the *b=0* and a max *b* in dMRI datasets
2 parents 9d656d5 + 298516c commit d16c97c

File tree

3 files changed

+43
-15
lines changed

3 files changed

+43
-15
lines changed

src/nifreeze/cli/parser.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,12 @@ def build_parser() -> ArgumentParser:
143143
help="A NIfTI file containing the b-zero reference",
144144
)
145145

146+
g_dmri.add_argument(
147+
"--ignore-b0",
148+
action="store_true",
149+
help="Ignore the low-b reference and use the robust signal maximum",
150+
)
151+
146152
g_pet = parser.add_argument_group("Options for PET inputs")
147153
g_pet.add_argument(
148154
"--timing-file",

src/nifreeze/cli/run.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,13 +67,19 @@ def main(argv=None) -> None:
6767
**extra_kwargs,
6868
)
6969

70+
model_kwargs = {}
71+
72+
if args.ignore_b0:
73+
model_kwargs["ignore_bzero"] = True
74+
7075
prev_model: Estimator | None = None
7176
for _model in args.models:
7277
single_fit = _model.lower().startswith("single")
7378
estimator: Estimator = Estimator(
7479
_model.lower().replace("single", ""),
7580
prev=prev_model,
7681
single_fit=single_fit,
82+
model_kwargs=model_kwargs,
7783
)
7884
prev_model = estimator
7985

src/nifreeze/model/dmri.py

Lines changed: 31 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,9 @@
3434
)
3535
from nifreeze.model.base import BaseModel, ExpectationModel
3636

37+
S0_EPSILON = 1e-6
38+
B_MIN = 50
39+
3740

3841
def _exec_fit(model, data, chunk=None):
3942
retval = model.fit(data)
@@ -49,12 +52,15 @@ class BaseDWIModel(BaseModel):
4952
"""Interface and default methods for DWI models."""
5053

5154
__slots__ = {
55+
"_max_b": "The maximum b-value supported by the model",
56+
"_data_mask": "A mask for the voxels that will be fitted and predicted",
57+
"_S0": "The S0 (b=0 reference signal) that will be fed into DIPY models",
5258
"_model_class": "Defining a model class, DIPY models are instantiated automagically",
5359
"_modelargs": "Arguments acceptable by the underlying DIPY-like model.",
5460
"_models": "List with one or more (if parallel execution) model instances",
5561
}
5662

57-
def __init__(self, dataset: DWI, **kwargs):
63+
def __init__(self, dataset: DWI, max_b: float | int | None = None, **kwargs):
5864
r"""Initialization.
5965
6066
Parameters
@@ -76,6 +82,26 @@ def __init__(self, dataset: DWI, **kwargs):
7682
f"DWI dataset is too small ({dataset.gradients.shape[0]} directions)."
7783
)
7884

85+
if max_b is not None and max_b > B_MIN:
86+
self._max_b = max_b
87+
88+
self._data_mask = (
89+
dataset.brainmask
90+
if dataset.brainmask is not None
91+
else np.ones(dataset.dataobj.shape[:3], dtype=bool)
92+
)
93+
94+
# By default, set S0 to the 98% percentile of the DWI data within mask
95+
self._S0 = np.full(
96+
self._data_mask.sum(),
97+
np.round(np.percentile(dataset.dataobj[self._data_mask, ...], 98)),
98+
)
99+
100+
# If b=0 is present and not to be ignored, update brain mask and set
101+
if not kwargs.pop("ignore_bzero", False) and dataset.bzero is not None:
102+
self._data_mask[dataset.bzero < S0_EPSILON] = False
103+
self._S0 = dataset.bzero[self._data_mask]
104+
79105
super().__init__(dataset, **kwargs)
80106

81107
def _fit(self, index: int | None = None, n_jobs=None, **kwargs):
@@ -151,26 +177,20 @@ def fit_predict(self, index: int | None = None, **kwargs):
151177
if index is None:
152178
return None
153179

154-
brainmask = self._dataset.brainmask
155180
gradient = self._dataset.gradients[:, index]
156181

157182
if "dipy" in getattr(self, "_model_class", ""):
158183
gradient = gradient_table_from_bvals_bvecs(
159184
gradient[np.newaxis, -1], gradient[np.newaxis, :-1]
160185
)
161186

162-
S0 = self._dataset.bzero
163-
if S0 is not None:
164-
S0 = S0[brainmask, ...] if brainmask is not None else S0.reshape(-1)
165-
166187
if n_models == 1:
167188
predicted, _ = _exec_predict(
168-
self._models[0], **(kwargs | {"gtab": gradient, "S0": S0})
189+
self._models[0], **(kwargs | {"gtab": gradient, "S0": self._S0})
169190
)
170191
else:
171-
S0 = np.array_split(S0, n_models) if S0 is not None else np.full(n_models, None)
172-
173192
predicted = [None] * n_models
193+
S0 = np.array_split(self._S0, n_models)
174194

175195
# Parallelize process with joblib
176196
with Parallel(n_jobs=n_models) as executor:
@@ -187,12 +207,8 @@ def fit_predict(self, index: int | None = None, **kwargs):
187207

188208
predicted = np.hstack(predicted)
189209

190-
if brainmask is not None:
191-
retval = np.zeros_like(brainmask, dtype="float32")
192-
retval[brainmask, ...] = predicted
193-
else:
194-
retval = predicted.reshape(self._dataset.dataobj.shape[:-1])
195-
210+
retval = np.zeros_like(self._data_mask, dtype=self._dataset.dataobj.dtype)
211+
retval[self._data_mask, ...] = predicted
196212
return retval
197213

198214

0 commit comments

Comments
 (0)