Skip to content

Commit 266feb2

Browse files
committed
avoid extra read from image file
1 parent 81655fb commit 266feb2

File tree

2 files changed

+47
-15
lines changed

2 files changed

+47
-15
lines changed

nibabel/loadsave.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,18 @@
2020
from .deprecated import deprecate_with_version
2121

2222

23-
def _signature_matches_extension(filename):
23+
def _signature_matches_extension(filename, sniff):
2424
"""Check if signature aka magic number matches filename extension.
2525
2626
Parameters
2727
----------
2828
filename : str or os.PathLike
2929
Path to the file to check
3030
31+
sniff : bytes or None
32+
First bytes of the file. If not `None` and long enough to contain the
33+
signature, avoids having to read the start of the file.
34+
3135
Returns
3236
-------
3337
matches : bool
@@ -50,11 +54,14 @@ def _signature_matches_extension(filename):
5054
if ext not in signatures:
5155
return True, ""
5256
expected_signature = signatures[ext]["signature"]
53-
try:
54-
with open(filename, "rb") as fh:
55-
found_signature = fh.read(len(expected_signature))
56-
except OSError:
57-
return False, f"Could not read file: {filename}"
57+
if sniff is not None and len(sniff) >= len(expected_signature):
58+
found_signature = sniff[:len(expected_signature)]
59+
else:
60+
try:
61+
with open(filename, "rb") as fh:
62+
found_signature = fh.read(len(expected_signature))
63+
except OSError:
64+
return False, f"Could not read file: {filename}"
5865
if found_signature == expected_signature:
5966
return True, ""
6067
format_name = signatures[ext]["format_name"]
@@ -85,9 +92,6 @@ def load(filename, **kwargs):
8592
raise FileNotFoundError(f"No such file or no access: '{filename}'")
8693
if stat_result.st_size <= 0:
8794
raise ImageFileError(f"Empty file: '{filename}'")
88-
matches, msg = _signature_matches_extension(filename)
89-
if not matches:
90-
raise ImageFileError(msg)
9195

9296
sniff = None
9397
for image_klass in all_image_classes:
@@ -96,6 +100,10 @@ def load(filename, **kwargs):
96100
img = image_klass.from_filename(filename, **kwargs)
97101
return img
98102

103+
matches, msg = _signature_matches_extension(filename, sniff)
104+
if not matches:
105+
raise ImageFileError(msg)
106+
99107
raise ImageFileError(f'Cannot work out file type of "{filename}"')
100108

101109

nibabel/tests/test_loadsave.py

Lines changed: 30 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from .. import (Spm99AnalyzeImage, Spm2AnalyzeImage,
1111
Nifti1Pair, Nifti1Image,
1212
Nifti2Pair, Nifti2Image)
13-
from ..loadsave import load, read_img_data
13+
from ..loadsave import load, read_img_data, _signature_matches_extension
1414
from ..filebasedimages import ImageFileError
1515
from ..tmpdirs import InTemporaryDirectory, TemporaryDirectory
1616

@@ -84,11 +84,35 @@ def test_load_bad_compressed_extension(tmp_path, extension):
8484
load(file_path)
8585

8686

87-
def test_load_file_that_cannot_be_read(tmp_path):
88-
subdir = tmp_path / "img.nii.gz"
89-
subdir.mkdir()
90-
with pytest.raises(ImageFileError, match="Could not read"):
91-
load(subdir)
87+
def test_signature_matches_extension(tmp_path):
88+
gz_signature = b"\x1f\x8b"
89+
good_file = tmp_path / "good.gz"
90+
good_file.write_bytes(gz_signature)
91+
bad_file = tmp_path / "bad.gz"
92+
bad_file.write_bytes(b"bad")
93+
matches, msg = _signature_matches_extension(
94+
tmp_path / "uncompressed.nii", None)
95+
assert matches
96+
assert msg == ""
97+
matches, msg = _signature_matches_extension(tmp_path / "missing.gz", None)
98+
assert not matches
99+
assert msg.startswith("Could not read")
100+
matches, msg = _signature_matches_extension(bad_file, None)
101+
assert not matches
102+
assert "is not a" in msg
103+
matches, msg = _signature_matches_extension(bad_file, gz_signature + b"abc")
104+
assert matches
105+
assert msg == ""
106+
matches, msg = _signature_matches_extension(
107+
good_file, gz_signature + b"abc")
108+
assert matches
109+
assert msg == ""
110+
matches, msg = _signature_matches_extension(good_file, gz_signature[:1])
111+
assert matches
112+
assert msg == ""
113+
matches, msg = _signature_matches_extension(good_file, None)
114+
assert matches
115+
assert msg == ""
92116

93117

94118
def test_read_img_data_nifti():

0 commit comments

Comments
 (0)