131
131
"""
132
132
from __future__ import annotations
133
133
134
- from typing import Type
134
+ import io
135
+ import typing as ty
136
+ from typing import Literal , Sequence
135
137
136
138
import numpy as np
137
139
140
+ from .arrayproxy import ArrayLike
138
141
from .dataobj_images import DataobjImage
139
- from .filebasedimages import ImageFileError # noqa
140
- from .filebasedimages import FileBasedHeader
142
+ from .filebasedimages import FileBasedHeader , FileBasedImage , FileMap
141
143
from .fileslice import canonical_slicers
142
144
from .orientations import apply_orientation , inv_ornt_aff
143
145
from .viewers import OrthoSlicer3D
148
150
except ImportError : # PY38
149
151
from functools import lru_cache as cache
150
152
153
+ if ty .TYPE_CHECKING : # pragma: no cover
154
+ import numpy .typing as npt
155
+
156
+ SpatialImgT = ty .TypeVar ('SpatialImgT' , bound = 'SpatialImage' )
157
+ SpatialHdrT = ty .TypeVar ('SpatialHdrT' , bound = 'SpatialHeader' )
158
+
159
+
160
+ class HasDtype (ty .Protocol ):
161
+ def get_data_dtype (self ) -> np .dtype :
162
+ ... # pragma: no cover
163
+
164
+ def set_data_dtype (self , dtype : npt .DTypeLike ) -> None :
165
+ ... # pragma: no cover
166
+
167
+
168
+ @ty .runtime_checkable
169
+ class SpatialProtocol (ty .Protocol ):
170
+ def get_data_dtype (self ) -> np .dtype :
171
+ ... # pragma: no cover
172
+
173
+ def get_data_shape (self ) -> ty .Tuple [int , ...]:
174
+ ... # pragma: no cover
175
+
176
+ def get_zooms (self ) -> ty .Tuple [float , ...]:
177
+ ... # pragma: no cover
178
+
151
179
152
180
class HeaderDataError (Exception ):
153
181
"""Class to indicate error in getting or setting header data"""
@@ -157,21 +185,33 @@ class HeaderTypeError(Exception):
157
185
"""Class to indicate error in parameters into header functions"""
158
186
159
187
160
- class SpatialHeader (FileBasedHeader ):
188
+ class SpatialHeader (FileBasedHeader , SpatialProtocol ):
161
189
"""Template class to implement header protocol"""
162
190
163
- default_x_flip = True
164
- data_layout = 'F'
191
+ default_x_flip : bool = True
192
+ data_layout : Literal [ 'F' , 'C' ] = 'F'
165
193
166
- def __init__ (self , data_dtype = np .float32 , shape = (0 ,), zooms = None ):
194
+ _dtype : np .dtype
195
+ _shape : tuple [int , ...]
196
+ _zooms : tuple [float , ...]
197
+
198
+ def __init__ (
199
+ self ,
200
+ data_dtype : npt .DTypeLike = np .float32 ,
201
+ shape : Sequence [int ] = (0 ,),
202
+ zooms : Sequence [float ] | None = None ,
203
+ ):
167
204
self .set_data_dtype (data_dtype )
168
205
self ._zooms = ()
169
206
self .set_data_shape (shape )
170
207
if zooms is not None :
171
208
self .set_zooms (zooms )
172
209
173
210
@classmethod
174
- def from_header (klass , header = None ):
211
+ def from_header (
212
+ klass : type [SpatialHdrT ],
213
+ header : SpatialProtocol | FileBasedHeader | ty .Mapping | None = None ,
214
+ ) -> SpatialHdrT :
175
215
if header is None :
176
216
return klass ()
177
217
# I can't do isinstance here because it is not necessarily true
@@ -180,74 +220,68 @@ def from_header(klass, header=None):
180
220
# different field names
181
221
if type (header ) == klass :
182
222
return header .copy ()
183
- return klass (header .get_data_dtype (), header .get_data_shape (), header .get_zooms ())
184
-
185
- @classmethod
186
- def from_fileobj (klass , fileobj ):
187
- raise NotImplementedError
188
-
189
- def write_to (self , fileobj ):
190
- raise NotImplementedError
191
-
192
- def __eq__ (self , other ):
193
- return (self .get_data_dtype (), self .get_data_shape (), self .get_zooms ()) == (
194
- other .get_data_dtype (),
195
- other .get_data_shape (),
196
- other .get_zooms (),
197
- )
198
-
199
- def __ne__ (self , other ):
200
- return not self == other
223
+ if isinstance (header , SpatialProtocol ):
224
+ return klass (header .get_data_dtype (), header .get_data_shape (), header .get_zooms ())
225
+ return super ().from_header (header )
226
+
227
+ def __eq__ (self , other : object ) -> bool :
228
+ if isinstance (other , SpatialHeader ):
229
+ return (self .get_data_dtype (), self .get_data_shape (), self .get_zooms ()) == (
230
+ other .get_data_dtype (),
231
+ other .get_data_shape (),
232
+ other .get_zooms (),
233
+ )
234
+ return NotImplemented
201
235
202
- def copy (self ) :
236
+ def copy (self : SpatialHdrT ) -> SpatialHdrT :
203
237
"""Copy object to independent representation
204
238
205
239
The copy should not be affected by any changes to the original
206
240
object.
207
241
"""
208
242
return self .__class__ (self ._dtype , self ._shape , self ._zooms )
209
243
210
- def get_data_dtype (self ):
244
+ def get_data_dtype (self ) -> np . dtype :
211
245
return self ._dtype
212
246
213
- def set_data_dtype (self , dtype ) :
247
+ def set_data_dtype (self , dtype : npt . DTypeLike ) -> None :
214
248
self ._dtype = np .dtype (dtype )
215
249
216
- def get_data_shape (self ):
250
+ def get_data_shape (self ) -> tuple [ int , ...] :
217
251
return self ._shape
218
252
219
- def set_data_shape (self , shape ) :
253
+ def set_data_shape (self , shape : Sequence [ int ]) -> None :
220
254
ndim = len (shape )
221
255
if ndim == 0 :
222
256
self ._shape = (0 ,)
223
257
self ._zooms = (1.0 ,)
224
258
return
225
- self ._shape = tuple ([ int (s ) for s in shape ] )
259
+ self ._shape = tuple (int (s ) for s in shape )
226
260
# set any unset zooms to 1.0
227
261
nzs = min (len (self ._zooms ), ndim )
228
262
self ._zooms = self ._zooms [:nzs ] + (1.0 ,) * (ndim - nzs )
229
263
230
- def get_zooms (self ):
264
+ def get_zooms (self ) -> tuple [ float , ...] :
231
265
return self ._zooms
232
266
233
- def set_zooms (self , zooms ) :
234
- zooms = tuple ([ float (z ) for z in zooms ] )
267
+ def set_zooms (self , zooms : Sequence [ float ]) -> None :
268
+ zooms = tuple (float (z ) for z in zooms )
235
269
shape = self .get_data_shape ()
236
270
ndim = len (shape )
237
271
if len (zooms ) != ndim :
238
272
raise HeaderDataError ('Expecting %d zoom values for ndim %d' % (ndim , ndim ))
239
- if len ([ z for z in zooms if z < 0 ] ):
273
+ if any ( z < 0 for z in zooms ):
240
274
raise HeaderDataError ('zooms must be positive' )
241
275
self ._zooms = zooms
242
276
243
- def get_base_affine (self ):
277
+ def get_base_affine (self ) -> np . ndarray :
244
278
shape = self .get_data_shape ()
245
279
zooms = self .get_zooms ()
246
280
return shape_zoom_affine (shape , zooms , self .default_x_flip )
247
281
248
282
get_best_affine = get_base_affine
249
283
250
- def data_to_fileobj (self , data , fileobj , rescale = True ):
284
+ def data_to_fileobj (self , data : npt . ArrayLike , fileobj : io . IOBase , rescale : bool = True ):
251
285
"""Write array data `data` as binary to `fileobj`
252
286
253
287
Parameters
@@ -264,7 +298,7 @@ def data_to_fileobj(self, data, fileobj, rescale=True):
264
298
dtype = self .get_data_dtype ()
265
299
fileobj .write (data .astype (dtype ).tobytes (order = self .data_layout ))
266
300
267
- def data_from_fileobj (self , fileobj ) :
301
+ def data_from_fileobj (self , fileobj : io . IOBase ) -> np . ndarray :
268
302
"""Read binary image data from `fileobj`"""
269
303
dtype = self .get_data_dtype ()
270
304
shape = self .get_data_shape ()
@@ -274,7 +308,7 @@ def data_from_fileobj(self, fileobj):
274
308
275
309
276
310
@cache
277
- def _supported_np_types (klass ) :
311
+ def _supported_np_types (klass : type [ HasDtype ]) -> set [ type [ np . generic ]] :
278
312
"""Numpy data types that instances of ``klass`` support
279
313
280
314
Parameters
@@ -308,7 +342,7 @@ def _supported_np_types(klass):
308
342
return supported
309
343
310
344
311
- def supported_np_types (obj ) :
345
+ def supported_np_types (obj : HasDtype ) -> set [ type [ np . generic ]] :
312
346
"""Numpy data types that instance `obj` supports
313
347
314
348
Parameters
@@ -330,13 +364,15 @@ class ImageDataError(Exception):
330
364
pass
331
365
332
366
333
- class SpatialFirstSlicer :
367
+ class SpatialFirstSlicer ( ty . Generic [ SpatialImgT ]) :
334
368
"""Slicing interface that returns a new image with an updated affine
335
369
336
370
Checks that an image's first three axes are spatial
337
371
"""
338
372
339
- def __init__ (self , img ):
373
+ img : SpatialImgT
374
+
375
+ def __init__ (self , img : SpatialImgT ):
340
376
# Local import to avoid circular import on module load
341
377
from .imageclasses import spatial_axes_first
342
378
@@ -346,7 +382,7 @@ def __init__(self, img):
346
382
)
347
383
self .img = img
348
384
349
- def __getitem__ (self , slicer ) :
385
+ def __getitem__ (self , slicer : object ) -> SpatialImgT :
350
386
try :
351
387
slicer = self .check_slicing (slicer )
352
388
except ValueError as err :
@@ -359,7 +395,7 @@ def __getitem__(self, slicer):
359
395
affine = self .slice_affine (slicer )
360
396
return self .img .__class__ (dataobj .copy (), affine , self .img .header )
361
397
362
- def check_slicing (self , slicer , return_spatial = False ):
398
+ def check_slicing (self , slicer : object , return_spatial : bool = False ) -> tuple [ slice , ...] :
363
399
"""Canonicalize slicers and check for scalar indices in spatial dims
364
400
365
401
Parameters
@@ -376,21 +412,21 @@ def check_slicing(self, slicer, return_spatial=False):
376
412
Validated slicer object that will slice image's `dataobj`
377
413
without collapsing spatial dimensions
378
414
"""
379
- slicer = canonical_slicers (slicer , self .img .shape )
415
+ canonical = canonical_slicers (slicer , self .img .shape )
380
416
# We can get away with this because we've checked the image's
381
417
# first three axes are spatial.
382
418
# More general slicers will need to be smarter, here.
383
- spatial_slices = slicer [:3 ]
419
+ spatial_slices = canonical [:3 ]
384
420
for subslicer in spatial_slices :
385
421
if subslicer is None :
386
422
raise IndexError ('New axis not permitted in spatial dimensions' )
387
423
elif isinstance (subslicer , int ):
388
424
raise IndexError (
389
425
'Scalar indices disallowed in spatial dimensions; Use `[x]` or `x:x+1`.'
390
426
)
391
- return spatial_slices if return_spatial else slicer
427
+ return spatial_slices if return_spatial else canonical
392
428
393
- def slice_affine (self , slicer ) :
429
+ def slice_affine (self , slicer : tuple [ slice , ...]) -> np . ndarray :
394
430
"""Retrieve affine for current image, if sliced by a given index
395
431
396
432
Applies scaling if down-sampling is applied, and adjusts the intercept
@@ -430,10 +466,19 @@ def slice_affine(self, slicer):
430
466
class SpatialImage (DataobjImage ):
431
467
"""Template class for volumetric (3D/4D) images"""
432
468
433
- header_class : Type [SpatialHeader ] = SpatialHeader
434
- ImageSlicer = SpatialFirstSlicer
469
+ header_class : type [SpatialHeader ] = SpatialHeader
470
+ ImageSlicer : type [SpatialFirstSlicer ] = SpatialFirstSlicer
471
+
472
+ _header : SpatialHeader
435
473
436
- def __init__ (self , dataobj , affine , header = None , extra = None , file_map = None ):
474
+ def __init__ (
475
+ self ,
476
+ dataobj : ArrayLike ,
477
+ affine : np .ndarray ,
478
+ header : FileBasedHeader | ty .Mapping | None = None ,
479
+ extra : ty .Mapping | None = None ,
480
+ file_map : FileMap | None = None ,
481
+ ):
437
482
"""Initialize image
438
483
439
484
The image is a combination of (array-like, affine matrix, header), with
@@ -483,7 +528,7 @@ def __init__(self, dataobj, affine, header=None, extra=None, file_map=None):
483
528
def affine (self ):
484
529
return self ._affine
485
530
486
- def update_header (self ):
531
+ def update_header (self ) -> None :
487
532
"""Harmonize header with image data and affine
488
533
489
534
>>> data = np.zeros((2,3,4))
@@ -512,7 +557,7 @@ def update_header(self):
512
557
return
513
558
self ._affine2header ()
514
559
515
- def _affine2header (self ):
560
+ def _affine2header (self ) -> None :
516
561
"""Unconditionally set affine into the header"""
517
562
RZS = self ._affine [:3 , :3 ]
518
563
vox = np .sqrt (np .sum (RZS * RZS , axis = 0 ))
@@ -522,7 +567,7 @@ def _affine2header(self):
522
567
zooms [:n_to_set ] = vox [:n_to_set ]
523
568
hdr .set_zooms (zooms )
524
569
525
- def __str__ (self ):
570
+ def __str__ (self ) -> str :
526
571
shape = self .shape
527
572
affine = self .affine
528
573
return f"""
@@ -534,14 +579,14 @@ def __str__(self):
534
579
{ self ._header }
535
580
"""
536
581
537
- def get_data_dtype (self ):
582
+ def get_data_dtype (self ) -> np . dtype :
538
583
return self ._header .get_data_dtype ()
539
584
540
- def set_data_dtype (self , dtype ) :
585
+ def set_data_dtype (self , dtype : npt . DTypeLike ) -> None :
541
586
self ._header .set_data_dtype (dtype )
542
587
543
588
@classmethod
544
- def from_image (klass , img ) :
589
+ def from_image (klass : type [ SpatialImgT ] , img : SpatialImage | FileBasedImage ) -> SpatialImgT :
545
590
"""Class method to create new instance of own class from `img`
546
591
547
592
Parameters
@@ -555,15 +600,17 @@ def from_image(klass, img):
555
600
cimg : ``spatialimage`` instance
556
601
Image, of our own class
557
602
"""
558
- return klass (
559
- img .dataobj ,
560
- img .affine ,
561
- klass .header_class .from_header (img .header ),
562
- extra = img .extra .copy (),
563
- )
603
+ if isinstance (img , SpatialImage ):
604
+ return klass (
605
+ img .dataobj ,
606
+ img .affine ,
607
+ klass .header_class .from_header (img .header ),
608
+ extra = img .extra .copy (),
609
+ )
610
+ return super ().from_image (img )
564
611
565
612
@property
566
- def slicer (self ) :
613
+ def slicer (self : SpatialImgT ) -> SpatialFirstSlicer [ SpatialImgT ] :
567
614
"""Slicer object that returns cropped and subsampled images
568
615
569
616
The image is resliced in the current orientation; no rotation or
@@ -582,7 +629,7 @@ def slicer(self):
582
629
"""
583
630
return self .ImageSlicer (self )
584
631
585
- def __getitem__ (self , idx ) :
632
+ def __getitem__ (self , idx : object ) -> None :
586
633
"""No slicing or dictionary interface for images
587
634
588
635
Use the slicer attribute to perform cropping and subsampling at your
@@ -595,7 +642,7 @@ def __getitem__(self, idx):
595
642
'`img.get_fdata()[slice]`'
596
643
)
597
644
598
- def orthoview (self ):
645
+ def orthoview (self ) -> OrthoSlicer3D :
599
646
"""Plot the image using OrthoSlicer3D
600
647
601
648
Returns
@@ -611,7 +658,7 @@ def orthoview(self):
611
658
"""
612
659
return OrthoSlicer3D (self .dataobj , self .affine , title = self .get_filename ())
613
660
614
- def as_reoriented (self , ornt ) :
661
+ def as_reoriented (self : SpatialImgT , ornt : Sequence [ Sequence [ int ]]) -> SpatialImgT :
615
662
"""Apply an orientation change and return a new image
616
663
617
664
If ornt is identity transform, return the original image, unchanged
0 commit comments