Skip to content

Commit 3dd2b65

Browse files
committed
enh: allow "locking" of models with first fit
This allows having "frozen" models that are fit only once before entering the leave-one-out loop. These models use all the data in the fitting and then return always the same "prediction". This feature was lost with the refactor of the estimator. By moving it into the models, we can use them in a more flexible way.
1 parent 5ee55de commit 3dd2b65

File tree

3 files changed

+76
-32
lines changed

3 files changed

+76
-32
lines changed

src/nifreeze/model/base.py

Lines changed: 33 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -87,46 +87,59 @@ class BaseModel:
8787
8888
"""
8989

90-
__slots__ = ("_dataset",)
90+
__slots__ = ("_dataset", "_locked_fit")
9191

9292
def __init__(self, dataset, **kwargs):
9393
"""Base initialization."""
9494

95+
self._locked_fit = None
9596
self._dataset = dataset
9697
# Warn if mask not present
9798
if dataset.brainmask is None:
9899
warn(mask_absence_warn_msg, stacklevel=2)
99100

100101
@abstractmethod
101-
def fit_predict(self, index, **kwargs) -> np.ndarray:
102-
"""Fit and predict the indicate index of the dataset (abstract signature)."""
102+
def fit_predict(self, index: int | None = None, **kwargs) -> np.ndarray:
103+
"""
104+
Fit and predict the indicated index of the dataset (abstract signature).
105+
106+
If ``index`` is ``None``, then the model is executed in *single-fit mode* meaning
107+
that it will be run only once in all the data available.
108+
Please note that all the predictions of this model will suffer from data leakage
109+
from the original volume.
110+
111+
Parameters
112+
----------
113+
index : :obj:`int` or ``None``
114+
The index to predict.
115+
If ``None``, no prediction will be executed.
116+
117+
"""
103118
raise NotImplementedError("Cannot call fit_predict() on a BaseModel instance.")
104119

105120

106121
class TrivialModel(BaseModel):
107122
"""A trivial model that returns a given map always."""
108123

109-
__slots__ = ("_predicted",)
110-
111124
def __init__(self, dataset, predicted=None, **kwargs):
112125
"""Implement object initialization."""
113126

114127
super().__init__(dataset, **kwargs)
115-
self._predicted = (
128+
self._locked_fit = (
116129
predicted
117130
if predicted is not None
118131
# Infer from dataset if not provided at initialization
119132
else getattr(dataset, "reference", getattr(dataset, "bzero", None))
120133
)
121134

122-
if self._predicted is None:
135+
if self._locked_fit is None:
123136
raise TypeError("This model requires the predicted map at initialization")
124137

125138
def fit_predict(self, *_, **kwargs):
126139
"""Return the reference map."""
127140

128141
# No need to check fit (if not fitted, has raised already)
129-
return self._predicted
142+
return self._locked_fit
130143

131144

132145
class ExpectationModel(BaseModel):
@@ -139,7 +152,7 @@ def __init__(self, dataset, stat="median", **kwargs):
139152
super().__init__(dataset, **kwargs)
140153
self._stat = stat
141154

142-
def fit_predict(self, index: int, **kwargs):
155+
def fit_predict(self, index: int | None = None, **kwargs):
143156
"""
144157
Return the expectation map.
145158
@@ -149,12 +162,20 @@ def fit_predict(self, index: int, **kwargs):
149162
The volume index that is left-out in fitting, and then predicted.
150163
151164
"""
165+
166+
if self._locked_fit is not None:
167+
return self._locked_fit
168+
152169
# Select the summary statistic
153170
avg_func = getattr(np, kwargs.pop("stat", self._stat))
154171

155172
# Create index mask
156173
index_mask = np.ones(len(self._dataset), dtype=bool)
157-
index_mask[index] = False
158174

159-
# Calculate the average
160-
return avg_func(self._dataset[index_mask][0], axis=-1)
175+
if index is not None:
176+
index_mask[index] = False
177+
# Calculate the average
178+
return avg_func(self._dataset[index_mask][0], axis=-1)
179+
180+
self._locked_fit = avg_func(self._dataset[index_mask][0], axis=-1)
181+
return self._locked_fit

src/nifreeze/model/dmri.py

Lines changed: 31 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ class BaseDWIModel(BaseModel):
5151
__slots__ = {
5252
"_model_class": "Defining a model class, DIPY models are instantiated automagically",
5353
"_modelargs": "Arguments acceptable by the underlying DIPY-like model.",
54+
"_models": "List with one or more (if parallel execution) model instances",
5455
}
5556

5657
def __init__(self, dataset: DWI, **kwargs):
@@ -77,13 +78,21 @@ def __init__(self, dataset: DWI, **kwargs):
7778

7879
super().__init__(dataset, **kwargs)
7980

80-
def _fit(self, index, n_jobs=None, **kwargs):
81+
def _fit(self, index: int | None = None, n_jobs=None, **kwargs):
8182
"""Fit the model chunk-by-chunk asynchronously"""
83+
8284
n_jobs = n_jobs or 1
8385

86+
if self._locked_fit is not None:
87+
return n_jobs
88+
8489
brainmask = self._dataset.brainmask
8590
idxmask = np.ones(len(self._dataset), dtype=bool)
86-
idxmask[index] = False
91+
92+
if index is not None:
93+
idxmask[index] = False
94+
else:
95+
self._locked_fit = True
8796

8897
data, _, gtab = self._dataset[idxmask]
8998
# Select voxels within mask or just unravel 3D if no mask
@@ -96,14 +105,15 @@ def _fit(self, index, n_jobs=None, **kwargs):
96105

97106
if model_str:
98107
module_name, class_name = model_str.rsplit(".", 1)
99-
self._model = getattr(
108+
model = getattr(
100109
import_module(module_name),
101110
class_name,
102111
)(gtab, **kwargs)
103112

104113
# One single CPU - linear execution (full model)
105114
if n_jobs == 1:
106-
self._model, _ = _exec_fit(self._model, data)
115+
_modelfit, _ = _exec_fit(model, data)
116+
self._models = [_modelfit]
107117
return 1
108118

109119
# Split data into chunks of group of slices
@@ -114,15 +124,14 @@ def _fit(self, index, n_jobs=None, **kwargs):
114124
# Parallelize process with joblib
115125
with Parallel(n_jobs=n_jobs) as executor:
116126
results = executor(
117-
delayed(_exec_fit)(self._model, dchunk, i) for i, dchunk in enumerate(data_chunks)
127+
delayed(_exec_fit)(model, dchunk, i) for i, dchunk in enumerate(data_chunks)
118128
)
119129
for submodel, rindex in results:
120130
self._models[rindex] = submodel
121131

122-
self._model = None # Preempt further actions on the model
123132
return n_jobs
124133

125-
def fit_predict(self, index: int, **kwargs):
134+
def fit_predict(self, index: int | None = None, **kwargs):
126135
"""
127136
Predict asynchronously chunk-by-chunk the diffusion signal.
128137
@@ -133,8 +142,14 @@ def fit_predict(self, index: int, **kwargs):
133142
134143
"""
135144

136-
n_models = self._fit(index, **kwargs)
137-
kwargs.pop("n_jobs")
145+
n_models = self._fit(
146+
index,
147+
n_jobs=kwargs.pop("n_jobs"),
148+
**kwargs,
149+
)
150+
151+
if index is None:
152+
return None
138153

139154
brainmask = self._dataset.brainmask
140155
gradient = self._dataset.gradients[:, index]
@@ -149,9 +164,10 @@ def fit_predict(self, index: int, **kwargs):
149164
S0 = S0[brainmask, ...] if brainmask is not None else S0.reshape(-1)
150165

151166
if n_models == 1:
152-
predicted, _ = _exec_predict(self._model, **(kwargs | {"gtab": gradient, "S0": S0}))
167+
predicted, _ = _exec_predict(
168+
self._models[0], **(kwargs | {"gtab": gradient, "S0": S0})
169+
)
153170
else:
154-
print(n_models, S0)
155171
S0 = np.array_split(S0, n_models) if S0 is not None else np.full(n_models, None)
156172

157173
predicted = [None] * n_models
@@ -221,9 +237,12 @@ def __init__(
221237
self._th_high = th_high
222238
self._detrend = detrend
223239

224-
def fit_predict(self, index, *_, **kwargs):
240+
def fit_predict(self, index: int | None = None, *_, **kwargs):
225241
"""Return the average map."""
226242

243+
if index is None:
244+
raise RuntimeError(f"Model {self.__class__.__name__} does not allow locking.")
245+
227246
bvalues = self._dataset.gradients[:, -1]
228247
bcenter = bvalues[index]
229248

src/nifreeze/model/pet.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
class PETModel(BaseModel):
3737
"""A PET imaging realignment model based on B-Spline approximation."""
3838

39-
__slots__ = ("_t", "_x", "_xlim", "_order", "_coeff", "_n_ctrl")
39+
__slots__ = ("_t", "_x", "_xlim", "_order", "_n_ctrl")
4040

4141
def __init__(self, timepoints=None, xlim=None, n_ctrl=None, order=3, **kwargs):
4242
"""
@@ -76,13 +76,17 @@ def __init__(self, timepoints=None, xlim=None, n_ctrl=None, order=3, **kwargs):
7676
# B-Spline knots
7777
self._t = np.arange(-3, float(self._n_ctrl) + 4, dtype="float32")
7878

