Skip to content
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions Lib/compression/zstd/_zstdfile.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import io
from os import PathLike
from _zstd import (ZstdCompressor, ZstdDecompressor, ZstdError,
ZSTD_DStreamOutSize)
from _zstd import ZstdCompressor, ZstdDecompressor, ZSTD_DStreamOutSize
from compression._common import _streams

__all__ = ('ZstdFile', 'open')
Expand Down
142 changes: 98 additions & 44 deletions Lib/test/test_zstd.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,10 @@

SUPPORT_MULTITHREADING = False

C_INT_MIN = -(2**31)
C_INT_MAX = (2**31) - 1


def setUpModule():
global SUPPORT_MULTITHREADING
SUPPORT_MULTITHREADING = CompressionParameter.nb_workers.bounds() != (0, 0)
Expand Down Expand Up @@ -195,14 +199,21 @@ def test_simple_compress_bad_args(self):
self.assertRaises(TypeError, ZstdCompressor, zstd_dict=b"abcd1234")
self.assertRaises(TypeError, ZstdCompressor, zstd_dict={1: 2, 3: 4})

with self.assertRaises(ValueError):
ZstdCompressor(2**31)
with self.assertRaises(ValueError):
ZstdCompressor(options={2**31: 100})
# valid range for compression level is [-(1<<17), 22]
msg = '{} not in valid range -131072 <= compression level <= 22'
with self.assertRaisesRegex(ValueError, msg.format(C_INT_MAX)):
ZstdCompressor(C_INT_MAX)
with self.assertRaisesRegex(ValueError, msg.format(C_INT_MIN)):
ZstdCompressor(C_INT_MIN)
msg = 'compression level not in valid range -131072 <= level <= 22'
with self.assertRaisesRegex(ValueError, msg):
ZstdCompressor(level=-(2**1000))
with self.assertRaisesRegex(ValueError, msg):
ZstdCompressor(level=(2**1000))

with self.assertRaises(ZstdError):
with self.assertRaises(ValueError):
ZstdCompressor(options={CompressionParameter.window_log: 100})
with self.assertRaises(ZstdError):
with self.assertRaises(ValueError):
ZstdCompressor(options={3333: 100})

# Method bad arguments
Expand Down Expand Up @@ -253,18 +264,32 @@ def test_compress_parameters(self):
}
ZstdCompressor(options=d)

# larger than signed int, ValueError
d1 = d.copy()
d1[CompressionParameter.ldm_bucket_size_log] = 2**31
self.assertRaises(ValueError, ZstdCompressor, options=d1)
# larger than signed int
d1[CompressionParameter.ldm_bucket_size_log] = C_INT_MAX
with self.assertRaises(ValueError):
ZstdCompressor(options=d1)
# smaller than signed int
d1[CompressionParameter.ldm_bucket_size_log] = C_INT_MIN
with self.assertRaises(ValueError):
ZstdCompressor(options=d1)

# clamp compressionLevel
# out of bounds compression level
level_min, level_max = CompressionParameter.compression_level.bounds()
compress(b'', level_max+1)
compress(b'', level_min-1)

compress(b'', options={CompressionParameter.compression_level:level_max+1})
compress(b'', options={CompressionParameter.compression_level:level_min-1})
with self.assertRaises(ValueError):
compress(b'', level_max+1)
with self.assertRaises(ValueError):
compress(b'', level_min-1)
with self.assertRaises(ValueError):
compress(b'', 2**1000)
with self.assertRaises(ValueError):
compress(b'', -(2**1000))
with self.assertRaises(ValueError):
compress(b'', options={
CompressionParameter.compression_level: level_max+1})
with self.assertRaises(ValueError):
compress(b'', options={
CompressionParameter.compression_level: level_min-1})

# zstd lib doesn't support MT compression
if not SUPPORT_MULTITHREADING:
Expand All @@ -277,19 +302,19 @@ def test_compress_parameters(self):

# out of bounds error msg
option = {CompressionParameter.window_log:100}
with self.assertRaisesRegex(ZstdError,
(r'Error when setting zstd compression parameter "window_log", '
r'it should \d+ <= value <= \d+, provided value is 100\. '
r'\((?:32|64)-bit build\)')):
with self.assertRaisesRegex(
ValueError,
r"100 not in valid range \d+ <= value <= \d+ for compression "
r"parameter 'window_log'",
):
compress(b'', options=option)

