@@ -51,6 +51,7 @@ class BaseDWIModel(BaseModel):
5151 __slots__ = {
5252 "_model_class" : "Defining a model class, DIPY models are instantiated automagically" ,
5353 "_modelargs" : "Arguments acceptable by the underlying DIPY-like model." ,
54+ "_models" : "List with one or more (if parallel execution) model instances" ,
5455 }
5556
5657 def __init__ (self , dataset : DWI , ** kwargs ):
@@ -77,13 +78,21 @@ def __init__(self, dataset: DWI, **kwargs):
7778
7879 super ().__init__ (dataset , ** kwargs )
7980
80- def _fit (self , index , n_jobs = None , ** kwargs ):
81+ def _fit (self , index : int | None = None , n_jobs = None , ** kwargs ):
8182 """Fit the model chunk-by-chunk asynchronously"""
83+
8284 n_jobs = n_jobs or 1
8385
86+ if self ._locked_fit is not None :
87+ return n_jobs
88+
8489 brainmask = self ._dataset .brainmask
8590 idxmask = np .ones (len (self ._dataset ), dtype = bool )
86- idxmask [index ] = False
91+
92+ if index is not None :
93+ idxmask [index ] = False
94+ else :
95+ self ._locked_fit = True
8796
8897 data , _ , gtab = self ._dataset [idxmask ]
8998 # Select voxels within mask or just unravel 3D if no mask
@@ -96,14 +105,15 @@ def _fit(self, index, n_jobs=None, **kwargs):
96105
97106 if model_str :
98107 module_name , class_name = model_str .rsplit ("." , 1 )
99- self . _model = getattr (
108+ model = getattr (
100109 import_module (module_name ),
101110 class_name ,
102111 )(gtab , ** kwargs )
103112
104113 # One single CPU - linear execution (full model)
105114 if n_jobs == 1 :
106- self ._model , _ = _exec_fit (self ._model , data )
115+ _modelfit , _ = _exec_fit (model , data )
116+ self ._models = [_modelfit ]
107117 return 1
108118
109119 # Split data into chunks of group of slices
@@ -114,15 +124,14 @@ def _fit(self, index, n_jobs=None, **kwargs):
114124 # Parallelize process with joblib
115125 with Parallel (n_jobs = n_jobs ) as executor :
116126 results = executor (
117- delayed (_exec_fit )(self . _model , dchunk , i ) for i , dchunk in enumerate (data_chunks )
127+ delayed (_exec_fit )(model , dchunk , i ) for i , dchunk in enumerate (data_chunks )
118128 )
119129 for submodel , rindex in results :
120130 self ._models [rindex ] = submodel
121131
122- self ._model = None # Preempt further actions on the model
123132 return n_jobs
124133
125- def fit_predict (self , index : int , ** kwargs ):
134+ def fit_predict (self , index : int | None = None , ** kwargs ):
126135 """
127136 Predict asynchronously chunk-by-chunk the diffusion signal.
128137
@@ -133,8 +142,14 @@ def fit_predict(self, index: int, **kwargs):
133142
134143 """
135144
136- n_models = self ._fit (index , ** kwargs )
137- kwargs .pop ("n_jobs" )
145+ n_models = self ._fit (
146+ index ,
147+ n_jobs = kwargs .pop ("n_jobs" ),
148+ ** kwargs ,
149+ )
150+
151+ if index is None :
152+ return None
138153
139154 brainmask = self ._dataset .brainmask
140155 gradient = self ._dataset .gradients [:, index ]
@@ -149,9 +164,10 @@ def fit_predict(self, index: int, **kwargs):
149164 S0 = S0 [brainmask , ...] if brainmask is not None else S0 .reshape (- 1 )
150165
151166 if n_models == 1 :
152- predicted , _ = _exec_predict (self ._model , ** (kwargs | {"gtab" : gradient , "S0" : S0 }))
167+ predicted , _ = _exec_predict (
168+ self ._models [0 ], ** (kwargs | {"gtab" : gradient , "S0" : S0 })
169+ )
153170 else :
154- print (n_models , S0 )
155171 S0 = np .array_split (S0 , n_models ) if S0 is not None else np .full (n_models , None )
156172
157173 predicted = [None ] * n_models
@@ -221,9 +237,12 @@ def __init__(
221237 self ._th_high = th_high
222238 self ._detrend = detrend
223239
224- def fit_predict (self , index , * _ , ** kwargs ):
240+ def fit_predict (self , index : int | None = None , * _ , ** kwargs ):
225241 """Return the average map."""
226242
243+ if index is None :
244+ raise RuntimeError (f"Model { self .__class__ .__name__ } does not allow locking." )
245+
227246 bvalues = self ._dataset .gradients [:, - 1 ]
228247 bcenter = bvalues [index ]
229248
0 commit comments