-
Notifications
You must be signed in to change notification settings - Fork 5
REF: Refactor PET model #204
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||
|---|---|---|---|---|---|---|---|---|
|
|
@@ -22,6 +22,7 @@ | |||||||
| # | ||||||||
| """Models for nuclear imaging.""" | ||||||||
|
|
||||||||
| from abc import ABC, ABCMeta, abstractmethod | ||||||||
| from os import cpu_count | ||||||||
| from typing import Union | ||||||||
|
|
||||||||
|
|
@@ -39,56 +40,78 @@ | |||||||
| """Time frame tolerance in seconds.""" | ||||||||
|
|
||||||||
|
|
||||||||
| class PETModel(BaseModel): | ||||||||
| """A PET imaging realignment model based on B-Spline approximation.""" | ||||||||
| def _exec_fit(model, data, chunk=None, **kwargs): | ||||||||
| return model.fit(data, **kwargs), chunk | ||||||||
|
|
||||||||
| __slots__ = ( | ||||||||
| "_t", | ||||||||
| "_x", | ||||||||
| "_xlim", | ||||||||
| "_order", | ||||||||
| "_n_ctrl", | ||||||||
| "_datashape", | ||||||||
| "_mask", | ||||||||
| "_smooth_fwhm", | ||||||||
| "_thresh_pct", | ||||||||
| ) | ||||||||
|
|
||||||||
| def _exec_predict(model, chunk=None, **kwargs): | ||||||||
| """Propagate model parameters and call predict.""" | ||||||||
| return np.squeeze(model.predict(**kwargs)), chunk | ||||||||
|
|
||||||||
|
|
||||||||
| class BasePETModel(BaseModel, ABC): | ||||||||
| """Interface and default methods for PET models.""" | ||||||||
|
|
||||||||
| __metaclass__ = ABCMeta | ||||||||
|
|
||||||||
| __slots__ = { | ||||||||
| "_data_mask": "A mask for the voxels that will be fitted and predicted", | ||||||||
| "_x": "", | ||||||||
| "_xlim": "", | ||||||||
| "_smooth_fwhm": "FWHM in mm over which to smooth", | ||||||||
| "_thresh_pct": "Thresholding percentile for the signal", | ||||||||
| "_model_class": "Defining a model class", | ||||||||
| "_modelargs": "Arguments acceptable by the underlying model", | ||||||||
| "_models": "List with one or more (if parallel execution) model instances", | ||||||||
| } | ||||||||
|
|
||||||||
| def __init__( | ||||||||
| self, | ||||||||
| dataset: PET, | ||||||||
| timepoints: list | np.ndarray | None = None, | ||||||||
| xlim: float | None = None, | ||||||||
| n_ctrl: int | None = None, | ||||||||
| order: int = 3, | ||||||||
| xlim: list | np.ndarray | None = None, | ||||||||
| smooth_fwhm: float = 10.0, | ||||||||
| thresh_pct: float = 20.0, | ||||||||
| **kwargs, | ||||||||
| ): | ||||||||
| """ | ||||||||
| Create the B-Spline interpolating matrix. | ||||||||
| """Initialization. | ||||||||
|
|
||||||||
| Parameters: | ||||||||
| ----------- | ||||||||
| timepoints : :obj:`list` | ||||||||
| Parameters | ||||||||
| ---------- | ||||||||
| timepoints : :obj:`list` or :obj:`~np.ndarray` | ||||||||
| The timing (in sec) of each PET volume. | ||||||||
| E.g., ``[15., 45., 75., 105., 135., 165., 210., 270., 330., | ||||||||
| 420., 540., 750., 1050., 1350., 1650., 1950., 2250., 2550.]`` | ||||||||
|
|
||||||||
| n_ctrl : :obj:`int` | ||||||||
| Number of B-Spline control points. If `None`, then one control point every | ||||||||
| six timepoints will be used. The less control points, the smoother is the | ||||||||
| model. | ||||||||
|
|
||||||||
| xlim : . | ||||||||
| . | ||||||||
| smooth_fwhm : obj:`float` | ||||||||
| FWHM in mm over which to smooth the signal. | ||||||||
| thresh_pct : obj:`float` | ||||||||
| Thresholding percentile for the signal. | ||||||||
| """ | ||||||||
|
|
||||||||
| super().__init__(dataset, **kwargs) | ||||||||
|
|
||||||||
| # Duck typing, instead of explicitly testing for PET type | ||||||||
| if not hasattr(dataset, "total_duration"): | ||||||||
| raise TypeError("Dataset MUST be a PET object.") | ||||||||
|
|
||||||||
| if not hasattr(dataset, "midframe"): | ||||||||
| raise ValueError("Dataset MUST have a midframe.") | ||||||||
|
|
||||||||
| self._data_mask = ( | ||||||||
| dataset.brainmask | ||||||||
| if dataset.brainmask is not None | ||||||||
| else np.ones(dataset.dataobj.shape[:3], dtype=bool) | ||||||||
| ) | ||||||||
|
|
||||||||
| # ToDo | ||||||||
| # Are timepoints and xlim features that all PET models require ?? | ||||||||
| if timepoints is None or xlim is None: | ||||||||
| raise TypeError("timepoints must be provided in initialization") | ||||||||
| raise ValueError("`timepoints` and `xlim` must be specified and have a nonzero value.") | ||||||||
|
|
||||||||
| self._order = order | ||||||||
| self._x = np.array(timepoints, dtype="float32") | ||||||||
| self._xlim = xlim | ||||||||
| self._x = np.asarray(timepoints, dtype="float32") | ||||||||
| self._xlim = np.asarray(xlim) | ||||||||
| self._smooth_fwhm = smooth_fwhm | ||||||||
| self._thresh_pct = thresh_pct | ||||||||
|
|
||||||||
|
|
@@ -97,62 +120,114 @@ def __init__( | |||||||
| if self._x[-1] > (self._xlim - DEFAULT_TIMEFRAME_MIDPOINT_TOL): | ||||||||
| raise ValueError("Last frame midpoint should not be equal or greater than duration") | ||||||||
|
|
||||||||
| # Calculate index coordinates in the B-Spline grid | ||||||||
| self._n_ctrl = n_ctrl or (len(timepoints) // 4) + 1 | ||||||||
| def _preprocess_data(self) -> np.ndarray: | ||||||||
| # ToDo | ||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What is the todo here?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If I recall correctly, the idea there is to preprocess only the data that is required; e.g. if we are using only a subset of the volumes, then we shouldn't preprocess the entire data, e.g. nifreeze/src/nifreeze/model/dmri.py Lines 122 to 124 in 4cfaadb
The |
||||||||
| # data, _, gtab = self._dataset[idxmask] ### This needs the PET data model to be changed | ||||||||
| data = self._dataset.dataobj | ||||||||
| brainmask = self._dataset.brainmask | ||||||||
|
|
||||||||
| # B-Spline knots | ||||||||
| self._t = np.arange(-3, float(self._n_ctrl) + 4, dtype="float32") | ||||||||
| # Preprocess the data | ||||||||
| if self._smooth_fwhm > 0: | ||||||||
| smoothed_img = smooth_image( | ||||||||
| nb.Nifti1Image(data, self._dataset.affine), self._smooth_fwhm | ||||||||
| ) | ||||||||
| data = smoothed_img.get_fdata() | ||||||||
|
|
||||||||
| self._datashape = None | ||||||||
| self._mask = None | ||||||||
| if self._thresh_pct > 0: | ||||||||
| thresh_val = np.percentile(data, self._thresh_pct) | ||||||||
| data[data < thresh_val] = 0 | ||||||||
|
|
||||||||
| # Convert data into V (voxels) x T (timepoints) | ||||||||
| return data.reshape((-1, data.shape[-1])) if brainmask is None else data[brainmask] | ||||||||
|
|
||||||||
| @property | ||||||||
| def is_fitted(self) -> bool: | ||||||||
| return self._locked_fit is not None | ||||||||
|
|
||||||||
| @abstractmethod | ||||||||
| def fit_predict(self, index: int | None = None, **kwargs) -> Union[np.ndarray, None]: | ||||||||
| """Predict the corrected volume.""" | ||||||||
| return None | ||||||||
|
|
||||||||
|
|
||||||||
| class BSplinePETModel(BasePETModel): | ||||||||
| """A PET imaging realignment model based on B-Spline approximation.""" | ||||||||
|
|
||||||||
| __slots__ = ( | ||||||||
| "_t", | ||||||||
| "_order", | ||||||||
| "_n_ctrl", | ||||||||
| ) | ||||||||
|
|
||||||||
| def __init__( | ||||||||
| self, | ||||||||
| dataset: PET, | ||||||||
| n_ctrl: int | None = None, | ||||||||
| order: int = 3, | ||||||||
| **kwargs, | ||||||||
| ): | ||||||||
| """Create the B-Spline interpolating matrix. | ||||||||
|
|
||||||||
| Parameters | ||||||||
| ---------- | ||||||||
| n_ctrl : :obj:`int` | ||||||||
| Number of B-Spline control points. If `None`, then one control point every | ||||||||
| six timepoints will be used. The less control points, the smoother is the | ||||||||
| model. | ||||||||
| order : :obj:`int` | ||||||||
| Order of the B-Spline approximation. | ||||||||
| """ | ||||||||
|
|
||||||||
| super().__init__(dataset, **kwargs) | ||||||||
|
|
||||||||
| self._order = order | ||||||||
|
|
||||||||
| # Calculate index coordinates in the B-Spline grid | ||||||||
| self._n_ctrl = n_ctrl or (len(self._x) // 4) + 1 | ||||||||
|
|
||||||||
| # B-Spline knots | ||||||||
| self._t = np.arange(-3, self._n_ctrl + 4, dtype="float32") | ||||||||
|
|
||||||||
| def _fit(self, index: int | None = None, n_jobs=None, **kwargs) -> int: | ||||||||
| """Fit the model.""" | ||||||||
|
|
||||||||
| n_jobs = n_jobs or min(cpu_count() or 1, 8) | ||||||||
|
|
||||||||
| if self._locked_fit is not None: | ||||||||
| return n_jobs | ||||||||
|
|
||||||||
| if index is not None: | ||||||||
| raise NotImplementedError("Fitting with held-out data is not supported") | ||||||||
| timepoints = kwargs.get("timepoints", None) or self._x | ||||||||
| x = np.asarray((np.array(timepoints, dtype="float32") / self._xlim) * self._n_ctrl) | ||||||||
|
|
||||||||
| data = self._dataset.dataobj | ||||||||
| brainmask = self._dataset.brainmask | ||||||||
|
|
||||||||
| if self._smooth_fwhm > 0: | ||||||||
| smoothed_img = smooth_image( | ||||||||
| nb.Nifti1Image(data, self._dataset.affine), self._smooth_fwhm | ||||||||
| ) | ||||||||
| data = smoothed_img.get_fdata() | ||||||||
|
|
||||||||
| if self._thresh_pct > 0: | ||||||||
| thresh_val = np.percentile(data, self._thresh_pct) | ||||||||
| data[data < thresh_val] = 0 | ||||||||
| data = self._preprocess_data() | ||||||||
|
|
||||||||
| # Convert data into V (voxels) x T (timepoints) | ||||||||
| data = data.reshape((-1, data.shape[-1])) if brainmask is None else data[brainmask] | ||||||||
| # ToDo | ||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Allowing _fit to override timepoints through kwargs duplicates information that was already validated and stored on self._x during initialization. Dropping that extra kwarg simplifies the API and avoids inconsistent inputs
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not sure I understand: maybe the reply is answering a question elsewhere? |
||||||||
| # Does not make sense to make timepoints be a kwarg if it is provided as a named parameter to __init__ | ||||||||
| timepoints = kwargs.get("timepoints", None) or self._x | ||||||||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should be removed |
||||||||
| x = np.asarray(timepoints, dtype="float32") / self._xlim * self._n_ctrl | ||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This can be replaced with x = self._x / self._xlim * self._n_ctrl |
||||||||
|
|
||||||||
| # A.shape = (T, K - 4); T= n. timepoints, K= n. knots (with padding) | ||||||||
| A = BSpline.design_matrix(x, self._t, k=self._order) | ||||||||
| AT = A.T | ||||||||
| ATdotA = AT @ A | ||||||||
|
|
||||||||
| # Parallelize process with joblib | ||||||||
| with Parallel(n_jobs=n_jobs or min(cpu_count() or 1, 8)) as executor: | ||||||||
| with Parallel(n_jobs=n_jobs) as executor: | ||||||||
| results = executor(delayed(cg)(ATdotA, AT @ v) for v in data) | ||||||||
|
|
||||||||
| self._locked_fit = np.array([r[0] for r in results]) | ||||||||
| self._locked_fit = np.asarray([r[0] for r in results]) | ||||||||
|
|
||||||||
| return n_jobs | ||||||||
|
|
||||||||
| def fit_predict(self, index: int | None = None, **kwargs) -> Union[np.ndarray, None]: | ||||||||
| """Return the corrected volume using B-spline interpolation.""" | ||||||||
|
|
||||||||
| # ToDo | ||||||||
| # Does the below apply to PET ? Martin has the return None statement | ||||||||
| # if index is None: | ||||||||
| # raise RuntimeError( | ||||||||
| # f"Model {self.__class__.__name__} does not allow locking.") | ||||||||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Another question for @mnoergaard. |
||||||||
|
|
||||||||
| # Fit the BSpline basis on all data | ||||||||
| if self._locked_fit is None: | ||||||||
| self._fit(index, n_jobs=kwargs.pop("n_jobs", None), **kwargs) | ||||||||
|
|
||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@mnoergaard question for you.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would say so, since we are dealing with motion correction, and hence would need temporal information. Currently, the PET model needs both the sampling midpoints and the scan’s total duration to work, so enforcing their presence in the base class keeps the API consistent. When new PET models that do not need temporal information are introduced, we can revisit this requirement; until then, providing dataset.midframe and dataset.total_duration (as done in the unit tests) is the intended usage.
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK, I see that this is related to #204 (review). I think the naming should then be made consistent across the PET data class and the model class. Also, I do not see why the model is given the timepoints, if these are hold by the PET data class, i.e. from #204 (review)
then we should not be providing the model with a copy of them.
If we make the parallel with the DWI class, the gradients are obtained from the dataset, e.g.:
Also, can the
xlimname be made somehow more descriptive or is it a name that is commonly used within the PET domain?