-
Notifications
You must be signed in to change notification settings - Fork 5
MAINT: Add static type checking #28
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
f20fe5f
734862c
12de4f9
b06eff8
3eb17ab
781104a
d426757
bc48dcf
d3591b4
3b214c3
590d235
ecbd574
73afbf7
db1a1e7
87f7e90
b8601d9
24e4ec1
cc735d1
052cf36
374fe48
478c860
0037efd
9f3fdd5
05237a6
4eb80dd
e240ac6
b450b8d
2310945
02ff9c4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -27,17 +27,23 @@ | |
| from collections import namedtuple | ||
| from pathlib import Path | ||
| from tempfile import mkdtemp | ||
| from typing import Any | ||
| from typing import Any, Generic, TypeVarTuple | ||
|
|
||
| import attr | ||
| import h5py | ||
| import nibabel as nb | ||
| import numpy as np | ||
| from nibabel.spatialimages import SpatialHeader, SpatialImage | ||
| from nitransforms.linear import Affine | ||
|
|
||
| from nifreeze.utils.ndimage import load_api | ||
|
|
||
| NFDH5_EXT = ".h5" | ||
|
|
||
|
|
||
| Ts = TypeVarTuple("Ts") | ||
|
|
||
|
|
||
| def _data_repr(value: np.ndarray | None) -> str: | ||
| if value is None: | ||
| return "None" | ||
|
|
@@ -52,7 +58,7 @@ | |
|
|
||
|
|
||
| @attr.s(slots=True) | ||
| class BaseDataset: | ||
| class BaseDataset(Generic[*Ts]): | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is the main new thing I introduced. This class BaseDataset[*Ts=Unpack[tuple[()]]]:
...Which, while not particularly pretty, will allow us not to have to type |
||
| """ | ||
| Base dataset representation structure. | ||
|
|
@@ -68,15 +74,15 @@ | |
| """ | ||
|
|
||
| dataobj = attr.ib(default=None, repr=_data_repr, eq=attr.cmp_using(eq=_cmp)) | ||
| dataobj: np.ndarray = attr.ib(default=None, repr=_data_repr, eq=attr.cmp_using(eq=_cmp)) | ||
| """A :obj:`~numpy.ndarray` object for the data array.""" | ||
| affine = attr.ib(default=None, repr=_data_repr, eq=attr.cmp_using(eq=_cmp)) | ||
| affine: np.ndarray = attr.ib(default=None, repr=_data_repr, eq=attr.cmp_using(eq=_cmp)) | ||
| """Best affine for RAS-to-voxel conversion of coordinates (NIfTI header).""" | ||
| brainmask = attr.ib(default=None, repr=_data_repr, eq=attr.cmp_using(eq=_cmp)) | ||
| brainmask: np.ndarray = attr.ib(default=None, repr=_data_repr, eq=attr.cmp_using(eq=_cmp)) | ||
| """A boolean ndarray object containing a corresponding brainmask.""" | ||
| motion_affines = attr.ib(default=None, eq=attr.cmp_using(eq=_cmp)) | ||
| motion_affines: np.ndarray = attr.ib(default=None, eq=attr.cmp_using(eq=_cmp)) | ||
| """List of :obj:`~nitransforms.linear.Affine` realigning the dataset.""" | ||
| datahdr = attr.ib(default=None) | ||
| datahdr: SpatialHeader = attr.ib(default=None) | ||
| """A :obj:`~nibabel.spatialimages.SpatialHeader` header corresponding to the data.""" | ||
|
|
||
| _filepath = attr.ib( | ||
|
|
@@ -93,9 +99,13 @@ | |
|
|
||
| return self.dataobj.shape[-1] | ||
|
|
||
| def _getextra(self, idx: int | slice | tuple | np.ndarray) -> tuple[*Ts]: | ||
| # PY312: Default values for TypeVarTuples are not yet supported | ||
| return () # type: ignore[return-value] | ||
|
|
||
| def __getitem__( | ||
| self, idx: int | slice | tuple | np.ndarray | ||
| ) -> tuple[np.ndarray, np.ndarray | None]: | ||
| ) -> tuple[np.ndarray, np.ndarray | None, *Ts]: | ||
| """ | ||
| Returns volume(s) and corresponding affine(s) through fancy indexing. | ||
|
|
@@ -118,7 +128,7 @@ | |
| raise ValueError("No data available (dataobj is None).") | ||
|
|
||
| affine = self.motion_affines[idx] if self.motion_affines is not None else None | ||
| return self.dataobj[..., idx], affine | ||
| return self.dataobj[..., idx], affine, *self._getextra(idx) | ||
|
|
||
| @classmethod | ||
| def from_filename(cls, filename: Path | str) -> BaseDataset: | ||
|
|
@@ -159,9 +169,8 @@ | |
| The order of the spline interpolation. | ||
| """ | ||
| reference = namedtuple("ImageGrid", ("shape", "affine"))( | ||
| shape=self.dataobj.shape[:3], affine=self.affine | ||
| ) | ||
| ImageGrid = namedtuple("ImageGrid", ("shape", "affine")) | ||
| reference = ImageGrid(shape=self.dataobj.shape[:3], affine=self.affine) | ||
|
|
||
| xform = Affine(matrix=affine, reference=reference) | ||
|
|
||
|
|
@@ -227,7 +236,7 @@ | |
| compression_opts=compression_opts, | ||
| ) | ||
|
|
||
| def to_nifti(self, filename: Path) -> None: | ||
| def to_nifti(self, filename: Path | str) -> None: | ||
| """ | ||
| Write a NIfTI file to disk. | ||
|
|
@@ -247,7 +256,7 @@ | |
| filename: Path | str, | ||
| brainmask_file: Path | str | None = None, | ||
| motion_file: Path | str | None = None, | ||
| ) -> BaseDataset: | ||
| ) -> BaseDataset[()]: | ||
| """ | ||
| Load 4D data from a filename or an HDF5 file. | ||
|
|
@@ -279,11 +288,11 @@ | |
| if filename.name.endswith(NFDH5_EXT): | ||
| return BaseDataset.from_filename(filename) | ||
|
|
||
| img = nb.load(filename) | ||
| retval = BaseDataset(dataobj=img.dataobj, affine=img.affine) | ||
| img = load_api(filename, SpatialImage) | ||
| retval: BaseDataset[()] = BaseDataset(dataobj=np.asanyarray(img.dataobj), affine=img.affine) | ||
|
|
||
| if brainmask_file: | ||
| mask = nb.load(brainmask_file) | ||
| mask = load_api(brainmask_file, SpatialImage) | ||
| retval.brainmask = np.asanyarray(mask.dataobj) | ||
|
|
||
| return retval | ||
Uh oh!
There was an error while loading. Please reload this page.