Skip to content

Commit 7cd34ff

Browse files
committed
TYP: Annotate loadsave
1 parent 40e31e8 commit 7cd34ff

File tree

2 files changed

+32
-20
lines changed

2 files changed

+32
-20
lines changed

nibabel/imageclasses.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,13 @@
77
#
88
### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ##
99
"""Define supported image classes and names"""
10+
from __future__ import annotations
11+
1012
from .analyze import AnalyzeImage
1113
from .brikhead import AFNIImage
1214
from .cifti2 import Cifti2Image
15+
from .dataobj_images import DataobjImage
16+
from .filebasedimages import FileBasedImage
1317
from .freesurfer import MGHImage
1418
from .gifti import GiftiImage
1519
from .minc1 import Minc1Image
@@ -21,7 +25,7 @@
2125
from .spm99analyze import Spm99AnalyzeImage
2226

2327
# Ordered by the load/save priority.
24-
all_image_classes = [
28+
all_image_classes: list[type[FileBasedImage]] = [
2529
Nifti1Pair,
2630
Nifti1Image,
2731
Nifti2Pair,
@@ -41,7 +45,7 @@
4145
# Image classes known to require spatial axes to be first in index ordering.
4246
# When adding an image class, consider whether the new class should be listed
4347
# here.
44-
KNOWN_SPATIAL_FIRST = (
48+
KNOWN_SPATIAL_FIRST: tuple[type[FileBasedImage], ...] = (
4549
Nifti1Pair,
4650
Nifti1Image,
4751
Nifti2Pair,
@@ -55,7 +59,7 @@
5559
)
5660

5761

58-
def spatial_axes_first(img):
62+
def spatial_axes_first(img: DataobjImage) -> bool:
5963
"""True if spatial image axes for `img` always precede other axes
6064
6165
Parameters

nibabel/loadsave.py

Lines changed: 25 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,10 @@
88
### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ##
99
# module imports
1010
"""Utilities to load and save image objects"""
11+
from __future__ import annotations
12+
1113
import os
14+
import typing as ty
1215

1316
import numpy as np
1417

@@ -22,7 +25,18 @@
2225
_compressed_suffixes = ('.gz', '.bz2', '.zst')
2326

2427

25-
def _signature_matches_extension(filename):
28+
if ty.TYPE_CHECKING: # pragma: no cover
29+
from .filebasedimages import FileBasedImage
30+
from .filename_parser import FileSpec
31+
32+
P = ty.ParamSpec('P')
33+
34+
class Signature(ty.TypedDict):
35+
signature: bytes
36+
format_name: str
37+
38+
39+
def _signature_matches_extension(filename: FileSpec) -> tuple[bool, str]:
2640
"""Check if signature aka magic number matches filename extension.
2741
2842
Parameters
@@ -42,7 +56,7 @@ def _signature_matches_extension(filename):
4256
the empty string otherwise.
4357
4458
"""
45-
signatures = {
59+
signatures: dict[str, Signature] = {
4660
'.gz': {'signature': b'\x1f\x8b', 'format_name': 'gzip'},
4761
'.bz2': {'signature': b'BZh', 'format_name': 'bzip2'},
4862
'.zst': {'signature': b'\x28\xb5\x2f\xfd', 'format_name': 'ztsd'},
@@ -64,7 +78,7 @@ def _signature_matches_extension(filename):
6478
return False, f'File {filename} is not a {format_name} file'
6579

6680

67-
def load(filename, **kwargs):
81+
def load(filename: FileSpec, **kwargs) -> FileBasedImage:
6882
r"""Load file given filename, guessing at file type
6983
7084
Parameters
@@ -126,7 +140,7 @@ def guessed_image_type(filename):
126140
raise ImageFileError(f'Cannot work out file type of "{filename}"')
127141

128142

129-
def save(img, filename, **kwargs):
143+
def save(img: FileBasedImage, filename: FileSpec, **kwargs) -> None:
130144
r"""Save an image to file adapting format to `filename`
131145
132146
Parameters
@@ -161,19 +175,17 @@ def save(img, filename, **kwargs):
161175
from .nifti1 import Nifti1Image, Nifti1Pair
162176
from .nifti2 import Nifti2Image, Nifti2Pair
163177

164-
klass = None
165-
converted = None
166-
178+
converted: FileBasedImage
167179
if type(img) == Nifti1Image and lext in ('.img', '.hdr'):
168-
klass = Nifti1Pair
180+
converted = Nifti1Pair.from_image(img)
169181
elif type(img) == Nifti2Image and lext in ('.img', '.hdr'):
170-
klass = Nifti2Pair
182+
converted = Nifti2Pair.from_image(img)
171183
elif type(img) == Nifti1Pair and lext == '.nii':
172-
klass = Nifti1Image
184+
converted = Nifti1Image.from_image(img)
173185
elif type(img) == Nifti2Pair and lext == '.nii':
174-
klass = Nifti2Image
186+
converted = Nifti2Image.from_image(img)
175187
else: # arbitrary conversion
176-
valid_klasses = [klass for klass in all_image_classes if ext in klass.valid_exts]
188+
valid_klasses = [klass for klass in all_image_classes if lext in klass.valid_exts]
177189
if not valid_klasses: # if list is empty
178190
raise ImageFileError(f'Cannot work out file type of "{filename}"')
179191

@@ -186,13 +198,9 @@ def save(img, filename, **kwargs):
186198
break
187199
except Exception as e:
188200
err = e
189-
# ... and if none of them work, raise an error.
190-
if converted is None:
201+
else:
191202
raise err
192203

193-
# Here, we either have a klass or a converted image.
194-
if converted is None:
195-
converted = klass.from_image(img)
196204
converted.to_filename(filename, **kwargs)
197205

198206

0 commit comments

Comments
 (0)