79-
self._coeff = None
80-
81-
def _fit(self, n_jobs=None, **kwargs):
79+
def _fit(self, index: int | None = None, n_jobs=None, **kwargs):
8280
"""Fit the model."""
8381
from scipy.interpolate import BSpline
8482
from scipy.sparse.linalg import cg
8583

84+
if self._locked_fit is not None:
85+
return n_jobs
86+
87+
if index is not None:
88+
raise NotImplementedError("Fitting with held-out data is not supported")
89+
8690
timepoints = kwargs.get("timepoints", None) or self._x
8791
x = (np.array(timepoints, dtype="float32") / self._xlim) * self._n_ctrl
8892

@@ -101,15 +105,15 @@ def _fit(self, n_jobs=None, **kwargs):
101105
with Parallel(n_jobs=n_jobs or min(cpu_count() or 1, 8)) as executor:
102106
results = executor(delayed(cg)(ATdotA, AT @ v) for v in data)
103107

104-
self._coeff = np.array([r[0] for r in results])
108+
self._locked_fit = np.array([r[0] for r in results])
105109

106110
def fit_predict(self, index: int | None = None, **kwargs):
107111
"""Return the corrected volume using B-spline interpolation."""
108112
from scipy.interpolate import BSpline
109113

110114
# Fit the BSpline basis on all data
111-
if self._coeff is None:
112-
self._fit(n_jobs=kwargs.pop("n_jobs", None))
115+
if self._locked_fit is None:
116+
self._fit(index, n_jobs=kwargs.pop("n_jobs", None), **kwargs)
113117

114118
if index is None: # If no index, just fit the data.
115119
return None
@@ -120,7 +124,7 @@ def fit_predict(self, index: int | None = None, **kwargs):
120124

121125
# A is 1 (num. timepoints) x C (num. coeff)
122126
# self._coeff is V (num. voxels) x K - 4
123-
predicted = np.squeeze(A @ self._coeff.T)
127+
predicted = np.squeeze(A @ self._locked_fit.T)
124128

125129
brainmask = self._dataset.brainmask
126130
datashape = self._dataset.dataobj.shape[:3]

0 commit comments

Comments
 (0)