Skip to content

Commit 5bb1ad2

Browse files
committed
type: Annotate SpatialImage subclasses with affine information
1 parent dc96ff0 commit 5bb1ad2

File tree

10 files changed

+108
-52
lines changed

10 files changed

+108
-52
lines changed

nibabel/analyze.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -84,13 +84,15 @@
8484

8585
from __future__ import annotations
8686

87+
import typing as ty
88+
8789
import numpy as np
8890

8991
from .arrayproxy import ArrayProxy
9092
from .arraywriters import ArrayWriter, WriterError, get_slope_inter, make_array_writer
9193
from .batteryrunners import Report
9294
from .fileholders import copy_file_map
93-
from .spatialimages import HeaderDataError, HeaderTypeError, SpatialHeader, SpatialImage
95+
from .spatialimages import AffT, HeaderDataError, HeaderTypeError, SpatialHeader, SpatialImage
9496
from .volumeutils import (
9597
apply_read_scaling,
9698
array_from_file,
@@ -102,6 +104,13 @@
102104
)
103105
from .wrapstruct import LabeledWrapStruct
104106

107+
if ty.TYPE_CHECKING:
108+
from collections.abc import Mapping
109+
110+
from .arrayproxy import ArrayLike
111+
from .filebasedimages import FileBasedHeader
112+
from .fileholders import FileMap
113+
105114
# Sub-parts of standard analyze header from
106115
# Mayo dbh.h file
107116
header_key_dtd = [
@@ -893,11 +902,12 @@ def may_contain_header(klass, binaryblock):
893902
return 348 in (hdr_struct['sizeof_hdr'], bs_hdr_struct['sizeof_hdr'])
894903

895904

896-
class AnalyzeImage(SpatialImage):
905+
class AnalyzeImage(SpatialImage[AffT]):
897906
"""Class for basic Analyze format image"""
898907

899908
header_class: type[AnalyzeHeader] = AnalyzeHeader
900909
header: AnalyzeHeader
910+
_header: AnalyzeHeader
901911
_meta_sniff_len = header_class.sizeof_hdr
902912
files_types: tuple[tuple[str, str], ...] = (('image', '.img'), ('header', '.hdr'))
903913
valid_exts: tuple[str, ...] = ('.img', '.hdr')
@@ -908,7 +918,15 @@ class AnalyzeImage(SpatialImage):
908918

909919
ImageArrayProxy = ArrayProxy
910920

911-
def __init__(self, dataobj, affine, header=None, extra=None, file_map=None, dtype=None):
921+
def __init__(
922+
self,
923+
dataobj: ArrayLike,
924+
affine: AffT,
925+
header: FileBasedHeader | Mapping | None = None,
926+
extra: Mapping | None = None,
927+
file_map: FileMap | None = None,
928+
dtype=None,
929+
) -> None:
912930
super().__init__(dataobj, affine, header, extra, file_map)
913931
# Reset consumable values
914932
self._header.set_data_offset(0)

nibabel/brikhead.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535

3636
from .arrayproxy import ArrayProxy
3737
from .fileslice import strided_scalar
38-
from .spatialimages import HeaderDataError, ImageDataError, SpatialHeader, SpatialImage
38+
from .spatialimages import Affine, HeaderDataError, ImageDataError, SpatialHeader, SpatialImage
3939
from .volumeutils import Recoder
4040

4141
# used for doc-tests
@@ -453,7 +453,7 @@ def get_volume_labels(self):
453453
return labels
454454

455455

456-
class AFNIImage(SpatialImage):
456+
class AFNIImage(SpatialImage[Affine]):
457457
"""
458458
AFNI Image file
459459

nibabel/ecat.py

Lines changed: 47 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -43,17 +43,30 @@
4343
below). It's not clear what the licenses are for these files.
4444
"""
4545

46+
from __future__ import annotations
47+
4648
import warnings
4749
from numbers import Integral
50+
from typing import TYPE_CHECKING
4851

4952
import numpy as np
5053

5154
from .arraywriters import make_array_writer
5255
from .fileslice import canonical_slicers, predict_shape, slice2outax
53-
from .spatialimages import SpatialHeader, SpatialImage
56+
from .spatialimages import Affine, AffT, SpatialHeader, SpatialImage
5457
from .volumeutils import array_from_file, make_dt_codes, native_code, swapped_code
5558
from .wrapstruct import WrapStruct
5659

60+
if TYPE_CHECKING:
61+
from collections.abc import Mapping
62+
from typing import Literal as L
63+
64+
import numpy.typing as npt
65+
66+
from .arrayproxy import ArrayLike
67+
from .filebasedimages import FileBasedHeader
68+
from .fileholders import FileMap
69+
5770
BLOCK_SIZE = 512
5871

5972
main_header_dtd = [
@@ -743,7 +756,7 @@ def __getitem__(self, sliceobj):
743756
return out_data
744757

745758

746-
class EcatImage(SpatialImage):
759+
class EcatImage(SpatialImage[AffT]):
747760
"""Class returns a list of Ecat images, with one image(hdr/data) per frame"""
748761

749762
header_class = EcatHeader
@@ -756,7 +769,16 @@ class EcatImage(SpatialImage):
756769

757770
ImageArrayProxy = EcatImageArrayProxy
758771

759-
def __init__(self, dataobj, affine, header, subheader, mlist, extra=None, file_map=None):
772+
def __init__(
773+
self,
774+
dataobj: ArrayLike,
775+
affine: AffT,
776+
header: FileBasedHeader | Mapping | None,
777+
subheader: EcatSubHeader,
778+
mlist: npt.NDArray[np.integer],
779+
extra: Mapping | None = None,
780+
file_map: FileMap | None = None,
781+
) -> None:
760782
"""Initialize Image
761783
762784
The image is a combination of
@@ -798,40 +820,38 @@ def __init__(self, dataobj, affine, header, subheader, mlist, extra=None, file_m
798820
>>> data4d.shape == (10, 10, 3, 1)
799821
True
800822
"""
823+
super().__init__(
824+
dataobj=dataobj,
825+
affine=affine,
826+
header=header,
827+
extra=extra,
828+
file_map=file_map,
829+
)
801830
self._subheader = subheader
802831
self._mlist = mlist
803-
self._dataobj = dataobj
804-
if affine is not None:
805-
# Check that affine is array-like 4,4. Maybe this is too strict at
806-
# this abstract level, but so far I think all image formats we know
807-
# do need 4,4.
808-
affine = np.array(affine, dtype=np.float64, copy=True)
809-
if not affine.shape == (4, 4):
810-
raise ValueError('Affine should be shape 4,4')
811-
self._affine = affine
812-
if extra is None:
813-
extra = {}
814-
self.extra = extra
815-
self._header = header
816-
if file_map is None:
817-
file_map = self.__class__.make_file_map()
818-
self.file_map = file_map
819-
self._data_cache = None
820-
self._fdata_cache = None
832+
833+
# Override SpatialImage default, which attempts to set the
834+
# affine in the header.
835+
def update_header(self) -> None:
836+
"""Does nothing"""
821837

822838
@property
823-
def affine(self):
839+
def affine(self) -> AffT:
824840
if not self._subheader._check_affines():
825841
warnings.warn(
826842
'Affines different across frames, loading affine from FIRST frame', UserWarning
827843
)
828844
return self._affine
829845

830-
def get_frame_affine(self, frame):
846+
def get_frame_affine(self, frame: int) -> Affine:
831847
"""returns 4X4 affine"""
832848
return self._subheader.get_frame_affine(frame=frame)
833849

834-
def get_frame(self, frame, orientation=None):
850+
def get_frame(
851+
self,
852+
frame: int,
853+
orientation: L['neurological', 'radiological'] | None = None,
854+
) -> np.ndarray:
835855
"""
836856
Get full volume for a time frame
837857
@@ -847,16 +867,16 @@ def get_data_dtype(self, frame):
847867
return dt
848868

849869
@property
850-
def shape(self):
870+
def shape(self) -> tuple[int, int, int, int]:
851871
x, y, z = self._subheader.get_shape()
852872
nframes = self._subheader.get_nframes()
853873
return (x, y, z, nframes)
854874

855-
def get_mlist(self):
875+
def get_mlist(self) -> npt.NDArray[np.integer]:
856876
"""get access to the mlist"""
857877
return self._mlist
858878

859-
def get_subheaders(self):
879+
def get_subheaders(self) -> EcatSubHeader:
860880
"""get access to subheaders"""
861881
return self._subheader
862882

nibabel/freesurfer/mghformat.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from ..fileholders import FileHolder
2323
from ..filename_parser import _stringify_path
2424
from ..openers import ImageOpener
25-
from ..spatialimages import HeaderDataError, SpatialHeader, SpatialImage
25+
from ..spatialimages import Affine, HeaderDataError, SpatialHeader, SpatialImage
2626
from ..volumeutils import Recoder, array_from_file, array_to_file, endian_codes
2727
from ..wrapstruct import LabeledWrapStruct
2828

@@ -459,7 +459,7 @@ def diagnose_binaryblock(klass, binaryblock, endianness=None):
459459
return '\n'.join([report.message for report in reports if report.message])
460460

461461

462-
class MGHImage(SpatialImage, SerializableImage):
462+
class MGHImage(SpatialImage[Affine], SerializableImage):
463463
"""Class for MGH format image"""
464464

465465
header_class = MGHHeader

nibabel/minc1.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
from .externals.netcdf import netcdf_file
1818
from .fileslice import canonical_slicers
19-
from .spatialimages import SpatialHeader, SpatialImage
19+
from .spatialimages import Affine, SpatialHeader, SpatialImage
2020

2121
_dt_dict = {
2222
('b', 'unsigned'): np.uint8,
@@ -299,7 +299,7 @@ def may_contain_header(klass, binaryblock):
299299
return binaryblock[:4] == b'CDF\x01'
300300

301301

302-
class Minc1Image(SpatialImage):
302+
class Minc1Image(SpatialImage[Affine]):
303303
"""Class for MINC1 format images
304304
305305
The MINC1 image class uses the default header type, rather than a specific

nibabel/nifti1.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,13 +35,19 @@
3535
from .filebasedimages import ImageFileError, SerializableImage
3636
from .optpkg import optional_package
3737
from .quaternions import fillpositive, mat2quat, quat2mat
38-
from .spatialimages import HeaderDataError
38+
from .spatialimages import AffT, HeaderDataError
3939
from .spm99analyze import SpmAnalyzeHeader
4040
from .volumeutils import Recoder, endian_codes, make_dt_codes
4141

4242
if ty.TYPE_CHECKING:
43+
from collections.abc import Mapping
44+
4345
import pydicom as pdcm
4446

47+
from .arrayproxy import ArrayLike
48+
from .filebasedimages import FileBasedHeader
49+
from .fileholders import FileMap
50+
4551
have_dicom = True
4652
DicomDataset = pdcm.Dataset
4753
else:
@@ -1971,19 +1977,28 @@ class Nifti1PairHeader(Nifti1Header):
19711977
is_single = False
19721978

19731979

1974-
class Nifti1Pair(analyze.AnalyzeImage):
1980+
class Nifti1Pair(analyze.AnalyzeImage[AffT]):
19751981
"""Class for NIfTI1 format image, header pair"""
19761982

19771983
header_class: type[Nifti1Header] = Nifti1PairHeader
19781984
header: Nifti1Header
1985+
_header: Nifti1Header
19791986
_meta_sniff_len = header_class.sizeof_hdr
19801987
rw = True
19811988

19821989
# If a _dtype_alias has been set, it can only be resolved by inspecting
19831990
# the data at serialization time
19841991
_dtype_alias = None
19851992

1986-
def __init__(self, dataobj, affine, header=None, extra=None, file_map=None, dtype=None):
1993+
def __init__(
1994+
self,
1995+
dataobj: ArrayLike,
1996+
affine: AffT,
1997+
header: FileBasedHeader | Mapping | None = None,
1998+
extra: Mapping | None = None,
1999+
file_map: FileMap | None = None,
2000+
dtype=None,
2001+
) -> None:
19872002
# Special carve-out for 64 bit integers
19882003
# See GitHub issues
19892004
# * https://github.com/nipy/nibabel/issues/1046
@@ -1994,7 +2009,7 @@ def __init__(self, dataobj, affine, header=None, extra=None, file_map=None, dtyp
19942009
danger_dts = (np.dtype('int64'), np.dtype('uint64'))
19952010
if header is None and dtype is None and get_obj_dtype(dataobj) in danger_dts:
19962011
alert_future_error(
1997-
f'Image data has type {dataobj.dtype}, which may cause '
2012+
f'Image data has type {get_obj_dtype(dataobj)}, which may cause '
19982013
'incompatibilities with other tools.',
19992014
'5.0',
20002015
warning_rec='This warning can be silenced by passing the dtype argument'
@@ -2410,7 +2425,7 @@ def as_reoriented(self, ornt):
24102425
return img
24112426

24122427

2413-
class Nifti1Image(Nifti1Pair, SerializableImage):
2428+
class Nifti1Image(Nifti1Pair[AffT], SerializableImage):
24142429
"""Class for single file NIfTI1 format image"""
24152430

24162431
header_class = Nifti1Header

nibabel/nifti2.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from .batteryrunners import Report
2020
from .filebasedimages import ImageFileError
2121
from .nifti1 import Nifti1Header, Nifti1Image, Nifti1Pair
22-
from .spatialimages import HeaderDataError
22+
from .spatialimages import AffT, HeaderDataError
2323

2424
r"""
2525
Header struct from : https://www.nitrc.org/forum/message.php?msg_id=3738
@@ -240,17 +240,19 @@ class Nifti2PairHeader(Nifti2Header):
240240
is_single = False
241241

242242

243-
class Nifti2Pair(Nifti1Pair):
243+
class Nifti2Pair(Nifti1Pair[AffT]):
244244
"""Class for NIfTI2 format image, header pair"""
245245

246-
header_class = Nifti2PairHeader
246+
header_class: type[Nifti2Header] = Nifti2PairHeader
247+
header: Nifti2Header
247248
_meta_sniff_len = header_class.sizeof_hdr
248249

249250

250-
class Nifti2Image(Nifti1Image):
251+
class Nifti2Image(Nifti1Image[AffT]):
251252
"""Class for single file NIfTI2 format image"""
252253

253-
header_class = Nifti2Header
254+
header_class: type[Nifti2Header] = Nifti2Header
255+
header: Nifti2Header
254256
_meta_sniff_len = header_class.sizeof_hdr
255257

256258

nibabel/parrec.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@
134134
from .fileslice import fileslice, strided_scalar
135135
from .nifti1 import unit_codes
136136
from .openers import ImageOpener
137-
from .spatialimages import SpatialHeader, SpatialImage
137+
from .spatialimages import Affine, SpatialHeader, SpatialImage
138138
from .volumeutils import Recoder, array_from_file
139139

140140
# PSL to RAS affine
@@ -1248,7 +1248,7 @@ def get_volume_labels(self):
12481248
return sort_info
12491249

12501250

1251-
class PARRECImage(SpatialImage):
1251+
class PARRECImage(SpatialImage[Affine]):
12521252
"""PAR/REC image"""
12531253

12541254
header_class = PARRECHeader

nibabel/spm2analyze.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@
1010

1111
import numpy as np
1212

13-
from . import spm99analyze as spm99 # module import
13+
from . import spm99analyze as spm99
14+
from .spatialimages import AffT
1415

1516
image_dimension_dtd = spm99.image_dimension_dtd.copy()
1617
image_dimension_dtd[image_dimension_dtd.index(('funused2', 'f4'))] = ('scl_inter', 'f4')
@@ -125,7 +126,7 @@ def may_contain_header(klass, binaryblock):
125126
)
126127

127128

128-
class Spm2AnalyzeImage(spm99.Spm99AnalyzeImage):
129+
class Spm2AnalyzeImage(spm99.Spm99AnalyzeImage[AffT]):
129130
"""Class for SPM2 variant of basic Analyze image"""
130131

131132
header_class = Spm2AnalyzeHeader

0 commit comments

Comments
 (0)