Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 48 additions & 21 deletions Lib/test/test_zstd.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,14 +196,31 @@ def test_simple_compress_bad_args(self):
self.assertRaises(TypeError, ZstdCompressor, zstd_dict={1: 2, 3: 4})

# valid compression level range is [-(1<<17), 22]
with self.assertRaises(ValueError):
with self.assertRaises(ValueError) as cm:
ZstdCompressor(23)
with self.assertRaises(ValueError):
self.assertEqual(
str(cm.exception),
'23 not in valid range -131072 <= compression level <= 22.',
)
with self.assertRaises(ValueError) as cm:
ZstdCompressor(-(1<<17)-1)
with self.assertRaises(ValueError):
self.assertEqual(-(1<<17)-1, -131073)
self.assertEqual(
str(cm.exception),
'-131073 not in valid range -131072 <= compression level <= 22.',
)
with self.assertRaises(ValueError) as cm:
ZstdCompressor(2**31)
with self.assertRaises(ValueError):
self.assertEqual(
str(cm.exception),
'compression level not in valid range -131072 <= level <= 22.',
)
with self.assertRaises(ValueError) as cm:
ZstdCompressor(level=-(2**1000))
self.assertEqual(
str(cm.exception),
'compression level not in valid range -131072 <= level <= 22.',
)
with self.assertRaises(ValueError):
ZstdCompressor(level=(2**1000))

Expand Down Expand Up @@ -260,10 +277,15 @@ def test_compress_parameters(self):
}
ZstdCompressor(options=d)

# larger than signed int, ValueError
d1 = d.copy()
# larger than signed int
d1[CompressionParameter.ldm_bucket_size_log] = 2**31
self.assertRaises(ValueError, ZstdCompressor, options=d1)
with self.assertRaises(OverflowError):
ZstdCompressor(options=d1)
# smaller than signed int
d1[CompressionParameter.ldm_bucket_size_log] = -(2**31)-1
with self.assertRaises(OverflowError):
ZstdCompressor(options=d1)

# out of bounds compression level
level_min, level_max = CompressionParameter.compression_level.bounds()
Expand Down Expand Up @@ -399,17 +421,17 @@ def test_simple_decompress_bad_args(self):
self.assertRaises(TypeError, ZstdDecompressor, options='abc')
self.assertRaises(TypeError, ZstdDecompressor, options=b'abc')

with self.assertRaises(ValueError):
with self.assertRaises(OverflowError):
ZstdDecompressor(options={2**31: 100})
with self.assertRaises(ValueError):
with self.assertRaises(OverflowError):
ZstdDecompressor(options={2**1000: 100})
with self.assertRaises(ValueError):
with self.assertRaises(OverflowError):
ZstdDecompressor(options={-(2**31)-1: 100})
with self.assertRaises(ValueError):
with self.assertRaises(OverflowError):
ZstdDecompressor(options={-(2**1000): 100})
with self.assertRaises(ValueError):
ZstdDecompressor(options={0: 2**32})
with self.assertRaises(ValueError):
with self.assertRaises(OverflowError):
ZstdDecompressor(options={0: 2**31})
with self.assertRaises(OverflowError):
ZstdDecompressor(options={0: -(2**1000)})

with self.assertRaises(ZstdError):
Expand All @@ -428,10 +450,15 @@ def test_decompress_parameters(self):
d = {DecompressionParameter.window_log_max : 15}
ZstdDecompressor(options=d)

# larger than signed int, ValueError
d1 = d.copy()
# larger than signed int
d1[DecompressionParameter.window_log_max] = 2**31
self.assertRaises(ValueError, ZstdDecompressor, None, d1)
with self.assertRaises(OverflowError):
ZstdDecompressor(None, d1)
# smaller than signed int
d1[DecompressionParameter.window_log_max] = -(2**31)-1
with self.assertRaises(OverflowError):
ZstdDecompressor(None, d1)

# out of bounds error msg
options = {DecompressionParameter.window_log_max:100}
Expand All @@ -443,16 +470,16 @@ def test_decompress_parameters(self):

# out of bounds deecompression parameter
options[DecompressionParameter.window_log_max] = 2**31
with self.assertRaises(ValueError):
with self.assertRaises(OverflowError):
decompress(b'', options=options)
options[DecompressionParameter.window_log_max] = -(2**32)-1
with self.assertRaises(ValueError):
options[DecompressionParameter.window_log_max] = -(2**31)-1
with self.assertRaises(OverflowError):
decompress(b'', options=options)
options[DecompressionParameter.window_log_max] = 2**1000
with self.assertRaises(ValueError):
with self.assertRaises(OverflowError):
decompress(b'', options=options)
options[DecompressionParameter.window_log_max] = -(2**1000)
with self.assertRaises(ValueError):
with self.assertRaises(OverflowError):
decompress(b'', options=options)

def test_unknown_decompression_parameter(self):
Expand Down Expand Up @@ -1487,7 +1514,7 @@ def test_init_bad_check(self):
with self.assertRaises(TypeError):
ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB), "r", options=33)