def test_unknown_compression_parameter(self):
KEY = 100001234
option = {CompressionParameter.compression_level: 10,
KEY: 200000000}
pattern = (r'Invalid zstd compression parameter.*?'
fr'"unknown parameter \(key {KEY}\)"')
with self.assertRaisesRegex(ZstdError, pattern):
pattern = rf"invalid compression parameter 'unknown parameter \(key {KEY}\)'"
with self.assertRaisesRegex(ValueError, pattern):
ZstdCompressor(options=option)

@unittest.skipIf(not SUPPORT_MULTITHREADING,
Expand Down Expand Up @@ -384,12 +409,22 @@ def test_simple_decompress_bad_args(self):
self.assertRaises(TypeError, ZstdDecompressor, options=b'abc')

with self.assertRaises(ValueError):
ZstdDecompressor(options={2**31 : 100})
ZstdDecompressor(options={C_INT_MAX: 100})
with self.assertRaises(ValueError):
ZstdDecompressor(options={C_INT_MIN: 100})
with self.assertRaises(ValueError):
ZstdDecompressor(options={0: C_INT_MAX})
with self.assertRaises(OverflowError):
ZstdDecompressor(options={2**1000: 100})
with self.assertRaises(OverflowError):
ZstdDecompressor(options={-(2**1000): 100})
with self.assertRaises(OverflowError):
ZstdDecompressor(options={0: -(2**1000)})

with self.assertRaises(ZstdError):
ZstdDecompressor(options={DecompressionParameter.window_log_max:100})
with self.assertRaises(ZstdError):
ZstdDecompressor(options={3333 : 100})
with self.assertRaises(ValueError):
ZstdDecompressor(options={DecompressionParameter.window_log_max: 100})
with self.assertRaises(ValueError):
ZstdDecompressor(options={3333: 100})

empty = compress(b'')
lzd = ZstdDecompressor()
Expand All @@ -402,26 +437,45 @@ def test_decompress_parameters(self):
d = {DecompressionParameter.window_log_max : 15}
ZstdDecompressor(options=d)

# larger than signed int, ValueError
d1 = d.copy()
d1[DecompressionParameter.window_log_max] = 2**31
self.assertRaises(ValueError, ZstdDecompressor, None, d1)
# larger than signed int
d1[DecompressionParameter.window_log_max] = C_INT_MAX
with self.assertRaises(ValueError):
ZstdDecompressor(None, d1)
# smaller than signed int
d1[DecompressionParameter.window_log_max] = C_INT_MIN
with self.assertRaises(ValueError):
ZstdDecompressor(None, d1)

# out of bounds error msg
options = {DecompressionParameter.window_log_max:100}
with self.assertRaisesRegex(ZstdError,
(r'Error when setting zstd decompression parameter "window_log_max", '
r'it should \d+ <= value <= \d+, provided value is 100\. '
r'\((?:32|64)-bit build\)')):
with self.assertRaisesRegex(
ValueError,
r"100 not in valid range \d+ <= value <= \d+ for decompression "
r"parameter 'window_log_max'",
):
decompress(b'', options=options)

# out of bounds deecompression parameter
options[DecompressionParameter.window_log_max] = C_INT_MAX
with self.assertRaises(ValueError):
decompress(b'', options=options)
options[DecompressionParameter.window_log_max] = C_INT_MIN
with self.assertRaises(ValueError):
decompress(b'', options=options)
options[DecompressionParameter.window_log_max] = 2**1000
with self.assertRaises(OverflowError):
decompress(b'', options=options)
options[DecompressionParameter.window_log_max] = -(2**1000)
with self.assertRaises(OverflowError):
decompress(b'', options=options)

def test_unknown_decompression_parameter(self):
KEY = 100001234
options = {DecompressionParameter.window_log_max: DecompressionParameter.window_log_max.bounds()[1],
KEY: 200000000}
pattern = (r'Invalid zstd decompression parameter.*?'
fr'"unknown parameter \(key {KEY}\)"')
with self.assertRaisesRegex(ZstdError, pattern):
pattern = rf"invalid decompression parameter 'unknown parameter \(key {KEY}\)'"
with self.assertRaisesRegex(ValueError, pattern):
ZstdDecompressor(options=options)

def test_decompress_epilogue_flags(self):
Expand Down Expand Up @@ -1424,11 +1478,11 @@ def test_init_bad_mode(self):
ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB), "rw")

