Skip to content

Commit f475901

Browse files
committed
TYP: Annotate SpatialImage and SpatialHeader
1 parent 7d263bd commit f475901

File tree

1 file changed

+115
-68
lines changed

1 file changed

+115
-68
lines changed

nibabel/spatialimages.py

Lines changed: 115 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -131,13 +131,15 @@
131131
"""
132132
from __future__ import annotations
133133

134-
from typing import Type
134+
import io
135+
import typing as ty
136+
from typing import Literal, Sequence
135137

136138
import numpy as np
137139

140+
from .arrayproxy import ArrayLike
138141
from .dataobj_images import DataobjImage
139-
from .filebasedimages import ImageFileError # noqa
140-
from .filebasedimages import FileBasedHeader
142+
from .filebasedimages import FileBasedHeader, FileBasedImage, FileMap
141143
from .fileslice import canonical_slicers
142144
from .orientations import apply_orientation, inv_ornt_aff
143145
from .viewers import OrthoSlicer3D
@@ -148,6 +150,32 @@
148150
except ImportError: # PY38
149151
from functools import lru_cache as cache
150152

153+
if ty.TYPE_CHECKING: # pragma: no cover
154+
import numpy.typing as npt
155+
156+
SpatialImgT = ty.TypeVar('SpatialImgT', bound='SpatialImage')
157+
SpatialHdrT = ty.TypeVar('SpatialHdrT', bound='SpatialHeader')
158+
159+
160+
class HasDtype(ty.Protocol):
161+
def get_data_dtype(self) -> np.dtype:
162+
... # pragma: no cover
163+
164+
def set_data_dtype(self, dtype: npt.DTypeLike) -> None:
165+
... # pragma: no cover
166+
167+
168+
@ty.runtime_checkable
169+
class SpatialProtocol(ty.Protocol):
170+
def get_data_dtype(self) -> np.dtype:
171+
... # pragma: no cover
172+
173+
def get_data_shape(self) -> ty.Tuple[int, ...]:
174+
... # pragma: no cover
175+
176+
def get_zooms(self) -> ty.Tuple[float, ...]:
177+
... # pragma: no cover
178+
151179

152180
class HeaderDataError(Exception):
153181
"""Class to indicate error in getting or setting header data"""
@@ -157,21 +185,33 @@ class HeaderTypeError(Exception):
157185
"""Class to indicate error in parameters into header functions"""
158186

159187

160-
class SpatialHeader(FileBasedHeader):
188+
class SpatialHeader(FileBasedHeader, SpatialProtocol):
161189
"""Template class to implement header protocol"""
162190

163-
default_x_flip = True
164-
data_layout = 'F'
191+
default_x_flip: bool = True
192+
data_layout: Literal['F', 'C'] = 'F'
165193

166-
def __init__(self, data_dtype=np.float32, shape=(0,), zooms=None):
194+
_dtype: np.dtype
195+
_shape: tuple[int, ...]
196+
_zooms: tuple[float, ...]
197+
198+
def __init__(
199+
self,
200+
data_dtype: npt.DTypeLike = np.float32,
201+
shape: Sequence[int] = (0,),
202+
zooms: Sequence[float] | None = None,
203+
):
167204
self.set_data_dtype(data_dtype)
168205
self._zooms = ()
169206
self.set_data_shape(shape)
170207
if zooms is not None:
171208
self.set_zooms(zooms)
172209

173210
@classmethod
174-
def from_header(klass, header=None):
211+
def from_header(
212+
klass: type[SpatialHdrT],
213+
header: SpatialProtocol | FileBasedHeader | ty.Mapping | None = None,
214+
) -> SpatialHdrT:
175215
if header is None:
176216
return klass()
177217
# I can't do isinstance here because it is not necessarily true
@@ -180,74 +220,68 @@ def from_header(klass, header=None):
180220
# different field names
181221
if type(header) == klass:
182222
return header.copy()
183-
return klass(header.get_data_dtype(), header.get_data_shape(), header.get_zooms())
184-
185-
@classmethod
186-
def from_fileobj(klass, fileobj):
187-
raise NotImplementedError
188-
189-
def write_to(self, fileobj):
190-
raise NotImplementedError
191-
192-
def __eq__(self, other):
193-
return (self.get_data_dtype(), self.get_data_shape(), self.get_zooms()) == (
194-
other.get_data_dtype(),
195-
other.get_data_shape(),
196-
other.get_zooms(),
197-
)
198-
199-
def __ne__(self, other):
200-
return not self == other
223+
if isinstance(header, SpatialProtocol):
224+
return klass(header.get_data_dtype(), header.get_data_shape(), header.get_zooms())
225+
return super().from_header(header)
226+
227+
def __eq__(self, other: object) -> bool:
228+
if isinstance(other, SpatialHeader):
229+
return (self.get_data_dtype(), self.get_data_shape(), self.get_zooms()) == (
230+
other.get_data_dtype(),
231+
other.get_data_shape(),
232+
other.get_zooms(),
233+
)
234+
return NotImplemented
201235

202-
def copy(self):
236+
def copy(self: SpatialHdrT) -> SpatialHdrT:
203237
"""Copy object to independent representation
204238
205239
The copy should not be affected by any changes to the original
206240
object.
207241
"""
208242
return self.__class__(self._dtype, self._shape, self._zooms)
209243

210-
def get_data_dtype(self):
244+
def get_data_dtype(self) -> np.dtype:
211245
return self._dtype
212246

213-
def set_data_dtype(self, dtype):
247+
def set_data_dtype(self, dtype: npt.DTypeLike) -> None:
214248
self._dtype = np.dtype(dtype)
215249

216-
def get_data_shape(self):
250+
def get_data_shape(self) -> tuple[int, ...]:
217251
return self._shape
218252

219-
def set_data_shape(self, shape):
253+
def set_data_shape(self, shape: Sequence[int]) -> None:
220254
ndim = len(shape)
221255
if ndim == 0:
222256
self._shape = (0,)
223257
self._zooms = (1.0,)
224258
return
225-
self._shape = tuple([int(s) for s in shape])
259+
self._shape = tuple(int(s) for s in shape)
226260
# set any unset zooms to 1.0
227261
nzs = min(len(self._zooms), ndim)
228262
self._zooms = self._zooms[:nzs] + (1.0,) * (ndim - nzs)
229263

230-
def get_zooms(self):
264+
def get_zooms(self) -> tuple[float, ...]:
231265
return self._zooms
232266

233-
def set_zooms(self, zooms):
234-
zooms = tuple([float(z) for z in zooms])
267+
def set_zooms(self, zooms: Sequence[float]) -> None:
268+
zooms = tuple(float(z) for z in zooms)
235269
shape = self.get_data_shape()
236270
ndim = len(shape)
237271
if len(zooms) != ndim:
238272
raise HeaderDataError('Expecting %d zoom values for ndim %d' % (ndim, ndim))
239-
if len([z for z in zooms if z < 0]):
273+
if any(z < 0 for z in zooms):
240274
raise HeaderDataError('zooms must be positive')
241275
self._zooms = zooms
242276

243-
def get_base_affine(self):
277+
def get_base_affine(self) -> np.ndarray:
244278
shape = self.get_data_shape()
245279
zooms = self.get_zooms()
246280
return shape_zoom_affine(shape, zooms, self.default_x_flip)
247281

248282
get_best_affine = get_base_affine
249283

250-
def data_to_fileobj(self, data, fileobj, rescale=True):
284+
def data_to_fileobj(self, data: npt.ArrayLike, fileobj: io.IOBase, rescale: bool = True):
251285
"""Write array data `data` as binary to `fileobj`
252286
253287
Parameters
@@ -264,7 +298,7 @@ def data_to_fileobj(self, data, fileobj, rescale=True):
264298
dtype = self.get_data_dtype()
265299
fileobj.write(data.astype(dtype).tobytes(order=self.data_layout))
266300

