2222#
2323"""Models for nuclear imaging."""
2424
25+ from os import cpu_count
26+
2527import numpy as np
2628from joblib import Parallel , delayed
2729
28- from nifreeze .exceptions import ModelNotFittedError
2930from nifreeze .model .base import BaseModel
3031
3132DEFAULT_TIMEFRAME_MIDPOINT_TOL = 1e-2
@@ -77,21 +78,15 @@ def __init__(self, timepoints=None, xlim=None, n_ctrl=None, order=3, **kwargs):
7778
7879 self ._coeff = None
7980
80- @property
81- def is_fitted (self ):
82- return self ._coeff is not None
83-
84- def fit (self , data , ** kwargs ):
81+ def _fit (self , n_jobs = None , ** kwargs ):
8582 """Fit the model."""
8683 from scipy .interpolate import BSpline
8784 from scipy .sparse .linalg import cg
8885
89- n_jobs = kwargs .pop ("n_jobs" , None ) or 1
90-
9186 timepoints = kwargs .get ("timepoints" , None ) or self ._x
9287 x = (np .array (timepoints , dtype = "float32" ) / self ._xlim ) * self ._n_ctrl
9388
94- self . _datashape = data . shape [: 3 ]
89+ data = self . _dataset . dataobj
9590
9691 # Convert data into V (voxels) x T (timepoints)
9792 data = data .reshape ((- 1 , data .shape [- 1 ])) if self ._mask is None else data [self ._mask ]
@@ -101,26 +96,22 @@ def fit(self, data, **kwargs):
10196 AT = A .T
10297 ATdotA = AT @ A
10398
104- # One single CPU - linear execution (full model)
105- if n_jobs == 1 :
106- self ._coeff = np .array ([cg (ATdotA , AT @ v )[0 ] for v in data ])
107- return
108-
10999 # Parallelize process with joblib
110- with Parallel (n_jobs = n_jobs ) as executor :
100+ with Parallel (n_jobs = n_jobs or min ( cpu_count () or 1 , 8 ) ) as executor :
111101 results = executor (delayed (cg )(ATdotA , AT @ v ) for v in data )
112102
113103 self ._coeff = np .array ([r [0 ] for r in results ])
114104
115- def predict (self , index = None , ** kwargs ):
105+ def fit_predict (self , index : int | None = None , ** kwargs ):
116106 """Return the corrected volume using B-spline interpolation."""
117107 from scipy .interpolate import BSpline
118108
119- if index is None :
120- raise ValueError ("A timepoint index to be simulated must be provided." )
109+ # Fit the BSpline basis on all data
110+ if self ._coeff is None :
111+ self ._fit (n_jobs = kwargs .pop ("n_jobs" , None ))
121112
122- if not self . _is_fitted :
123- raise ModelNotFittedError ( f" { type ( self ). __name__ } must be fitted before predicting" )
113+ if index is None : # If no index, just fit the data.
114+ return None
124115
125116 # Project sample timing into B-Spline coordinates
126117 x = (index / self ._xlim ) * self ._n_ctrl
@@ -130,9 +121,5 @@ def predict(self, index=None, **kwargs):
130121 # self._coeff is V (num. voxels) x K - 4
131122 predicted = np .squeeze (A @ self ._coeff .T )
132123
133- if self ._mask is None :
134- return predicted .reshape (self ._datashape )
135-
136- retval = np .zeros (self ._datashape , dtype = "float32" )
137- retval [self ._mask ] = predicted
138- return retval
124+ datashape = self ._dataset .dataobj .shape [:3 ]
125+ return predicted .reshape (datashape )
0 commit comments