2525
2626import numpy as np
2727from dipy .core .gradients import gradient_table_from_bvals_bvecs
28- from joblib import Parallel , delayed
2928
3029from nifreeze .data .dmri import (
3130 DEFAULT_CLIP_PERCENTILE ,
3534from nifreeze .model .base import BaseModel , ExpectationModel
3635
3736
38- def _exec_fit (model , data , chunk = None ):
39- retval = model .fit (data )
40- return retval , chunk
41-
42-
43- def _exec_predict (model , chunk = None , ** kwargs ):
44- """Propagate model parameters and call predict."""
45- return np .squeeze (model .predict (** kwargs )), chunk
46-
47-
4837class BaseDWIModel (BaseModel ):
4938 """Interface and default methods for DWI models."""
5039
5140 __slots__ = {
5241 "_model_class" : "Defining a model class, DIPY models are instantiated automagically" ,
5342 "_modelargs" : "Arguments acceptable by the underlying DIPY-like model." ,
54- "_models " : "List with one or more (if parallel execution) model instances " ,
43+ "_model_fit " : "Fitted model" ,
5544 }
5645
5746 def __init__ (self , dataset : DWI , ** kwargs ):
@@ -81,8 +70,6 @@ def __init__(self, dataset: DWI, **kwargs):
8170 def _fit (self , index : int | None = None , n_jobs = None , ** kwargs ):
8271 """Fit the model chunk-by-chunk asynchronously"""
8372
84- n_jobs = n_jobs or 1
85-
8673 if self ._locked_fit is not None :
8774 return n_jobs
8875
@@ -110,25 +97,11 @@ def _fit(self, index: int | None = None, n_jobs=None, **kwargs):
11097 class_name ,
11198 )(gtab , ** kwargs )
11299
113- # One single CPU - linear execution (full model)
114- if n_jobs == 1 :
115- _modelfit , _ = _exec_fit (model , data )
116- self ._models = [_modelfit ]
117- return 1
118-
119- # Split data into chunks of group of slices
120- data_chunks = np .array_split (data , n_jobs )
121-
122- self ._models = [None ] * n_jobs
123-
124- # Parallelize process with joblib
125- with Parallel (n_jobs = n_jobs ) as executor :
126- results = executor (
127- delayed (_exec_fit )(model , dchunk , i ) for i , dchunk in enumerate (data_chunks )
128- )
129- for submodel , rindex in results :
130- self ._models [rindex ] = submodel
131-
100+ self ._model_fit = model .fit (
101+ data ,
102+ engine = "serial" if n_jobs == 1 else "joblib" ,
103+ n_jobs = n_jobs ,
104+ )
132105 return n_jobs
133106
134107 def fit_predict (self , index : int | None = None , ** kwargs ):
@@ -142,13 +115,14 @@ def fit_predict(self, index: int | None = None, **kwargs):
142115
143116 """
144117
145- n_models = self ._fit (
118+ self ._fit (
146119 index ,
147120 n_jobs = kwargs .pop ("n_jobs" ),
148121 ** kwargs ,
149122 )
150123
151124 if index is None :
125+ self ._locked_fit = True
152126 return None
153127
154128 brainmask = self ._dataset .brainmask
@@ -163,29 +137,12 @@ def fit_predict(self, index: int | None = None, **kwargs):
163137 if S0 is not None :
164138 S0 = S0 [brainmask , ...] if brainmask is not None else S0 .reshape (- 1 )
165139
166- if n_models == 1 :
167- predicted , _ = _exec_predict (
168- self ._models [0 ], ** (kwargs | {"gtab" : gradient , "S0" : S0 })
140+ predicted = np .squeeze (
141+ self ._model_fit .predict (
142+ gtab = gradient ,
143+ S0 = S0 ,
169144 )
170- else :
171- S0 = np .array_split (S0 , n_models ) if S0 is not None else np .full (n_models , None )
172-
173- predicted = [None ] * n_models
174-
175- # Parallelize process with joblib
176- with Parallel (n_jobs = n_models ) as executor :
177- results = executor (
178- delayed (_exec_predict )(
179- model ,
180- chunk = i ,
181- ** (kwargs | {"gtab" : gradient , "S0" : S0 [i ]}),
182- )
183- for i , model in enumerate (self ._models )
184- )
185- for subprediction , index in results :
186- predicted [index ] = subprediction
187-
188- predicted = np .hstack (predicted )
145+ )
189146
190147 if brainmask is not None :
191148 retval = np .zeros_like (brainmask , dtype = "float32" )
0 commit comments