267-
def data_from_fileobj(self, fileobj):
301+
def data_from_fileobj(self, fileobj: io.IOBase) -> np.ndarray:
268302
"""Read binary image data from `fileobj`"""
269303
dtype = self.get_data_dtype()
270304
shape = self.get_data_shape()
@@ -274,7 +308,7 @@ def data_from_fileobj(self, fileobj):
274308

275309

276310
@cache
277-
def _supported_np_types(klass):
311+
def _supported_np_types(klass: type[HasDtype]) -> set[type[np.generic]]:
278312
"""Numpy data types that instances of ``klass`` support
279313
280314
Parameters
@@ -308,7 +342,7 @@ def _supported_np_types(klass):
308342
return supported
309343

310344

311-
def supported_np_types(obj):
345+
def supported_np_types(obj: HasDtype) -> set[type[np.generic]]:
312346
"""Numpy data types that instance `obj` supports
313347
314348
Parameters
@@ -330,13 +364,15 @@ class ImageDataError(Exception):
330364
pass
331365

332366

333-
class SpatialFirstSlicer:
367+
class SpatialFirstSlicer(ty.Generic[SpatialImgT]):
334368
"""Slicing interface that returns a new image with an updated affine
335369
336370
Checks that an image's first three axes are spatial
337371
"""
338372