with self.assertRaisesRegex(TypeError,
r"NOT be a CompressionParameter"):
r"not be a CompressionParameter"):
ZstdFile(io.BytesIO(), 'rb',
options={CompressionParameter.compression_level:5})
with self.assertRaisesRegex(TypeError,
r"NOT be a DecompressionParameter"):
r"not be a DecompressionParameter"):
ZstdFile(io.BytesIO(), 'wb',
options={DecompressionParameter.window_log_max:21})

Expand All @@ -1439,19 +1493,19 @@ def test_init_bad_check(self):
with self.assertRaises(TypeError):
ZstdFile(io.BytesIO(), "w", level='asd')
# CHECK_UNKNOWN and anything above CHECK_ID_MAX should be invalid.
with self.assertRaises(ZstdError):
with self.assertRaises(ValueError):
ZstdFile(io.BytesIO(), "w", options={999:9999})
with self.assertRaises(ZstdError):
with self.assertRaises(ValueError):
ZstdFile(io.BytesIO(), "w", options={CompressionParameter.window_log:99})

with self.assertRaises(TypeError):
ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB), "r", options=33)

with self.assertRaises(ValueError):
with self.assertRaises(OverflowError):
ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB),
options={DecompressionParameter.window_log_max:2**31})

with self.assertRaises(ZstdError):
with self.assertRaises(ValueError):
ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB),
options={444:333})

Expand All @@ -1467,7 +1521,7 @@ def test_init_close_fp(self):
tmp_f.write(DAT_130K_C)
filename = tmp_f.name

with self.assertRaises(ValueError):
with self.assertRaises(TypeError):
ZstdFile(filename, options={'a':'b'})

# for PyPy
Expand Down
24 changes: 8 additions & 16 deletions Modules/_zstd/_zstdmodule.c
Original file line number Diff line number Diff line change
Expand Up @@ -103,16 +103,13 @@ static const ParameterInfo dp_list[] = {
};

void
set_parameter_error(const _zstd_state* const state, int is_compress,
int key_v, int value_v)
set_parameter_error(int is_compress, int key_v, int value_v)
{
ParameterInfo const *list;
int list_size;
char const *name;
char *type;
ZSTD_bounds bounds;
int i;
char pos_msg[128];
char pos_msg[64];

if (is_compress) {
list = cp_list;
Expand All @@ -126,8 +123,8 @@ set_parameter_error(const _zstd_state* const state, int is_compress,
}

/* Find parameter's name */
name = NULL;
for (i = 0; i < list_size; i++) {
char const *name = NULL;
for (int i = 0; i < list_size; i++) {
if (key_v == (list+i)->parameter) {
name = (list+i)->parameter_name;
break;
Expand All @@ -149,20 +146,15 @@ set_parameter_error(const _zstd_state* const state, int is_compress,
bounds = ZSTD_dParam_getBounds(key_v);
}
if (ZSTD_isError(bounds.error)) {
PyErr_Format(state->ZstdError,
"Invalid zstd %s parameter \"%s\".",
PyErr_Format(PyExc_ValueError, "invalid %s parameter '%s'",
type, name);
return;
}

/* Error message */
PyErr_Format(state->ZstdError,
"Error when setting zstd %s parameter \"%s\", it "
"should %d <= value <= %d, provided value is %d. "
"(%d-bit build)",
type, name,
bounds.lowerBound, bounds.upperBound, value_v,
8*(int)sizeof(Py_ssize_t));
PyErr_Format(PyExc_ValueError,
"%d not in valid range %d <= value <= %d for %s parameter '%s'",
value_v, bounds.lowerBound, bounds.upperBound, type, name);
}

static inline _zstd_state*
Expand Down
3 changes: 1 addition & 2 deletions Modules/_zstd/_zstdmodule.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@ set_zstd_error(const _zstd_state* const state,
const error_type type, size_t zstd_ret);

extern void
set_parameter_error(const _zstd_state* const state, int is_compress,
int key_v, int value_v);
set_parameter_error(int is_compress, int key_v, int value_v);

#endif // !ZSTD_MODULE_H
Loading
Loading