2525
2626import numpy as np
2727from dipy .core .gradients import gradient_table_from_bvals_bvecs
28- from joblib import Parallel , delayed
2928
3029from nifreeze .data .dmri import (
3130 DEFAULT_CLIP_PERCENTILE ,
3837B_MIN = 50
3938
4039
41- def _exec_fit (model , data , chunk = None ):
42- retval = model .fit (data )
43- return retval , chunk
44-
45-
46- def _exec_predict (model , chunk = None , ** kwargs ):
47- """Propagate model parameters and call predict."""
48- return np .squeeze (model .predict (** kwargs )), chunk
49-
50-
5140class BaseDWIModel (BaseModel ):
5241 """Interface and default methods for DWI models."""
5342
@@ -57,7 +46,7 @@ class BaseDWIModel(BaseModel):
5746 "_S0" : "The S0 (b=0 reference signal) that will be fed into DIPY models" ,
5847 "_model_class" : "Defining a model class, DIPY models are instantiated automagically" ,
5948 "_modelargs" : "Arguments acceptable by the underlying DIPY-like model." ,
60- "_models " : "List with one or more (if parallel execution) model instances " ,
49+ "_model_fit " : "Fitted model" ,
6150 }
6251
6352 def __init__ (self , dataset : DWI , max_b : float | int | None = None , ** kwargs ):
@@ -107,8 +96,6 @@ def __init__(self, dataset: DWI, max_b: float | int | None = None, **kwargs):
10796 def _fit (self , index : int | None = None , n_jobs = None , ** kwargs ):
10897 """Fit the model chunk-by-chunk asynchronously"""
10998
110- n_jobs = n_jobs or 1
111-
11299 if self ._locked_fit is not None :
113100 return n_jobs
114101
@@ -136,25 +123,11 @@ def _fit(self, index: int | None = None, n_jobs=None, **kwargs):
136123 class_name ,
137124 )(gtab , ** kwargs )
138125
139- # One single CPU - linear execution (full model)
140- if n_jobs == 1 :
141- _modelfit , _ = _exec_fit (model , data )
142- self ._models = [_modelfit ]
143- return 1
144-
145- # Split data into chunks of group of slices
146- data_chunks = np .array_split (data , n_jobs )
147-
148- self ._models = [None ] * n_jobs
149-
150- # Parallelize process with joblib
151- with Parallel (n_jobs = n_jobs ) as executor :
152- results = executor (
153- delayed (_exec_fit )(model , dchunk , i ) for i , dchunk in enumerate (data_chunks )
154- )
155- for submodel , rindex in results :
156- self ._models [rindex ] = submodel
157-
126+ self ._model_fit = model .fit (
127+ data ,
128+ engine = "serial" if n_jobs == 1 else "joblib" ,
129+ n_jobs = n_jobs ,
130+ )
158131 return n_jobs
159132
160133 def fit_predict (self , index : int | None = None , ** kwargs ):
@@ -168,13 +141,14 @@ def fit_predict(self, index: int | None = None, **kwargs):
168141
169142 """
170143
171- n_models = self ._fit (
144+ self ._fit (
172145 index ,
173146 n_jobs = kwargs .pop ("n_jobs" ),
174147 ** kwargs ,
175148 )
176149
177150 if index is None :
151+ self ._locked_fit = True
178152 return None
179153
180154 gradient = self ._dataset .gradients [:, index ]
@@ -184,28 +158,12 @@ def fit_predict(self, index: int | None = None, **kwargs):
184158 gradient [np .newaxis , - 1 ], gradient [np .newaxis , :- 1 ]
185159 )
186160
187- if n_models == 1 :
188- predicted , _ = _exec_predict (
189- self ._models [0 ], ** (kwargs | {"gtab" : gradient , "S0" : self ._S0 })
161+ predicted = np .squeeze (
162+ self ._model_fit .predict (
163+ gtab = gradient ,
164+ S0 = self ._S0 ,
190165 )
191- else :
192- predicted = [None ] * n_models
193- S0 = np .array_split (self ._S0 , n_models )
194-
195- # Parallelize process with joblib
196- with Parallel (n_jobs = n_models ) as executor :
197- results = executor (
198- delayed (_exec_predict )(
199- model ,
200- chunk = i ,
201- ** (kwargs | {"gtab" : gradient , "S0" : S0 [i ]}),
202- )
203- for i , model in enumerate (self ._models )
204- )
205- for subprediction , index in results :
206- predicted [index ] = subprediction
207-
208- predicted = np .hstack (predicted )
166+ )
209167
210168 retval = np .zeros_like (self ._data_mask , dtype = self ._dataset .dataobj .dtype )
211169 retval [self ._data_mask , ...] = predicted
0 commit comments