339-
def __init__(self, img):
373+
img: SpatialImgT
374+
375+
def __init__(self, img: SpatialImgT):
340376
# Local import to avoid circular import on module load
341377
from .imageclasses import spatial_axes_first
342378

@@ -346,7 +382,7 @@ def __init__(self, img):
346382
)
347383
self.img = img
348384

349-
def __getitem__(self, slicer):
385+
def __getitem__(self, slicer: object) -> SpatialImgT:
350386
try:
351387
slicer = self.check_slicing(slicer)
352388
except ValueError as err:
@@ -359,7 +395,7 @@ def __getitem__(self, slicer):
359395
affine = self.slice_affine(slicer)
360396
return self.img.__class__(dataobj.copy(), affine, self.img.header)
361397

362-
def check_slicing(self, slicer, return_spatial=False):
398+
def check_slicing(self, slicer: object, return_spatial: bool = False) -> tuple[slice, ...]:
363399
"""Canonicalize slicers and check for scalar indices in spatial dims
364400
365401
Parameters
@@ -376,21 +412,21 @@ def check_slicing(self, slicer, return_spatial=False):
376412
Validated slicer object that will slice image's `dataobj`
377413
without collapsing spatial dimensions
378414
"""
379-
slicer = canonical_slicers(slicer, self.img.shape)
415+
canonical = canonical_slicers(slicer, self.img.shape)
380416
# We can get away with this because we've checked the image's
381417
# first three axes are spatial.
382418
# More general slicers will need to be smarter, here.
383-
spatial_slices = slicer[:3]
419+
spatial_slices = canonical[:3]
384420
for subslicer in spatial_slices:
385421
if subslicer is None:
386422
raise IndexError('New axis not permitted in spatial dimensions')
387423
elif isinstance(subslicer, int):
388424
raise IndexError(
389425
'Scalar indices disallowed in spatial dimensions; Use `[x]` or `x:x+1`.'
390426
)
391-
return spatial_slices if return_spatial else slicer
427+
return spatial_slices if return_spatial else canonical
392428

393-
def slice_affine(self, slicer):
429+
def slice_affine(self, slicer: tuple[slice, ...]) -> np.ndarray:
394430
"""Retrieve affine for current image, if sliced by a given index
395431
396432
Applies scaling if down-sampling is applied, and adjusts the intercept
@@ -430,10 +466,19 @@ def slice_affine(self, slicer):
430466
class SpatialImage(DataobjImage):
431467
"""Template class for volumetric (3D/4D) images"""
432468

433-
header_class: Type[SpatialHeader] = SpatialHeader
434-
ImageSlicer = SpatialFirstSlicer
469+
header_class: type[SpatialHeader] = SpatialHeader
470+
ImageSlicer: type[SpatialFirstSlicer] = SpatialFirstSlicer
471+
472+
_header: SpatialHeader
435473

436-
def __init__(self, dataobj, affine, header=None, extra=None, file_map=None):
474+
def __init__(
475+
self,
476+
dataobj: ArrayLike,
477+
affine: np.ndarray,
478+
header: FileBasedHeader | ty.Mapping | None = None,
479+
extra: ty.Mapping | None = None,
480+
file_map: FileMap | None = None,
481+
):
437482
"""Initialize image
438483
439484
The image is a combination of (array-like, affine matrix, header), with
@@ -483,7 +528,7 @@ def __init__(self, dataobj, affine, header=None, extra=None, file_map=None):
483528
def affine(self):
484529
return self._affine
485530

