diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index a1acb0d74..43d95a5be 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -95,8 +95,7 @@ jobs: continue-on-error: true strategy: matrix: - check: ['spellcheck'] - + check: ['spellcheck', 'typecheck'] steps: - uses: actions/checkout@v4 - name: Install the latest version of uv diff --git a/docs/conf.py b/docs/conf.py index 02488e10e..93397d99f 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -54,7 +54,6 @@ "nipype", "nitime", "nitransforms", - "numpy", "pandas", "scipy", "seaborn", @@ -154,20 +153,20 @@ # -- Options for LaTeX output ------------------------------------------------ -latex_elements = { - # The paper size ('letterpaper' or 'a4paper'). - # - # 'papersize': 'letterpaper', - # The font size ('10pt', '11pt' or '12pt'). - # - # 'pointsize': '10pt', - # Additional stuff for the LaTeX preamble. - # - # 'preamble': '', - # Latex figure (float) alignment - # - # 'figure_align': 'htbp', -} +# latex_elements = { +# # The paper size ('letterpaper' or 'a4paper'). +# # +# # 'papersize': 'letterpaper', +# # The font size ('10pt', '11pt' or '12pt'). +# # +# # 'pointsize': '10pt', +# # Additional stuff for the LaTeX preamble. +# # +# # 'preamble': '', +# # Latex figure (float) alignment +# # +# # 'figure_align': 'htbp', +# } # Grouping the document tree into LaTeX files. List of tuples # (source start file, target name, title, diff --git a/pyproject.toml b/pyproject.toml index 205ac6960..6252ce983 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -75,6 +75,15 @@ test = [ "pytest-env", "pytest-xdist >= 1.28" ] +types = [ + "pandas-stubs", + "types-setuptools", + "scipy-stubs", + "types-PyYAML", + "types-tqdm", + "pytest", + "microsoft-python-type-stubs @ git+https://github.com/microsoft/python-type-stubs.git", +] notebooks = [ "jupyter", @@ -138,6 +147,21 @@ version-file = "src/nifreeze/_version.py" # Developer tool configurations # +[[tool.mypy.overrides]] +module = [ + "nipype.*", + "nilearn.*", + "nireports.*", + "nitransforms.*", + "seaborn", + "dipy.*", + "smac.*", + "joblib", + "h5py", + "ConfigSpace", +] +ignore_missing_imports = true + [tool.ruff] line-length = 99 target-version = "py310" diff --git a/scripts/dwi_gp_estimation_error_analysis.py b/scripts/dwi_gp_estimation_error_analysis.py index eb393b973..3feef2c1e 100644 --- a/scripts/dwi_gp_estimation_error_analysis.py +++ b/scripts/dwi_gp_estimation_error_analysis.py @@ -49,7 +49,7 @@ def cross_validate( cv: int, n_repeats: int, gpr: DiffusionGPR, -) -> dict[int, list[tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]]]: +) -> np.ndarray: """ Perform the experiment by estimating the dMRI signal using a Gaussian process model. @@ -74,7 +74,14 @@ def cross_validate( """ rkf = RepeatedKFold(n_splits=cv, n_repeats=n_repeats) - scores = cross_val_score(gpr, X, y, scoring="neg_root_mean_squared_error", cv=rkf) + # scikit-learn stubs do not recognize rkf as a BaseCrossValidator + scores = cross_val_score( + gpr, + X, + y, + scoring="neg_root_mean_squared_error", + cv=rkf, # type: ignore[arg-type] + ) return scores @@ -204,10 +211,10 @@ def main() -> None: if args.kfold: # Use Scikit-learn cross validation - scores = defaultdict(list, {}) + scores: dict[str, list] = defaultdict(list, {}) for n in args.kfold: for i in range(args.repeats): - cv_scores = -1.0 * cross_validate(X, y.T, n, gpr) + cv_scores = -1.0 * cross_validate(X, y.T, n, i, gpr) scores["rmse"] += cv_scores.tolist() scores["repeat"] += [i] * len(cv_scores) scores["n_folds"] += [n] * len(cv_scores) @@ -217,7 +224,7 @@ def main() -> None: print(f"Finished {n}-fold cross-validation") scores_df = pd.DataFrame(scores) - scores_df.to_csv(args.output_scores, sep="\t", index=None, na_rep="n/a") + scores_df.to_csv(args.output_scores, sep="\t", index=False, na_rep="n/a") grouped = scores_df.groupby(["n_folds"]) print(grouped[["rmse"]].mean()) diff --git a/scripts/dwi_gp_estimation_error_analysis_plot.py b/scripts/dwi_gp_estimation_error_analysis_plot.py index 1656f6b6e..5b0d5729f 100644 --- a/scripts/dwi_gp_estimation_error_analysis_plot.py +++ b/scripts/dwi_gp_estimation_error_analysis_plot.py @@ -89,10 +89,18 @@ def main() -> None: df = pd.read_csv(args.error_data_fname, sep="\t", keep_default_na=False, na_values="n/a") # Plot the prediction error - kfolds = sorted(np.unique(df["n_folds"].values)) - snr = np.unique(df["snr"].values).item() - bval = np.unique(df["bval"].values).item() - rmse_data = [df.groupby("n_folds").get_group(k)["rmse"].values for k in kfolds] + kfolds = sorted(pd.unique(df["n_folds"])) + snr = pd.unique(df["snr"]) + if len(snr) == 1: + snr = snr[0] + else: + raise ValueError(f"More than one unique SNR value: {snr}") + bval = pd.unique(df["bval"]) + if len(bval) == 1: + bval = bval[0] + else: + raise ValueError(f"More than one unique bval value: {bval}") + rmse_data = np.asarray([df.groupby("n_folds").get_group(k)["rmse"].values for k in kfolds]) axis = 1 mean = np.mean(rmse_data, axis=axis) std_dev = np.std(rmse_data, axis=axis) diff --git a/scripts/dwi_gp_estimation_simulated_signal.py b/scripts/dwi_gp_estimation_simulated_signal.py index fd22a81a6..3b2081e5a 100644 --- a/scripts/dwi_gp_estimation_simulated_signal.py +++ b/scripts/dwi_gp_estimation_simulated_signal.py @@ -132,11 +132,11 @@ def main() -> None: # Fit the Gaussian Process regressor and predict on an arbitrary number of # directions - a = 1.15 - lambda_s = 120 + beta_a = 1.15 + beta_l = 120 alpha = 100 gpr = DiffusionGPR( - kernel=SphericalKriging(a=a, lambda_s=lambda_s), + kernel=SphericalKriging(beta_a=beta_a, beta_l=beta_l), alpha=alpha, optimizer=None, ) @@ -154,6 +154,8 @@ def main() -> None: X_test = np.vstack([gtab[~gtab.b0s_mask].bvecs, sph.vertices]) predictions = gpr_fit.predict(X_test) + if isinstance(predictions, tuple): + predictions = predictions[0] # Save the predicted data testsims.serialize_dwi(predictions.T, args.dwi_pred_data_fname) diff --git a/scripts/optimize_registration.py b/scripts/optimize_registration.py index d9f732ad7..abc05d6fb 100644 --- a/scripts/optimize_registration.py +++ b/scripts/optimize_registration.py @@ -127,12 +127,13 @@ async def train_coro( moving_path = tmp_folder / f"test-{index:04d}.nii.gz" (~xfm).apply(refnii, reference=refnii).to_filename(moving_path) + _kwargs = {"output_transform_prefix": f"conversion-{index:04d}", **align_kwargs} + cmdline = erants.generate_command( fixed_path, moving_path, fixedmask_path=brainmask_path, - output_transform_prefix=f"conversion-{index:04d}", - **align_kwargs, + **_kwargs, ) tasks.append( diff --git a/src/nifreeze/cli/parser.py b/src/nifreeze/cli/parser.py index 98655f9b3..117bcde7b 100644 --- a/src/nifreeze/cli/parser.py +++ b/src/nifreeze/cli/parser.py @@ -29,13 +29,13 @@ import yaml -def _parse_yaml_config(file_path: Path) -> dict: +def _parse_yaml_config(file_path: str) -> dict: """ Parse YAML configuration file. Parameters ---------- - file_path : Path + file_path : str Path to the YAML configuration file. Returns diff --git a/src/nifreeze/data/base.py b/src/nifreeze/data/base.py index 9eae6edd8..4fee5a01e 100644 --- a/src/nifreeze/data/base.py +++ b/src/nifreeze/data/base.py @@ -27,17 +27,23 @@ from collections import namedtuple from pathlib import Path from tempfile import mkdtemp -from typing import Any +from typing import Any, Generic, TypeVarTuple import attr import h5py import nibabel as nb import numpy as np +from nibabel.spatialimages import SpatialHeader, SpatialImage from nitransforms.linear import Affine +from nifreeze.utils.ndimage import load_api + NFDH5_EXT = ".h5" +Ts = TypeVarTuple("Ts") + + def _data_repr(value: np.ndarray | None) -> str: if value is None: return "None" @@ -52,7 +58,7 @@ def _cmp(lh: Any, rh: Any) -> bool: @attr.s(slots=True) -class BaseDataset: +class BaseDataset(Generic[*Ts]): """ Base dataset representation structure. @@ -68,15 +74,15 @@ class BaseDataset: """ - dataobj = attr.ib(default=None, repr=_data_repr, eq=attr.cmp_using(eq=_cmp)) + dataobj: np.ndarray = attr.ib(default=None, repr=_data_repr, eq=attr.cmp_using(eq=_cmp)) """A :obj:`~numpy.ndarray` object for the data array.""" - affine = attr.ib(default=None, repr=_data_repr, eq=attr.cmp_using(eq=_cmp)) + affine: np.ndarray = attr.ib(default=None, repr=_data_repr, eq=attr.cmp_using(eq=_cmp)) """Best affine for RAS-to-voxel conversion of coordinates (NIfTI header).""" - brainmask = attr.ib(default=None, repr=_data_repr, eq=attr.cmp_using(eq=_cmp)) + brainmask: np.ndarray = attr.ib(default=None, repr=_data_repr, eq=attr.cmp_using(eq=_cmp)) """A boolean ndarray object containing a corresponding brainmask.""" - motion_affines = attr.ib(default=None, eq=attr.cmp_using(eq=_cmp)) + motion_affines: np.ndarray = attr.ib(default=None, eq=attr.cmp_using(eq=_cmp)) """List of :obj:`~nitransforms.linear.Affine` realigning the dataset.""" - datahdr = attr.ib(default=None) + datahdr: SpatialHeader = attr.ib(default=None) """A :obj:`~nibabel.spatialimages.SpatialHeader` header corresponding to the data.""" _filepath = attr.ib( @@ -93,9 +99,13 @@ def __len__(self) -> int: return self.dataobj.shape[-1] + def _getextra(self, idx: int | slice | tuple | np.ndarray) -> tuple[*Ts]: + # PY312: Default values for TypeVarTuples are not yet supported + return () # type: ignore[return-value] + def __getitem__( self, idx: int | slice | tuple | np.ndarray - ) -> tuple[np.ndarray, np.ndarray | None]: + ) -> tuple[np.ndarray, np.ndarray | None, *Ts]: """ Returns volume(s) and corresponding affine(s) through fancy indexing. @@ -118,7 +128,7 @@ def __getitem__( raise ValueError("No data available (dataobj is None).") affine = self.motion_affines[idx] if self.motion_affines is not None else None - return self.dataobj[..., idx], affine + return self.dataobj[..., idx], affine, *self._getextra(idx) @classmethod def from_filename(cls, filename: Path | str) -> BaseDataset: @@ -159,9 +169,8 @@ def set_transform(self, index: int, affine: np.ndarray, order: int = 3) -> None: The order of the spline interpolation. """ - reference = namedtuple("ImageGrid", ("shape", "affine"))( - shape=self.dataobj.shape[:3], affine=self.affine - ) + ImageGrid = namedtuple("ImageGrid", ("shape", "affine")) + reference = ImageGrid(shape=self.dataobj.shape[:3], affine=self.affine) xform = Affine(matrix=affine, reference=reference) @@ -227,7 +236,7 @@ def to_filename( compression_opts=compression_opts, ) - def to_nifti(self, filename: Path) -> None: + def to_nifti(self, filename: Path | str) -> None: """ Write a NIfTI file to disk. @@ -247,7 +256,7 @@ def load( filename: Path | str, brainmask_file: Path | str | None = None, motion_file: Path | str | None = None, -) -> BaseDataset: +) -> BaseDataset[()]: """ Load 4D data from a filename or an HDF5 file. @@ -279,11 +288,11 @@ def load( if filename.name.endswith(NFDH5_EXT): return BaseDataset.from_filename(filename) - img = nb.load(filename) - retval = BaseDataset(dataobj=img.dataobj, affine=img.affine) + img = load_api(filename, SpatialImage) + retval: BaseDataset[()] = BaseDataset(dataobj=np.asanyarray(img.dataobj), affine=img.affine) if brainmask_file: - mask = nb.load(brainmask_file) + mask = load_api(brainmask_file, SpatialImage) retval.brainmask = np.asanyarray(mask.dataobj) return retval diff --git a/src/nifreeze/data/dmri.py b/src/nifreeze/data/dmri.py index 6a32496ca..69423c8c8 100644 --- a/src/nifreeze/data/dmri.py +++ b/src/nifreeze/data/dmri.py @@ -33,13 +33,15 @@ import h5py import nibabel as nb import numpy as np +from nibabel.spatialimages import SpatialImage from nitransforms.linear import Affine from nifreeze.data.base import BaseDataset, _cmp, _data_repr +from nifreeze.utils.ndimage import load_api @attr.s(slots=True) -class DWI(BaseDataset): +class DWI(BaseDataset[np.ndarray | None]): """Data representation structure for dMRI data.""" bzero = attr.ib(default=None, repr=_data_repr, eq=attr.cmp_using(eq=_cmp)) @@ -49,6 +51,10 @@ class DWI(BaseDataset): eddy_xfms = attr.ib(default=None) """List of transforms to correct for estimatted eddy current distortions.""" + def _getextra(self, idx: int | slice | tuple | np.ndarray) -> tuple[np.ndarray | None]: + return (self.gradients[..., idx] if self.gradients is not None else None,) + + # For the sake of the docstring def __getitem__( self, idx: int | slice | tuple | np.ndarray ) -> tuple[np.ndarray, np.ndarray | None, np.ndarray | None]: @@ -74,8 +80,7 @@ def __getitem__( """ - data, affine = super().__getitem__(idx) - return data, affine, self.gradients[..., idx] + return super().__getitem__(idx) @classmethod def from_filename(cls, filename: Path | str) -> DWI: @@ -276,7 +281,7 @@ def load( return DWI.from_filename(filename) # 2) Otherwise, load a NIfTI - img = nb.load(str(filename)) + img = load_api(filename, SpatialImage) fulldata = img.get_fdata(dtype=np.float32) affine = img.affine @@ -319,7 +324,7 @@ def load( # 6) b=0 volume (bzero) # If the user provided a b0_file, load it if b0_file: - b0img = nb.load(str(b0_file)) + b0img = load_api(b0_file, SpatialImage) b0vol = np.asanyarray(b0img.dataobj) # We'll assume your DWI class has a bzero: np.ndarray | None attribute dwi_obj.bzero = b0vol @@ -333,7 +338,7 @@ def load( # 7) If a brainmask_file was provided, load it if brainmask_file: - mask_img = nb.load(str(brainmask_file)) + mask_img = load_api(brainmask_file, SpatialImage) dwi_obj.brainmask = np.asanyarray(mask_img.dataobj, dtype=bool) return dwi_obj diff --git a/src/nifreeze/data/filtering.py b/src/nifreeze/data/filtering.py index 0f14604a8..6f5b725e5 100644 --- a/src/nifreeze/data/filtering.py +++ b/src/nifreeze/data/filtering.py @@ -77,8 +77,12 @@ def advanced_clip( # Calculate stats on denoised version to avoid outlier bias denoised = median_filter(data, footprint=ball(3)) - a_min = np.percentile(denoised[denoised >= 0] if nonnegative else denoised, p_min) - a_max = np.percentile(denoised[denoised >= 0] if nonnegative else denoised, p_max) + a_min = np.percentile( + np.asarray([denoised[denoised >= 0] if nonnegative else denoised]), p_min + ) + a_max = np.percentile( + np.asarray([denoised[denoised >= 0] if nonnegative else denoised]), p_max + ) # Clip and scale data data = np.clip(data, a_min=a_min, a_max=a_max) diff --git a/src/nifreeze/data/pet.py b/src/nifreeze/data/pet.py index dd0e45437..747532552 100644 --- a/src/nifreeze/data/pet.py +++ b/src/nifreeze/data/pet.py @@ -29,14 +29,15 @@ import attr import h5py -import nibabel as nb import numpy as np +from nibabel.spatialimages import SpatialImage from nifreeze.data.base import BaseDataset, _cmp, _data_repr +from nifreeze.utils.ndimage import load_api @attr.s(slots=True) -class PET(BaseDataset): +class PET(BaseDataset[np.ndarray | None]): """Data representation structure for PET data.""" frame_time: np.ndarray | None = attr.ib( @@ -46,6 +47,10 @@ class PET(BaseDataset): total_duration: float | None = attr.ib(default=None, repr=True) """A float representing the total duration of the dataset.""" + def _getextra(self, idx: int | slice | tuple | np.ndarray) -> tuple[np.ndarray | None]: + return (self.frame_time[idx] if self.frame_time is not None else None,) + + # For the sake of the docstring def __getitem__( self, idx: int | slice | tuple | np.ndarray ) -> tuple[np.ndarray, np.ndarray | None, np.ndarray | None]: @@ -69,9 +74,7 @@ def __getitem__( The frame time corresponding to the index(es). """ - - data, affine = super().__getitem__(idx) - return data, affine, self.frame_time[idx] + return super().__getitem__(idx) @classmethod def from_filename(cls, filename: Union[str, Path]) -> PET: @@ -181,7 +184,7 @@ def load( pet_obj = PET.from_filename(filename) else: # Load from NIfTI - img = nb.load(str(filename)) + img = load_api(filename, SpatialImage) data = img.get_fdata(dtype=np.float32) pet_obj = PET( dataobj=data, @@ -204,11 +207,12 @@ def load( # If the user doesn't provide frame_duration, we derive it: if frame_duration is None: - frame_time_arr = pet_obj.frame_time - # If shape is e.g. (N,), then we can do - durations = np.diff(frame_time_arr) - if len(durations) == (len(frame_time_arr) - 1): - durations = np.append(durations, durations[-1]) # last frame same as second-last + if pet_obj.frame_time is not None: + frame_time_arr = pet_obj.frame_time + # If shape is e.g. (N,), then we can do + durations = np.diff(frame_time_arr) + if len(durations) == (len(frame_time_arr) - 1): + durations = np.append(durations, durations[-1]) # last frame same as second-last else: durations = np.array(frame_duration, dtype=np.float32) @@ -218,7 +222,7 @@ def load( # If a brain mask is provided, load and attach if brainmask_file is not None: - mask_img = nb.load(str(brainmask_file)) + mask_img = load_api(brainmask_file, SpatialImage) pet_obj.brainmask = np.asanyarray(mask_img.dataobj, dtype=bool) return pet_obj diff --git a/src/nifreeze/model/_dipy.py b/src/nifreeze/model/_dipy.py index b501c8b64..4f4b30209 100644 --- a/src/nifreeze/model/_dipy.py +++ b/src/nifreeze/model/_dipy.py @@ -25,6 +25,7 @@ from __future__ import annotations import warnings +from typing import Any import numpy as np from dipy.core.gradients import GradientTable @@ -78,7 +79,9 @@ def gp_prediction( raise RuntimeError("Model is not yet fitted.") # Extract orientations from bvecs, and highly likely, the b-value too. - return model.predict(X, return_std=return_std) + orientations = model.predict(X, return_std=return_std) + assert isinstance(orientations, np.ndarray) + return orientations class GaussianProcessModel(ReconstModel): @@ -87,6 +90,7 @@ class GaussianProcessModel(ReconstModel): __slots__ = ( "kernel", "_modelfit", + "sigma_sq", ) def __init__( @@ -137,7 +141,7 @@ def fit( self, data: np.ndarray, gtab: GradientTable | np.ndarray, - mask: np.ndarray[bool] | None = None, + mask: np.ndarray[bool, Any] | None = None, random_state: int = 0, ) -> GPFit: """Fit method of the DTI model class diff --git a/src/nifreeze/model/gpr.py b/src/nifreeze/model/gpr.py index 19f88bf4d..a3c612bbb 100644 --- a/src/nifreeze/model/gpr.py +++ b/src/nifreeze/model/gpr.py @@ -25,11 +25,12 @@ from __future__ import annotations from numbers import Integral, Real -from typing import Callable, Mapping, Sequence +from typing import Callable, ClassVar, Literal, Mapping, Optional, Sequence, Union import numpy as np +import numpy.typing as npt from scipy import optimize -from scipy.optimize._minimize import Bounds +from scipy.optimize import Bounds from sklearn.gaussian_process import GaussianProcessRegressor from sklearn.gaussian_process.kernels import ( Hyperparameter, @@ -153,7 +154,9 @@ class DiffusionGPR(GaussianProcessRegressor): """ - _parameter_constraints: dict = { + optimizer: Optional[Union[StrOptions, Callable, None]] = None + + _parameter_constraints: ClassVar[dict] = { "kernel": [None, Kernel], "alpha": [Interval(Real, 0, None, closed="left"), np.ndarray], "optimizer": [StrOptions(SUPPORTED_OPTIMIZERS), callable, None], @@ -169,7 +172,7 @@ def __init__( kernel: Kernel | None = None, *, alpha: float = 0.5, - optimizer: str | Callable | None = "fmin_l_bfgs_b", + optimizer: Literal["fmin_l_bfgs_b"] | Callable | None = "fmin_l_bfgs_b", n_restarts_optimizer: int = 0, copy_X_train: bool = True, normalize_y: bool = True, @@ -212,7 +215,7 @@ def _constrained_optimization( ) -> tuple[float, float]: options = {} if self.optimizer == "fmin_l_bfgs_b": - from sklearn.utils.optimize import _check_optimize_result + from sklearn.utils.optimize import _check_optimize_result # type: ignore for name in LBFGS_CONFIGURABLE_OPTIONS: if (value := getattr(self, name, None)) is not None: @@ -227,7 +230,7 @@ def _constrained_optimization( options=options, args=(self.eval_gradient,), tol=self.tol, - ) + ) # type: ignore[call-overload] _check_optimize_result("lbfgs", opt_res) return opt_res.x, opt_res.fun @@ -332,7 +335,7 @@ def __call__( return self.beta_l * C_theta, K_gradient - def diag(self, X: np.ndarray) -> np.ndarray: + def diag(self, X: npt.ArrayLike) -> np.ndarray: """Returns the diagonal of the kernel k(X, X). The result of this method is identical to np.diag(self(X)); however, @@ -349,7 +352,7 @@ def diag(self, X: np.ndarray) -> np.ndarray: K_diag : :obj:`~numpy.ndarray` of shape (n_samples_X,) Diagonal of kernel k(X, X) """ - return self.beta_l * np.ones(X.shape[0]) + return self.beta_l * np.ones(np.asanyarray(X).shape[0]) def is_stationary(self) -> bool: """Returns whether the kernel is stationary.""" @@ -442,7 +445,7 @@ def __call__( return self.beta_l * C_theta, K_gradient - def diag(self, X: np.ndarray) -> np.ndarray: + def diag(self, X: npt.ArrayLike) -> np.ndarray: """Returns the diagonal of the kernel k(X, X). The result of this method is identical to np.diag(self(X)); however, @@ -459,7 +462,7 @@ def diag(self, X: np.ndarray) -> np.ndarray: K_diag : :obj:`~numpy.ndarray` of shape (n_samples_X,) Diagonal of kernel k(X, X) """ - return self.beta_l * np.ones(X.shape[0]) + return self.beta_l * np.ones(np.asanyarray(X).shape[0]) def is_stationary(self) -> bool: """Returns whether the kernel is stationary.""" diff --git a/src/nifreeze/registration/ants.py b/src/nifreeze/registration/ants.py index f6bba237d..712a0874d 100644 --- a/src/nifreeze/registration/ants.py +++ b/src/nifreeze/registration/ants.py @@ -213,7 +213,7 @@ def generate_command( movingmask_path: str | Path | list[str] | None = None, init_affine: str | Path | None = None, default: str = "b0-to-b0_level0", - **kwargs: dict, + **kwargs, ) -> str: """ Generate an ANTs' command line. @@ -412,7 +412,7 @@ def _run_registration( i_iter: int, vol_idx: int, dirname: Path, - reg_target_type: str, + reg_target_type: str | tuple[str, str], align_kwargs: dict, ) -> nt.base.BaseTransform: """ @@ -440,7 +440,7 @@ def _run_registration( DWI frame index. dirname : :obj:`Path` Directory name where the transformation is saved. - reg_target_type : :obj:`str` + reg_target_type : :obj:`str` or tuple of :obj:`str` Target registration type. align_kwargs : :obj:`dict` Parameters to configure the image registration process. diff --git a/src/nifreeze/utils/iterators.py b/src/nifreeze/utils/iterators.py index 83886e5d3..0cf52d211 100644 --- a/src/nifreeze/utils/iterators.py +++ b/src/nifreeze/utils/iterators.py @@ -27,7 +27,7 @@ from typing import Iterator -def linear_iterator(size: int = None, **kwargs) -> Iterator[int]: +def linear_iterator(size: int | None = None, **kwargs) -> Iterator[int]: """ Traverse the dataset volumes in ascending order. @@ -53,10 +53,10 @@ def linear_iterator(size: int = None, **kwargs) -> Iterator[int]: if size is None: raise TypeError("Cannot build iterator without size") - return range(size) + return iter(range(size)) -def random_iterator(size: int = None, **kwargs) -> Iterator[int]: +def random_iterator(size: int | None = None, **kwargs) -> Iterator[int]: """ Traverse the dataset volumes randomly. @@ -105,7 +105,7 @@ def random_iterator(size: int = None, **kwargs) -> Iterator[int]: return (x for x in index_order) -def bvalue_iterator(size: int = None, **kwargs) -> Iterator[int]: +def bvalue_iterator(size: int | None = None, **kwargs) -> Iterator[int]: """ Traverse the volumes in a DWI dataset by growing b-value. @@ -132,7 +132,7 @@ def bvalue_iterator(size: int = None, **kwargs) -> Iterator[int]: return (index[1] for index in indexed_bvals) -def centralsym_iterator(size: int = None, **kwargs) -> Iterator[int]: +def centralsym_iterator(size: int | None = None, **kwargs) -> Iterator[int]: """ Traverse the dataset starting from the center and alternatingly progressing to the sides. diff --git a/src/nifreeze/viz/signals.py b/src/nifreeze/viz/signals.py index 37798e066..5c2db0b07 100644 --- a/src/nifreeze/viz/signals.py +++ b/src/nifreeze/viz/signals.py @@ -37,7 +37,7 @@ def plot_error( ylabel: str, title: str, color: str = "orange", - figsize: tuple[int, int] = (19.2, 10.8), + figsize: tuple[float, float] = (19.2, 10.8), ) -> plt.Figure: """ Plot the error and standard deviation. @@ -74,7 +74,7 @@ def plot_error( ax.set_xlabel(xlabel) ax.set_ylabel(ylabel) ax.set_xticks(kfolds) - ax.set_xticklabels(kfolds) + ax.set_xticklabels(map(str, kfolds)) ax.set_title(title) fig.tight_layout() return fig diff --git a/test/test_data_base.py b/test/test_data_base.py index 7f2c79562..4afa64464 100644 --- a/test/test_data_base.py +++ b/test/test_data_base.py @@ -53,7 +53,7 @@ def test_len(random_dataset: BaseDataset): assert len(random_dataset) == 5 # last dimension is 5 volumes -def test_getitem_volume_index(random_dataset: BaseDataset): +def test_getitem_volume_index(random_dataset: BaseDataset[()]): """ Test that __getitem__ returns the correct (volume, affine) tuple. @@ -71,7 +71,7 @@ def test_getitem_volume_index(random_dataset: BaseDataset): assert aff_slice is None -def test_set_transform(random_dataset: BaseDataset): +def test_set_transform(random_dataset: BaseDataset[()]): """ Test that calling set_transform changes the data and motion_affines. For simplicity, we'll apply an identity transform and check that motion_affines is updated. @@ -90,6 +90,7 @@ def test_set_transform(random_dataset: BaseDataset): assert random_dataset.motion_affines is not None np.testing.assert_array_equal(random_dataset.motion_affines[idx], affine) # The returned affine from __getitem__ should be the same. + assert aff0 is not None np.testing.assert_array_equal(aff0, affine) @@ -121,7 +122,7 @@ def test_to_nifti(random_dataset: BaseDataset): assert nifti_file.is_file() # Load the saved file with nibabel - img = nb.load(nifti_file) + img = nb.Nifti1Image.from_filename(nifti_file) data = img.get_fdata(dtype=np.float32) assert data.shape == (32, 32, 32, 5) assert np.allclose(data, random_dataset.dataobj) diff --git a/test/test_gpr.py b/test/test_gpr.py index 8d1997488..b714678af 100644 --- a/test/test_gpr.py +++ b/test/test_gpr.py @@ -20,7 +20,6 @@ # # https://www.nipreps.org/community/licensing/ # -from collections import namedtuple import numpy as np import pytest @@ -28,9 +27,6 @@ from nifreeze.model import gpr -GradientTablePatch = namedtuple("gtab", ["bvals", "bvecs"]) - - THETAS = np.linspace(0, np.pi / 2, num=50) EXPECTED_EXPONENTIAL = [ 1.0, diff --git a/tox.ini b/tox.ini index e3672d1fa..089d30f82 100644 --- a/tox.ini +++ b/tox.ini @@ -76,6 +76,15 @@ extras = doc commands = make -C docs/ SPHINXOPTS="-W -v" BUILDDIR="$HOME/docs" OUTDIR="${CURBRANCH:-html}" html +[testenv:typecheck] +description = Run mypy type checking +labels = check +deps = + mypy +extras = types +commands = + mypy . + [testenv:spellcheck] description = Check spelling labels = check