Skip to content

Commit 70c84d7

Browse files
committed
enh: use DIPY's parallelization
1 parent d16c97c commit 70c84d7

File tree

2 files changed

+15
-56
lines changed

2 files changed

+15
-56
lines changed

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,14 @@ license = "Apache-2.0"
2121
requires-python = ">=3.10"
2222
dependencies = [
2323
"attrs",
24-
"dipy>=1.5.0",
24+
"dipy>=1.10.0",
2525
"joblib",
2626
"nipype>= 1.5.1,<2.0",
2727
"nitransforms>=22.0.0,<24",
2828
"nireports",
2929
"numpy>=1.21.3",
3030
"nest-asyncio>=1.5.1",
31+
"ray",
3132
"scikit-image>=0.15.0",
3233
"scikit_learn>=1.3.0",
3334
"scipy>=1.8.0",

src/nifreeze/model/dmri.py

Lines changed: 13 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525

2626
import numpy as np
2727
from dipy.core.gradients import gradient_table_from_bvals_bvecs
28-
from joblib import Parallel, delayed
2928

3029
from nifreeze.data.dmri import (
3130
DEFAULT_CLIP_PERCENTILE,
@@ -38,16 +37,6 @@
3837
B_MIN = 50
3938

4039

41-
def _exec_fit(model, data, chunk=None):
42-
retval = model.fit(data)
43-
return retval, chunk
44-
45-
46-
def _exec_predict(model, chunk=None, **kwargs):
47-
"""Propagate model parameters and call predict."""
48-
return np.squeeze(model.predict(**kwargs)), chunk
49-
50-
5140
class BaseDWIModel(BaseModel):
5241
"""Interface and default methods for DWI models."""
5342

@@ -57,7 +46,7 @@ class BaseDWIModel(BaseModel):
5746
"_S0": "The S0 (b=0 reference signal) that will be fed into DIPY models",
5847
"_model_class": "Defining a model class, DIPY models are instantiated automagically",
5948
"_modelargs": "Arguments acceptable by the underlying DIPY-like model.",
60-
"_models": "List with one or more (if parallel execution) model instances",
49+
"_model_fit": "Fitted model",
6150
}
6251

6352
def __init__(self, dataset: DWI, max_b: float | int | None = None, **kwargs):
@@ -107,8 +96,6 @@ def __init__(self, dataset: DWI, max_b: float | int | None = None, **kwargs):
10796
def _fit(self, index: int | None = None, n_jobs=None, **kwargs):
10897
"""Fit the model chunk-by-chunk asynchronously"""
10998

110-
n_jobs = n_jobs or 1
111-
11299
if self._locked_fit is not None:
113100
return n_jobs
114101

@@ -136,25 +123,11 @@ def _fit(self, index: int | None = None, n_jobs=None, **kwargs):
136123
class_name,
137124
)(gtab, **kwargs)
138125

139-
# One single CPU - linear execution (full model)
140-
if n_jobs == 1:
141-
_modelfit, _ = _exec_fit(model, data)
142-
self._models = [_modelfit]
143-
return 1
144-
145-
# Split data into chunks of group of slices
146-
data_chunks = np.array_split(data, n_jobs)
147-
148-
self._models = [None] * n_jobs
149-
150-
# Parallelize process with joblib
151-
with Parallel(n_jobs=n_jobs) as executor:
152-
results = executor(
153-
delayed(_exec_fit)(model, dchunk, i) for i, dchunk in enumerate(data_chunks)
154-
)
155-
for submodel, rindex in results:
156-
self._models[rindex] = submodel
157-
126+
self._model_fit = model.fit(
127+
data,
128+
engine="serial" if n_jobs == 1 else "joblib",
129+
n_jobs=n_jobs,
130+
)
158131
return n_jobs
159132

160133
def fit_predict(self, index: int | None = None, **kwargs):
@@ -168,13 +141,14 @@ def fit_predict(self, index: int | None = None, **kwargs):
168141
169142
"""
170143

171-
n_models = self._fit(
144+
self._fit(
172145
index,
173146
n_jobs=kwargs.pop("n_jobs"),
174147
**kwargs,
175148
)
176149

177150
if index is None:
151+
self._locked_fit = True
178152
return None
179153

180154
gradient = self._dataset.gradients[:, index]
@@ -184,28 +158,12 @@ def fit_predict(self, index: int | None = None, **kwargs):
184158
gradient[np.newaxis, -1], gradient[np.newaxis, :-1]
185159
)
186160

187-
if n_models == 1:
188-
predicted, _ = _exec_predict(
189-
self._models[0], **(kwargs | {"gtab": gradient, "S0": self._S0})
161+
predicted = np.squeeze(
162+
self._model_fit.predict(
163+
gtab=gradient,
164+
S0=self._S0,
190165
)
191-
else:
192-
predicted = [None] * n_models
193-
S0 = np.array_split(self._S0, n_models)
194-
195-
# Parallelize process with joblib
196-
with Parallel(n_jobs=n_models) as executor:
197-
results = executor(
198-
delayed(_exec_predict)(
199-
model,
200-
chunk=i,
201-
**(kwargs | {"gtab": gradient, "S0": S0[i]}),
202-
)
203-
for i, model in enumerate(self._models)
204-
)
205-
for subprediction, index in results:
206-
predicted[index] = subprediction
207-
208-
predicted = np.hstack(predicted)
166+
)
209167

210168
retval = np.zeros_like(self._data_mask, dtype=self._dataset.dataobj.dtype)
211169
retval[self._data_mask, ...] = predicted

0 commit comments

Comments
 (0)