486-
def update_header(self):
531+
def update_header(self) -> None:
487532
"""Harmonize header with image data and affine
488533
489534
>>> data = np.zeros((2,3,4))
@@ -512,7 +557,7 @@ def update_header(self):
512557
return
513558
self._affine2header()
514559

515-
def _affine2header(self):
560+
def _affine2header(self) -> None:
516561
"""Unconditionally set affine into the header"""
517562
RZS = self._affine[:3, :3]
518563
vox = np.sqrt(np.sum(RZS * RZS, axis=0))
@@ -522,7 +567,7 @@ def _affine2header(self):
522567
zooms[:n_to_set] = vox[:n_to_set]
523568
hdr.set_zooms(zooms)
524569

525-
def __str__(self):
570+
def __str__(self) -> str:
526571
shape = self.shape
527572
affine = self.affine
528573
return f"""
@@ -534,14 +579,14 @@ def __str__(self):
534579
{self._header}
535580
"""
536581

537-
def get_data_dtype(self):
582+
def get_data_dtype(self) -> np.dtype:
538583
return self._header.get_data_dtype()
539584

540-
def set_data_dtype(self, dtype):
585+
def set_data_dtype(self, dtype: npt.DTypeLike) -> None:
541586
self._header.set_data_dtype(dtype)
542587

543588
@classmethod
544-
def from_image(klass, img):
589+
def from_image(klass: type[SpatialImgT], img: SpatialImage | FileBasedImage) -> SpatialImgT:
545590
"""Class method to create new instance of own class from `img`
546591
547592
Parameters
@@ -555,15 +600,17 @@ def from_image(klass, img):
555600
cimg : ``spatialimage`` instance
556601
Image, of our own class
557602
"""
558-
return klass(
559-
img.dataobj,
560-
img.affine,
561-
klass.header_class.from_header(img.header),
562-
extra=img.extra.copy(),
563-
)
603+
if isinstance(img, SpatialImage):
604+
return klass(
605+
img.dataobj,
606+
img.affine,
607+
klass.header_class.from_header(img.header),
608+
extra=img.extra.copy(),
609+
)
610+
return super().from_image(img)
564611

565612
@property
566-
def slicer(self):
613+
def slicer(self: SpatialImgT) -> SpatialFirstSlicer[SpatialImgT]:
567614
"""Slicer object that returns cropped and subsampled images
568615
569616
The image is resliced in the current orientation; no rotation or
@@ -582,7 +629,7 @@ def slicer(self):
582629
"""
583630
return self.ImageSlicer(self)
584631

585-
def __getitem__(self, idx):
632+
def __getitem__(self, idx: object) -> None:
586633
"""No slicing or dictionary interface for images
587634
588635
Use the slicer attribute to perform cropping and subsampling at your
@@ -595,7 +642,7 @@ def __getitem__(self, idx):
595642
'`img.get_fdata()[slice]`'
596643
)
597644

598-
def orthoview(self):
645+
def orthoview(self) -> OrthoSlicer3D:
599646
"""Plot the image using OrthoSlicer3D
600647
601648
Returns
@@ -611,7 +658,7 @@ def orthoview(self):
611658
"""
612659
return OrthoSlicer3D(self.dataobj, self.affine, title=self.get_filename())
613660

614-
def as_reoriented(self, ornt):
661+
def as_reoriented(self: SpatialImgT, ornt: Sequence[Sequence[int]]) -> SpatialImgT:
615662
"""Apply an orientation change and return a new image
616663
617664
If ornt is identity transform, return the original image, unchanged

0 commit comments

Comments
 (0)