Skip to content

Commit 9de0beb

Browse files
authored
Merge pull request #1013 from jeromedockes/better_err_msg_for_files_with_wrong_gz_extension
ENH: Provide clear error message when files with zip extensions don't match file contents
2 parents 5c4f39c + 1c709bb commit 9de0beb

File tree

2 files changed

+90
-1
lines changed

2 files changed

+90
-1
lines changed

nibabel/loadsave.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,52 @@
2020
from .deprecated import deprecate_with_version
2121

2222

23+
def _signature_matches_extension(filename, sniff):
24+
"""Check if signature aka magic number matches filename extension.
25+
26+
Parameters
27+
----------
28+
filename : str or os.PathLike
29+
Path to the file to check
30+
31+
sniff : bytes or None
32+
First bytes of the file. If not `None` and long enough to contain the
33+
signature, avoids having to read the start of the file.
34+
35+
Returns
36+
-------
37+
matches : bool
38+
- `True` if the filename extension is not recognized (not .gz nor .bz2)
39+
- `True` if the magic number was successfully read and corresponds to
40+
the format indicated by the extension.
41+
- `False` otherwise.
42+
error_message : str
43+
An error message if opening the file failed or a mismatch is detected;
44+
the empty string otherwise.
45+
46+
"""
47+
signatures = {
48+
".gz": {"signature": b"\x1f\x8b", "format_name": "gzip"},
49+
".bz2": {"signature": b"BZh", "format_name": "bzip2"}
50+
}
51+
filename = _stringify_path(filename)
52+
*_, ext = splitext_addext(filename)
53+
ext = ext.lower()
54+
if ext not in signatures:
55+
return True, ""
56+
expected_signature = signatures[ext]["signature"]
57+
if sniff is None or len(sniff) < len(expected_signature):
58+
try:
59+
with open(filename, "rb") as fh:
60+
sniff = fh.read(len(expected_signature))
61+
except OSError:
62+
return False, f"Could not read file: {filename}"
63+
if sniff.startswith(expected_signature):
64+
return True, ""
65+
format_name = signatures[ext]["format_name"]
66+
return False, f"File {filename} is not a {format_name} file"
67+
68+
2369
def load(filename, **kwargs):
2470
r""" Load file given filename, guessing at file type
2571
@@ -52,6 +98,10 @@ def load(filename, **kwargs):
5298
img = image_klass.from_filename(filename, **kwargs)
5399
return img
54100

101+
matches, msg = _signature_matches_extension(filename, sniff)
102+
if not matches:
103+
raise ImageFileError(msg)
104+
55105
raise ImageFileError(f'Cannot work out file type of "{filename}"')
56106

57107

nibabel/tests/test_loadsave.py

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from .. import (Spm99AnalyzeImage, Spm2AnalyzeImage,
1111
Nifti1Pair, Nifti1Image,
1212
Nifti2Pair, Nifti2Image)
13-
from ..loadsave import load, read_img_data
13+
from ..loadsave import load, read_img_data, _signature_matches_extension
1414
from ..filebasedimages import ImageFileError
1515
from ..tmpdirs import InTemporaryDirectory, TemporaryDirectory
1616

@@ -76,6 +76,45 @@ def test_load_empty_image():
7676
assert str(err.value).startswith('Empty file: ')
7777

7878

79+
@pytest.mark.parametrize("extension", [".gz", ".bz2"])
80+
def test_load_bad_compressed_extension(tmp_path, extension):
81+
file_path = tmp_path / f"img.nii{extension}"
82+
file_path.write_bytes(b"bad")
83+
with pytest.raises(ImageFileError, match=".*is not a .* file"):
84+
load(file_path)
85+
86+
87+
def test_signature_matches_extension(tmp_path):
88+
gz_signature = b"\x1f\x8b"
89+
good_file = tmp_path / "good.gz"
90+
good_file.write_bytes(gz_signature)
91+
bad_file = tmp_path / "bad.gz"
92+
bad_file.write_bytes(b"bad")
93+
matches, msg = _signature_matches_extension(
94+
tmp_path / "uncompressed.nii", None)
95+
assert matches
96+
assert msg == ""
97+
matches, msg = _signature_matches_extension(tmp_path / "missing.gz", None)
98+
assert not matches
99+
assert msg.startswith("Could not read")
100+
matches, msg = _signature_matches_extension(bad_file, None)
101+
assert not matches
102+
assert "is not a" in msg
103+
matches, msg = _signature_matches_extension(bad_file, gz_signature + b"abc")
104+
assert matches
105+
assert msg == ""
106+
matches, msg = _signature_matches_extension(
107+
good_file, gz_signature + b"abc")
108+
assert matches
109+
assert msg == ""
110+
matches, msg = _signature_matches_extension(good_file, gz_signature[:1])
111+
assert matches
112+
assert msg == ""
113+
matches, msg = _signature_matches_extension(good_file, None)
114+
assert matches
115+
assert msg == ""
116+
117+
79118
def test_read_img_data_nifti():
80119
shape = (2, 3, 4)
81120
data = np.random.normal(size=shape)

0 commit comments

Comments
 (0)