Skip to content

Commit 7d519b7

Browse files
committed
fix: revise interface and implementation of PET model
1 parent 3913b3c commit 7d519b7

File tree

1 file changed

+13
-26
lines changed

1 file changed

+13
-26
lines changed

src/nifreeze/model/pet.py

Lines changed: 13 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,11 @@
2222
#
2323
"""Models for nuclear imaging."""
2424

25+
from os import cpu_count
26+
2527
import numpy as np
2628
from joblib import Parallel, delayed
2729

28-
from nifreeze.exceptions import ModelNotFittedError
2930
from nifreeze.model.base import BaseModel
3031

3132
DEFAULT_TIMEFRAME_MIDPOINT_TOL = 1e-2
@@ -77,21 +78,15 @@ def __init__(self, timepoints=None, xlim=None, n_ctrl=None, order=3, **kwargs):
7778

7879
self._coeff = None
7980

80-
@property
81-
def is_fitted(self):
82-
return self._coeff is not None
83-
84-
def fit(self, data, **kwargs):
81+
def _fit(self, n_jobs=None, **kwargs):
8582
"""Fit the model."""
8683
from scipy.interpolate import BSpline
8784
from scipy.sparse.linalg import cg
8885

89-
n_jobs = kwargs.pop("n_jobs", None) or 1
90-
9186
timepoints = kwargs.get("timepoints", None) or self._x
9287
x = (np.array(timepoints, dtype="float32") / self._xlim) * self._n_ctrl
9388

94-
self._datashape = data.shape[:3]
89+
data = self._dataset.dataobj
9590

9691
# Convert data into V (voxels) x T (timepoints)
9792
data = data.reshape((-1, data.shape[-1])) if self._mask is None else data[self._mask]
@@ -101,26 +96,22 @@ def fit(self, data, **kwargs):
10196
AT = A.T
10297
ATdotA = AT @ A
10398

104-
# One single CPU - linear execution (full model)
105-
if n_jobs == 1:
106-
self._coeff = np.array([cg(ATdotA, AT @ v)[0] for v in data])
107-
return
108-
10999
# Parallelize process with joblib
110-
with Parallel(n_jobs=n_jobs) as executor:
100+
with Parallel(n_jobs=n_jobs or min(cpu_count() or 1, 8)) as executor:
111101
results = executor(delayed(cg)(ATdotA, AT @ v) for v in data)
112102

113103
self._coeff = np.array([r[0] for r in results])
114104

115-
def predict(self, index=None, **kwargs):
105+
def fit_predict(self, index: int | None = None, **kwargs):
116106
"""Return the corrected volume using B-spline interpolation."""
117107
from scipy.interpolate import BSpline
118108

119-
if index is None:
120-
raise ValueError("A timepoint index to be simulated must be provided.")
109+
# Fit the BSpline basis on all data
110+
if self._coeff is None:
111+
self._fit(n_jobs=kwargs.pop("n_jobs", None))
121112

122-
if not self._is_fitted:
123-
raise ModelNotFittedError(f"{type(self).__name__} must be fitted before predicting")
113+
if index is None: # If no index, just fit the data.
114+
return None
124115

125116
# Project sample timing into B-Spline coordinates
126117
x = (index / self._xlim) * self._n_ctrl
@@ -130,9 +121,5 @@ def predict(self, index=None, **kwargs):
130121
# self._coeff is V (num. voxels) x K - 4
131122
predicted = np.squeeze(A @ self._coeff.T)
132123

133-
if self._mask is None:
134-
return predicted.reshape(self._datashape)
135-
136-
retval = np.zeros(self._datashape, dtype="float32")
137-
retval[self._mask] = predicted
138-
return retval
124+
datashape = self._dataset.dataobj.shape[:3]
125+
return predicted.reshape(datashape)

0 commit comments

Comments
 (0)