3434)
3535from nifreeze .model .base import BaseModel , ExpectationModel
3636
37+ S0_EPSILON = 1e-6
38+ B_MIN = 50
39+
3740
3841def _exec_fit (model , data , chunk = None ):
3942 retval = model .fit (data )
@@ -49,12 +52,15 @@ class BaseDWIModel(BaseModel):
4952 """Interface and default methods for DWI models."""
5053
5154 __slots__ = {
55+ "_max_b" : "The maximum b-value supported by the model" ,
56+ "_data_mask" : "A mask for the voxels that will be fitted and predicted" ,
57+ "_S0" : "The S0 (b=0 reference signal) that will be fed into DIPY models" ,
5258 "_model_class" : "Defining a model class, DIPY models are instantiated automagically" ,
5359 "_modelargs" : "Arguments acceptable by the underlying DIPY-like model." ,
5460 "_models" : "List with one or more (if parallel execution) model instances" ,
5561 }
5662
57- def __init__ (self , dataset : DWI , ** kwargs ):
63+ def __init__ (self , dataset : DWI , max_b : float | int | None = None , ** kwargs ):
5864 r"""Initialization.
5965
6066 Parameters
@@ -76,6 +82,26 @@ def __init__(self, dataset: DWI, **kwargs):
7682 f"DWI dataset is too small ({ dataset .gradients .shape [0 ]} directions)."
7783 )
7884
85+ if max_b is not None and max_b > B_MIN :
86+ self ._max_b = max_b
87+
88+ self ._data_mask = (
89+ dataset .brainmask
90+ if dataset .brainmask is not None
91+ else np .ones (dataset .dataobj .shape [:3 ], dtype = bool )
92+ )
93+
94+ # By default, set S0 to the 98% percentile of the DWI data within mask
95+ self ._S0 = np .full (
96+ self ._data_mask .sum (),
97+ np .round (np .percentile (dataset .dataobj [self ._data_mask , ...], 98 )),
98+ )
99+
100+ # If b=0 is present and not to be ignored, update brain mask and set
101+ if not kwargs .pop ("ignore_bzero" , False ) and dataset .bzero is not None :
102+ self ._data_mask [dataset .bzero < S0_EPSILON ] = False
103+ self ._S0 = dataset .bzero [self ._data_mask ]
104+
79105 super ().__init__ (dataset , ** kwargs )
80106
81107 def _fit (self , index : int | None = None , n_jobs = None , ** kwargs ):
@@ -151,26 +177,20 @@ def fit_predict(self, index: int | None = None, **kwargs):
151177 if index is None :
152178 return None
153179
154- brainmask = self ._dataset .brainmask
155180 gradient = self ._dataset .gradients [:, index ]
156181
157182 if "dipy" in getattr (self , "_model_class" , "" ):
158183 gradient = gradient_table_from_bvals_bvecs (
159184 gradient [np .newaxis , - 1 ], gradient [np .newaxis , :- 1 ]
160185 )
161186
162- S0 = self ._dataset .bzero
163- if S0 is not None :
164- S0 = S0 [brainmask , ...] if brainmask is not None else S0 .reshape (- 1 )
165-
166187 if n_models == 1 :
167188 predicted , _ = _exec_predict (
168- self ._models [0 ], ** (kwargs | {"gtab" : gradient , "S0" : S0 })
189+ self ._models [0 ], ** (kwargs | {"gtab" : gradient , "S0" : self . _S0 })
169190 )
170191 else :
171- S0 = np .array_split (S0 , n_models ) if S0 is not None else np .full (n_models , None )
172-
173192 predicted = [None ] * n_models
193+ S0 = np .array_split (self ._S0 , n_models )
174194
175195 # Parallelize process with joblib
176196 with Parallel (n_jobs = n_models ) as executor :
@@ -187,12 +207,8 @@ def fit_predict(self, index: int | None = None, **kwargs):
187207
188208 predicted = np .hstack (predicted )
189209
190- if brainmask is not None :
191- retval = np .zeros_like (brainmask , dtype = "float32" )
192- retval [brainmask , ...] = predicted
193- else :
194- retval = predicted .reshape (self ._dataset .dataobj .shape [:- 1 ])
195-
210+ retval = np .zeros_like (self ._data_mask , dtype = self ._dataset .dataobj .dtype )
211+ retval [self ._data_mask , ...] = predicted
196212 return retval
197213
198214
0 commit comments