with self.assertRaises(ValueError):
with self.assertRaises(OverflowError):
ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB),
options={DecompressionParameter.window_log_max:2**31})

Expand Down
27 changes: 12 additions & 15 deletions Modules/_zstd/compressor.c
Original file line number Diff line number Diff line change
Expand Up @@ -49,14 +49,14 @@ typedef struct {
#include "clinic/compressor.c.h"

static int
_zstd_set_c_level(ZstdCompressor *self, const int level)
_zstd_set_c_level(ZstdCompressor *self, int level)
{
/* Set integer compression level */
int min_level = ZSTD_minCLevel();
int max_level = ZSTD_maxCLevel();
if (level < min_level || level > max_level) {
PyErr_Format(PyExc_ValueError,
"%zd not in valid range %d <= compression level <= %d.",
"%d not in valid range %d <= compression level <= %d.",
level, min_level, max_level);
return -1;
}
Expand Down Expand Up @@ -90,7 +90,8 @@ _zstd_set_c_parameters(ZstdCompressor *self, PyObject *options)

if (!PyDict_Check(options)) {
PyErr_Format(PyExc_TypeError,
"invalid type for options, expected dict");
"ZstdCompressor() argument 'options' must be dict, not %T",
options);
return -1;
}

Expand All @@ -106,22 +107,16 @@ _zstd_set_c_parameters(ZstdCompressor *self, PyObject *options)
}

Py_INCREF(key);
Py_INCREF(value);
int key_v = PyLong_AsInt(key);
Py_DECREF(key);
if (key_v == -1 && PyErr_Occurred()) {
if (PyErr_ExceptionMatches(PyExc_OverflowError)) {
PyErr_SetString(PyExc_ValueError,
"dictionary key must be less than 2**31");
}
return -1;
}

Py_INCREF(value);
int value_v = PyLong_AsInt(value);
Py_DECREF(value);
if (value_v == -1 && PyErr_Occurred()) {
if (PyErr_ExceptionMatches(PyExc_OverflowError)) {
PyErr_SetString(PyExc_ValueError,
"dictionary value must be less than 2**31");
}
return -1;
}

Expand All @@ -145,6 +140,8 @@ _zstd_set_c_parameters(ZstdCompressor *self, PyObject *options)

/* Set parameter to compression context */
size_t zstd_ret = ZSTD_CCtx_setParameter(self->cctx, key_v, value_v);

/* Check error */
if (ZSTD_isError(zstd_ret)) {
set_parameter_error(mod_state, 1, key_v, value_v);
return -1;
Expand Down Expand Up @@ -371,7 +368,7 @@ _zstd_ZstdCompressor_new_impl(PyTypeObject *type, PyObject *level,
self->last_mode = ZSTD_e_end;

if (level != Py_None && options != Py_None) {
PyErr_SetString(PyExc_RuntimeError,
PyErr_SetString(PyExc_TypeError,
"Only one of level or options should be used.");
goto error;
}
Expand All @@ -387,8 +384,8 @@ _zstd_ZstdCompressor_new_impl(PyTypeObject *type, PyObject *level,
if (level_v == -1 && PyErr_Occurred()) {
if (PyErr_ExceptionMatches(PyExc_OverflowError)) {
PyErr_Format(PyExc_ValueError,
"%zd not in valid range %d <= compression level <= %d.",
level, ZSTD_minCLevel(), ZSTD_maxCLevel());
"compression level not in valid range %d <= level <= %d.",
ZSTD_minCLevel(), ZSTD_maxCLevel());
}
goto error;
}
Expand Down
17 changes: 6 additions & 11 deletions Modules/_zstd/decompressor.c
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,9 @@ _zstd_set_d_parameters(ZstdDecompressor *self, PyObject *options)
}

if (!PyDict_Check(options)) {
PyErr_SetString(PyExc_TypeError,
"invalid type for options, expected dict");
PyErr_Format(PyExc_TypeError,
"ZstdDecompressor() argument 'options' must be dict, not %T",
options);
return -1;
}

Expand All @@ -114,22 +115,16 @@ _zstd_set_d_parameters(ZstdDecompressor *self, PyObject *options)
}

Py_INCREF(key);
Py_INCREF(value);
int key_v = PyLong_AsInt(key);
Py_DECREF(key);
if (key_v == -1 && PyErr_Occurred()) {
if (PyErr_ExceptionMatches(PyExc_OverflowError)) {
PyErr_SetString(PyExc_ValueError,
"dictionary key must be less than 2**31");
}
return -1;
}

Py_INCREF(value);
int value_v = PyLong_AsInt(value);
Py_DECREF(value);
if (value_v == -1 && PyErr_Occurred()) {
if (PyErr_ExceptionMatches(PyExc_OverflowError)) {
PyErr_SetString(PyExc_ValueError,
"dictionary value must be less than 2**31");
}
return -1;
}

Expand Down
Loading